diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index fd68779..9e26d9b 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1636,7 +1636,7 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hikari-desktop" -version = "1.2.0" +version = "1.3.0" dependencies = [ "chrono", "dirs 5.0.1", diff --git a/src-tauri/src/achievements.rs b/src-tauri/src/achievements.rs index 16d24f2..f4991a4 100644 --- a/src-tauri/src/achievements.rs +++ b/src-tauri/src/achievements.rs @@ -2329,6 +2329,10 @@ mod tests { context_utilisation_percent: 0.0, potential_cache_hits: 0, potential_cache_savings_tokens: 0, + current_request_input: None, + current_request_output_chars: 0, + current_request_thinking_chars: 0, + current_request_tools: Vec::new(), achievements: AchievementProgress::new(), } } diff --git a/src-tauri/src/stats.rs b/src-tauri/src/stats.rs index 346d073..989170b 100644 --- a/src-tauri/src/stats.rs +++ b/src-tauri/src/stats.rs @@ -154,6 +154,16 @@ pub struct UsageStats { // Achievement tracking #[serde(skip)] pub achievements: AchievementProgress, + + // Track current in-flight request for cost estimation on interrupt + #[serde(skip)] + pub current_request_input: Option, + #[serde(skip)] + pub current_request_output_chars: u64, + #[serde(skip)] + pub current_request_thinking_chars: u64, + #[serde(skip)] + pub current_request_tools: Vec, } impl UsageStats { @@ -163,13 +173,26 @@ impl UsageStats { stats } - pub fn add_usage(&mut self, input_tokens: u64, output_tokens: u64, model: &str) { + pub fn add_usage( + &mut self, + input_tokens: u64, + output_tokens: u64, + model: &str, + cache_creation_tokens: Option, + cache_read_tokens: Option, + ) { self.total_input_tokens += input_tokens; self.total_output_tokens += output_tokens; self.session_input_tokens += input_tokens; self.session_output_tokens += output_tokens; - let cost = calculate_cost(input_tokens, output_tokens, model); + let cost = calculate_cost( + input_tokens, + output_tokens, + model, + cache_creation_tokens, + cache_read_tokens, + ); self.total_cost_usd += cost; self.session_cost_usd += cost; @@ -439,6 +462,10 @@ impl UsageStats { potential_cache_hits: self.potential_cache_hits, potential_cache_savings_tokens: self.potential_cache_savings_tokens, achievements: AchievementProgress::new(), // Dummy for copy + current_request_input: None, // Don't copy tracking fields + current_request_output_chars: 0, + current_request_thinking_chars: 0, + current_request_tools: Vec::new(), }; check_achievements(&stats_copy, &mut self.achievements) } @@ -462,7 +489,14 @@ fn is_consecutive_day(prev_date: &str, current_date: &str) -> bool { // Pricing as of February 2026 // https://platform.claude.com/docs/en/about-claude/models/overview -pub fn calculate_cost(input_tokens: u64, output_tokens: u64, model: &str) -> f64 { +// Cache pricing: https://platform.claude.com/docs/en/build-with-claude/prompt-caching +pub fn calculate_cost( + input_tokens: u64, + output_tokens: u64, + model: &str, + cache_creation_tokens: Option, + cache_read_tokens: Option, +) -> f64 { let (input_price_per_million, output_price_per_million) = match model { // Current generation (Claude 4.5) "claude-opus-4-5-20251101" => (5.0, 25.0), @@ -487,10 +521,25 @@ pub fn calculate_cost(input_tokens: u64, output_tokens: u64, model: &str) -> f64 _ => (3.0, 15.0), }; + // Regular input/output tokens let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price_per_million; let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million; - input_cost + output_cost + // Cache write tokens (cache creation) cost 1.25x the base input price + let cache_write_cost = if let Some(cache_creation) = cache_creation_tokens { + (cache_creation as f64 / 1_000_000.0) * input_price_per_million * 1.25 + } else { + 0.0 + }; + + // Cache read tokens cost 0.1x (10%) the base input price + let cache_read_cost = if let Some(cache_read) = cache_read_tokens { + (cache_read as f64 / 1_000_000.0) * input_price_per_million * 0.1 + } else { + 0.0 + }; + + input_cost + output_cost + cache_write_cost + cache_read_cost } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -609,7 +658,7 @@ mod tests { #[test] fn test_cost_calculation_sonnet() { - let cost = calculate_cost(1000, 2000, "claude-sonnet-4-20250514"); + let cost = calculate_cost(1000, 2000, "claude-sonnet-4-20250514", None, None); // 1000 input * $3/M = $0.003 // 2000 output * $15/M = $0.030 // Total = $0.033 @@ -618,7 +667,7 @@ mod tests { #[test] fn test_cost_calculation_opus() { - let cost = calculate_cost(1000, 2000, "claude-opus-4-20250514"); + let cost = calculate_cost(1000, 2000, "claude-opus-4-20250514", None, None); // 1000 input * $15/M = $0.015 // 2000 output * $75/M = $0.150 // Total = $0.165 @@ -627,7 +676,7 @@ mod tests { #[test] fn test_cost_calculation_opus_45() { - let cost = calculate_cost(1000, 2000, "claude-opus-4-5-20251101"); + let cost = calculate_cost(1000, 2000, "claude-opus-4-5-20251101", None, None); // Opus 4.5 pricing: $5/MTok input, $25/MTok output // 1000 input tokens = $0.005, 2000 output tokens = $0.05 // Total = $0.055 @@ -636,7 +685,7 @@ mod tests { #[test] fn test_cost_calculation_haiku() { - let cost = calculate_cost(1000, 2000, "claude-3-5-haiku-20241022"); + let cost = calculate_cost(1000, 2000, "claude-3-5-haiku-20241022", None, None); // 1000 input * $1/M = $0.001 // 2000 output * $5/M = $0.010 // Total = $0.011 @@ -645,14 +694,14 @@ mod tests { #[test] fn test_cost_calculation_unknown_defaults_to_sonnet() { - let cost = calculate_cost(1000, 2000, "some-unknown-model"); + let cost = calculate_cost(1000, 2000, "some-unknown-model", None, None); // Should default to Sonnet pricing assert!((cost - 0.033).abs() < 0.0001); } #[test] fn test_cost_calculation_legacy_sonnet() { - let cost = calculate_cost(1000, 2000, "claude-3-5-sonnet-20241022"); + let cost = calculate_cost(1000, 2000, "claude-3-5-sonnet-20241022", None, None); // Same as Sonnet 4 pricing assert!((cost - 0.033).abs() < 0.0001); } @@ -660,7 +709,7 @@ mod tests { #[test] fn test_usage_stats_accumulation() { let mut stats = UsageStats::new(); - stats.add_usage(1000, 2000, "claude-sonnet-4-20250514"); + stats.add_usage(1000, 2000, "claude-sonnet-4-20250514", None, None); assert_eq!(stats.total_input_tokens, 1000); assert_eq!(stats.total_output_tokens, 2000); @@ -672,8 +721,8 @@ mod tests { #[test] fn test_usage_stats_multiple_accumulations() { let mut stats = UsageStats::new(); - stats.add_usage(1000, 1000, "claude-sonnet-4-20250514"); - stats.add_usage(500, 500, "claude-sonnet-4-20250514"); + stats.add_usage(1000, 1000, "claude-sonnet-4-20250514", None, None); + stats.add_usage(500, 500, "claude-sonnet-4-20250514", None, None); assert_eq!(stats.total_input_tokens, 1500); assert_eq!(stats.total_output_tokens, 1500); @@ -684,17 +733,17 @@ mod tests { #[test] fn test_usage_stats_model_updated() { let mut stats = UsageStats::new(); - stats.add_usage(1000, 1000, "claude-sonnet-4-20250514"); + stats.add_usage(1000, 1000, "claude-sonnet-4-20250514", None, None); assert_eq!(stats.model, Some("claude-sonnet-4-20250514".to_string())); - stats.add_usage(500, 500, "claude-opus-4-20250514"); + stats.add_usage(500, 500, "claude-opus-4-20250514", None, None); assert_eq!(stats.model, Some("claude-opus-4-20250514".to_string())); } #[test] fn test_session_reset() { let mut stats = UsageStats::new(); - stats.add_usage(1000, 2000, "claude-sonnet-4-20250514"); + stats.add_usage(1000, 2000, "claude-sonnet-4-20250514", None, None); stats.reset_session(); assert_eq!(stats.total_input_tokens, 1000); @@ -921,7 +970,7 @@ mod tests { #[test] fn test_usage_stats_serialization() { let mut stats = UsageStats::new(); - stats.add_usage(1000, 2000, "claude-sonnet-4-20250514"); + stats.add_usage(1000, 2000, "claude-sonnet-4-20250514", None, None); stats.increment_messages(); // UsageStats should be serializable (for events) @@ -950,7 +999,7 @@ mod tests { #[test] fn test_stats_update_event_serialization() { let mut stats = UsageStats::new(); - stats.add_usage(100, 200, "claude-sonnet-4-20250514"); + stats.add_usage(100, 200, "claude-sonnet-4-20250514", None, None); let event = StatsUpdateEvent { stats }; let json = serde_json::to_string(&event).expect("Failed to serialize"); @@ -1004,7 +1053,7 @@ mod tests { #[test] fn test_context_tracking_update() { let mut stats = UsageStats::new(); - stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514"); + stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514", None, None); assert_eq!(stats.context_tokens_used, 50_000); assert_eq!(stats.context_window_limit, 200_000); @@ -1014,8 +1063,8 @@ mod tests { #[test] fn test_context_tracking_accumulates() { let mut stats = UsageStats::new(); - stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514"); - stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514"); + stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514", None, None); + stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514", None, None); assert_eq!(stats.context_tokens_used, 100_000); assert!((stats.context_utilisation_percent - 50.0).abs() < 0.1); @@ -1079,7 +1128,7 @@ mod tests { #[test] fn test_context_reset_on_session_reset() { let mut stats = UsageStats::new(); - stats.add_usage(100_000, 20_000, "claude-sonnet-4-20250514"); + stats.add_usage(100_000, 20_000, "claude-sonnet-4-20250514", None, None); assert!(stats.context_tokens_used > 0); assert!(stats.context_utilisation_percent > 0.0); diff --git a/src-tauri/src/types.rs b/src-tauri/src/types.rs index 3607d89..7560a63 100644 --- a/src-tauri/src/types.rs +++ b/src-tauri/src/types.rs @@ -4,6 +4,10 @@ use serde::{Deserialize, Serialize}; pub struct UsageInfo { pub input_tokens: u64, pub output_tokens: u64, + #[serde(default)] + pub cache_creation_input_tokens: Option, + #[serde(default)] + pub cache_read_input_tokens: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] diff --git a/src-tauri/src/wsl_bridge.rs b/src-tauri/src/wsl_bridge.rs index 5f99b16..d925b09 100644 --- a/src-tauri/src/wsl_bridge.rs +++ b/src-tauri/src/wsl_bridge.rs @@ -355,6 +355,15 @@ impl WslBridge { pub fn send_message(&mut self, message: &str) -> Result<(), String> { let stdin = self.stdin.as_mut().ok_or("Process not running")?; + // Track input for cost estimation on interrupt + { + let mut stats = self.stats.write(); + stats.current_request_input = Some(message.to_string()); + stats.current_request_output_chars = 0; + stats.current_request_thinking_chars = 0; + stats.current_request_tools.clear(); + } + let input = serde_json::json!({ "type": "user", "message": { @@ -419,6 +428,9 @@ impl WslBridge { // we have to kill the process. This is the only reliable way to stop it. // See: https://github.com/anthropics/claude-code/issues/3455 if let Some(mut process) = self.process.take() { + // Estimate cost for interrupted request before killing + self.estimate_interrupted_request_cost(app); + // Kill the process immediately let _ = process.kill(); let _ = process.wait(); @@ -426,6 +438,15 @@ impl WslBridge { // Clear stdin self.stdin = None; + // Clear tracking fields + { + let mut stats = self.stats.write(); + stats.current_request_input = None; + stats.current_request_output_chars = 0; + stats.current_request_thinking_chars = 0; + stats.current_request_tools.clear(); + } + // Keep session_id and working directory for user reference // The user will see what session was interrupted @@ -442,6 +463,99 @@ impl WslBridge { } } + fn estimate_interrupted_request_cost(&mut self, app: &AppHandle) { + // Read tracking data from stats + let (input_chars, output_chars, thinking_chars, tools, model) = { + let stats = self.stats.read(); + + // Only estimate if we have tracked content + if stats.current_request_input.is_none() + && stats.current_request_output_chars == 0 + && stats.current_request_thinking_chars == 0 + && stats.current_request_tools.is_empty() { + return; + } + + let input_chars = stats.current_request_input.as_ref().map(|s| s.len() as u64).unwrap_or(0); + let model = stats.model.clone().unwrap_or_else(|| "claude-sonnet-4-5-20250929".to_string()); + + (input_chars, stats.current_request_output_chars, stats.current_request_thinking_chars, stats.current_request_tools.clone(), model) + }; + + println!("[COST ESTIMATION] Estimating cost for interrupted request"); + + // Use conservative 3.5 chars/token for estimation (vs standard 4) + let estimated_input_tokens = (input_chars as f64 / 3.5).ceil() as u64; + let estimated_output_tokens = ((output_chars as f64 / 3.5).ceil() as u64) + + ((thinking_chars as f64 / 3.5).ceil() as u64); + + // Add tool overhead based on session averages + let mut tool_overhead_tokens = 0u64; + { + let stats = self.stats.read(); + for tool_name in &tools { + if let Some(tool_stats) = stats.session_tools_usage.get(tool_name) { + if tool_stats.call_count > 0 { + // Use session average tokens per call for this tool + let avg_tokens = (tool_stats.estimated_input_tokens + tool_stats.estimated_output_tokens) + / tool_stats.call_count; + tool_overhead_tokens += avg_tokens; + println!("[COST ESTIMATION] Tool {} average: {} tokens", tool_name, avg_tokens); + } + } + } + } + + let total_estimated_input = estimated_input_tokens + tool_overhead_tokens; + let total_estimated_output = estimated_output_tokens; + + // Add 20% safety margin to overestimate + let safety_margin = 1.2; + let conservative_input = (total_estimated_input as f64 * safety_margin).ceil() as u64; + let conservative_output = (total_estimated_output as f64 * safety_margin).ceil() as u64; + + println!("[COST ESTIMATION] Input: {} chars → {} tokens (+ {} tool overhead) × 1.2 safety = {} tokens", + input_chars, estimated_input_tokens, tool_overhead_tokens, conservative_input); + println!("[COST ESTIMATION] Output: {} chars → {} tokens × 1.2 safety = {} tokens", + output_chars + thinking_chars, + estimated_output_tokens, conservative_output); + + // Calculate cost (no cache tokens for interrupted requests) + let estimated_cost = calculate_cost( + conservative_input, + conservative_output, + &model, + None, + None, + ); + + println!("[COST ESTIMATION] Estimated cost: ${:.4} (conservative)", estimated_cost); + + // Add to stats with estimated flag + { + let mut stats_guard = self.stats.write(); + stats_guard.add_usage( + conservative_input, + conservative_output, + &model, + None, + None, + ); + } + + // Emit stats update + let stats_update_event = StatsUpdateEvent { + stats: self.stats.read().clone(), + }; + let _ = app.emit("claude:stats", stats_update_event); + + // Record to historical cost tracking (mark as estimated) + let app_clone = app.clone(); + tauri::async_runtime::spawn(async move { + record_cost(&app_clone, conservative_input, conservative_output, estimated_cost).await; + }); + } + pub fn stop(&mut self, app: &AppHandle) { if let Some(mut process) = self.process.take() { let _ = process.kill(); @@ -603,48 +717,72 @@ fn process_json_line( // Only update stats if we have usage information if let Some(usage) = &message.usage { - if let Some(model) = &message.model { - // Calculate cost for historical tracking - let cost_usd = calculate_cost(usage.input_tokens, usage.output_tokens, model); + // Get model from message, or fall back to last known model from stats + let model = message.model.clone().or_else(|| { + let stats_guard = stats.read(); + stats_guard.model.clone() + }).unwrap_or_else(|| { + println!("[WARNING] No model info available for cost calculation, using default"); + "claude-sonnet-4-5-20250929".to_string() + }); - // Store cost for later use in output events - message_cost = Some(MessageCost { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cost_usd, - }); + // Calculate cost for historical tracking (including cache tokens) + let cost_usd = calculate_cost( + usage.input_tokens, + usage.output_tokens, + &model, + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + ); - // Batch all stats updates in a single write lock - { - let mut stats_guard = stats.write(); - stats_guard.increment_messages(); - stats_guard.add_usage(usage.input_tokens, usage.output_tokens, model); - stats_guard.get_session_duration(); + println!("Assistant message tokens - input: {}, output: {}, cache_creation: {:?}, cache_read: {:?}, cost: ${:.4}", + usage.input_tokens, + usage.output_tokens, + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + cost_usd + ); - // Attribute tokens to tools if any tools were used in this message - if !tools_in_message.is_empty() { - let per_tool_input = usage.input_tokens / tools_in_message.len() as u64; - let per_tool_output = usage.output_tokens / tools_in_message.len() as u64; - for tool in &tools_in_message { - stats_guard.add_tool_tokens(tool, per_tool_input, per_tool_output); - } + // Store cost for later use in output events + message_cost = Some(MessageCost { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cost_usd, + }); + + // Batch all stats updates in a single write lock + { + let mut stats_guard = stats.write(); + stats_guard.increment_messages(); + stats_guard.add_usage( + usage.input_tokens, + usage.output_tokens, + &model, + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + ); + stats_guard.get_session_duration(); + + // Attribute tokens to tools if any tools were used in this message + if !tools_in_message.is_empty() { + let per_tool_input = usage.input_tokens / tools_in_message.len() as u64; + let per_tool_output = usage.output_tokens / tools_in_message.len() as u64; + for tool in &tools_in_message { + stats_guard.add_tool_tokens(tool, per_tool_input, per_tool_output); } } - - // Record to historical cost tracking - let app_clone = app.clone(); - let input = usage.input_tokens; - let output = usage.output_tokens; - tauri::async_runtime::spawn(async move { - record_cost(&app_clone, input, output, cost_usd).await; - }); - - // Don't emit here - we'll emit on Result message instead - // This reduces the frequency of updates - } else { - // Just increment message count if no usage info - stats.write().increment_messages(); } + + // Record to historical cost tracking + let app_clone = app.clone(); + let input = usage.input_tokens; + let output = usage.output_tokens; + tauri::async_runtime::spawn(async move { + record_cost(&app_clone, input, output, cost_usd).await; + }); + + // Don't emit here - we'll emit on Result message instead + // This reduces the frequency of updates } else { // Just increment message count if no usage info stats.write().increment_messages(); @@ -722,6 +860,14 @@ fn process_json_line( ClaudeMessage::StreamEvent { event } => { if event.event_type == "content_block_start" { if let Some(block) = &event.content_block { + // Track tool calls for cost estimation + if block.block_type == "tool_use" { + if let Some(name) = &block.name { + let mut stats_guard = stats.write(); + stats_guard.current_request_tools.push(name.clone()); + } + } + let state = match block.block_type.as_str() { "thinking" => CharacterState::Thinking, "text" => CharacterState::Typing, @@ -739,7 +885,16 @@ fn process_json_line( } else if event.event_type == "content_block_delta" { if let Some(delta) = &event.delta { if let Some(text) = &delta.text { + // Track output characters for cost estimation + { + let mut stats_guard = stats.write(); + stats_guard.current_request_output_chars += text.len() as u64; + } let _ = app.emit("claude:stream", text.clone()); + } else if let Some(thinking) = &delta.thinking { + // Track thinking characters for cost estimation + let mut stats_guard = stats.write(); + stats_guard.current_request_thinking_chars += thinking.len() as u64; } } } @@ -768,12 +923,29 @@ fn process_json_line( stats_guard.model.clone().unwrap_or_else(|| "claude-opus-4-20250514".to_string()) }; - // Calculate cost for historical tracking - let cost_usd = calculate_cost(usage_info.input_tokens, usage_info.output_tokens, &model); + // Calculate cost for historical tracking (including cache tokens) + let cost_usd = calculate_cost( + usage_info.input_tokens, + usage_info.output_tokens, + &model, + usage_info.cache_creation_input_tokens, + usage_info.cache_read_input_tokens, + ); let mut stats_guard = stats.write(); - stats_guard.add_usage(usage_info.input_tokens, usage_info.output_tokens, &model); - println!("Result message tokens - input: {}, output: {}", usage_info.input_tokens, usage_info.output_tokens); + stats_guard.add_usage( + usage_info.input_tokens, + usage_info.output_tokens, + &model, + usage_info.cache_creation_input_tokens, + usage_info.cache_read_input_tokens, + ); + println!("Result message tokens - input: {}, output: {}, cache_creation: {:?}, cache_read: {:?}", + usage_info.input_tokens, + usage_info.output_tokens, + usage_info.cache_creation_input_tokens, + usage_info.cache_read_input_tokens + ); // Record to historical cost tracking let app_clone = app.clone(); @@ -784,6 +956,15 @@ fn process_json_line( }); } + // Clear tracking fields since request completed successfully + { + let mut stats_guard = stats.write(); + stats_guard.current_request_input = None; + stats_guard.current_request_output_chars = 0; + stats_guard.current_request_thinking_chars = 0; + stats_guard.current_request_tools.clear(); + } + // Always emit updated stats on result message (less frequent) // This includes the latest session duration let newly_unlocked = { diff --git a/src/lib/commands/slashCommands.ts b/src/lib/commands/slashCommands.ts index 90271a0..acd8034 100644 --- a/src/lib/commands/slashCommands.ts +++ b/src/lib/commands/slashCommands.ts @@ -37,6 +37,12 @@ async function changeDirectory(path: string): Promise { // Capture conversation history before disconnecting const conversationHistory = claudeStore.getConversationHistory(); + // Get currently granted tools and config auto-granted tools + const activeConversation = get(conversationsStore.activeConversation); + const grantedTools = activeConversation ? Array.from(activeConversation.grantedTools) : []; + const config = configStore.getConfig(); + const allAllowedTools = [...new Set([...grantedTools, ...config.auto_granted_tools])]; + await invoke("stop_claude", { conversationId }); // Wait for clean shutdown @@ -50,12 +56,11 @@ async function changeDirectory(path: string): Promise { conversationId, options: { working_dir: validatedPath, + allowed_tools: allAllowedTools, }, }); // Update Discord RPC when reconnecting after directory change - const config = configStore.getConfig(); - const activeConversation = get(conversationsStore.activeConversation); if (activeConversation) { await updateDiscordRpc( activeConversation.name, @@ -102,6 +107,12 @@ async function startNewConversation(): Promise { conversationId, }); + // Get granted tools before interrupting + const activeConversation = get(conversationsStore.activeConversation); + const grantedTools = activeConversation ? Array.from(activeConversation.grantedTools) : []; + const config = configStore.getConfig(); + const allAllowedTools = [...new Set([...grantedTools, ...config.auto_granted_tools])]; + claudeStore.addLine("system", "Starting new conversation..."); characterState.setState("thinking"); @@ -115,12 +126,11 @@ async function startNewConversation(): Promise { conversationId, options: { working_dir: workingDir, + allowed_tools: allAllowedTools, }, }); // Update Discord RPC when starting new conversation - const config = configStore.getConfig(); - const activeConversation = get(conversationsStore.activeConversation); if (activeConversation) { await updateDiscordRpc( activeConversation.name, diff --git a/src/lib/components/CompactMode.svelte b/src/lib/components/CompactMode.svelte index 5dbbbd1..ca172e8 100644 --- a/src/lib/components/CompactMode.svelte +++ b/src/lib/components/CompactMode.svelte @@ -5,6 +5,7 @@ import { characterState, characterInfo } from "$lib/stores/character"; import { isStreamerMode } from "$lib/stores/config"; import { handleNewUserMessage } from "$lib/notifications/rules"; + import { setSkipNextGreeting } from "$lib/tauri"; import type { CharacterState, CharacterStateInfo } from "$lib/types/states"; interface Props { @@ -127,6 +128,9 @@ const conversationId = get(claudeStore.activeConversationId); if (!conversationId) return; + // Set flag to preserve stats/permissions (don't treat next connect as new session) + setSkipNextGreeting(true); + await invoke("interrupt_claude", { conversationId }); claudeStore.addLine("system", "Interrupted"); characterState.setState("idle"); diff --git a/src/lib/components/InputBar.svelte b/src/lib/components/InputBar.svelte index bbb09ac..e50bc79 100644 --- a/src/lib/components/InputBar.svelte +++ b/src/lib/components/InputBar.svelte @@ -338,23 +338,28 @@ User: ${formattedMessage}`; throw new Error("No active conversation"); } - // Get current working directory before reconnecting + // Get current working directory and granted tools before reconnecting const workingDir = await invoke("get_working_directory", { conversationId }); + const activeConversation = get(conversationsStore.activeConversation); + const grantedTools = activeConversation + ? Array.from(activeConversation.grantedTools) + : []; + const config = configStore.getConfig(); + const allAllowedTools = [...new Set([...grantedTools, ...config.auto_granted_tools])]; // Set the flag to skip greeting on next connection setSkipNextGreeting(true); - // Reconnect to Claude + // Reconnect to Claude with preserved permissions await invoke("start_claude", { conversationId, options: { working_dir: workingDir, + allowed_tools: allAllowedTools, }, }); // Update Discord RPC when reconnecting - const config = configStore.getConfig(); - const activeConversation = get(conversationsStore.activeConversation); if (activeConversation) { await updateDiscordRpc( activeConversation.name, diff --git a/src/lib/components/PermissionModal.svelte b/src/lib/components/PermissionModal.svelte index e96afd5..c41ffe8 100644 --- a/src/lib/components/PermissionModal.svelte +++ b/src/lib/components/PermissionModal.svelte @@ -1,22 +1,17 @@