diff --git a/src-tauri/src/stats.rs b/src-tauri/src/stats.rs index 19eb422..2613226 100644 --- a/src-tauri/src/stats.rs +++ b/src-tauri/src/stats.rs @@ -154,6 +154,16 @@ pub struct UsageStats { // Achievement tracking #[serde(skip)] pub achievements: AchievementProgress, + + // Track current in-flight request for cost estimation on interrupt + #[serde(skip)] + pub current_request_input: Option, + #[serde(skip)] + pub current_request_output_chars: u64, + #[serde(skip)] + pub current_request_thinking_chars: u64, + #[serde(skip)] + pub current_request_tools: Vec, } impl UsageStats { @@ -452,6 +462,10 @@ impl UsageStats { potential_cache_hits: self.potential_cache_hits, potential_cache_savings_tokens: self.potential_cache_savings_tokens, achievements: AchievementProgress::new(), // Dummy for copy + current_request_input: None, // Don't copy tracking fields + current_request_output_chars: 0, + current_request_thinking_chars: 0, + current_request_tools: Vec::new(), }; check_achievements(&stats_copy, &mut self.achievements) } diff --git a/src-tauri/src/wsl_bridge.rs b/src-tauri/src/wsl_bridge.rs index 68019ba..1fa7102 100644 --- a/src-tauri/src/wsl_bridge.rs +++ b/src-tauri/src/wsl_bridge.rs @@ -355,6 +355,15 @@ impl WslBridge { pub fn send_message(&mut self, message: &str) -> Result<(), String> { let stdin = self.stdin.as_mut().ok_or("Process not running")?; + // Track input for cost estimation on interrupt + { + let mut stats = self.stats.write(); + stats.current_request_input = Some(message.to_string()); + stats.current_request_output_chars = 0; + stats.current_request_thinking_chars = 0; + stats.current_request_tools.clear(); + } + let input = serde_json::json!({ "type": "user", "message": { @@ -419,6 +428,9 @@ impl WslBridge { // we have to kill the process. This is the only reliable way to stop it. // See: https://github.com/anthropics/claude-code/issues/3455 if let Some(mut process) = self.process.take() { + // Estimate cost for interrupted request before killing + self.estimate_interrupted_request_cost(app); + // Kill the process immediately let _ = process.kill(); let _ = process.wait(); @@ -426,6 +438,15 @@ impl WslBridge { // Clear stdin self.stdin = None; + // Clear tracking fields + { + let mut stats = self.stats.write(); + stats.current_request_input = None; + stats.current_request_output_chars = 0; + stats.current_request_thinking_chars = 0; + stats.current_request_tools.clear(); + } + // Keep session_id and working directory for user reference // The user will see what session was interrupted @@ -442,6 +463,99 @@ impl WslBridge { } } + fn estimate_interrupted_request_cost(&mut self, app: &AppHandle) { + // Read tracking data from stats + let (input_chars, output_chars, thinking_chars, tools, model) = { + let stats = self.stats.read(); + + // Only estimate if we have tracked content + if stats.current_request_input.is_none() + && stats.current_request_output_chars == 0 + && stats.current_request_thinking_chars == 0 + && stats.current_request_tools.is_empty() { + return; + } + + let input_chars = stats.current_request_input.as_ref().map(|s| s.len() as u64).unwrap_or(0); + let model = stats.model.clone().unwrap_or_else(|| "claude-sonnet-4-5-20250929".to_string()); + + (input_chars, stats.current_request_output_chars, stats.current_request_thinking_chars, stats.current_request_tools.clone(), model) + }; + + println!("[COST ESTIMATION] Estimating cost for interrupted request"); + + // Use conservative 3.5 chars/token for estimation (vs standard 4) + let estimated_input_tokens = (input_chars as f64 / 3.5).ceil() as u64; + let estimated_output_tokens = ((output_chars as f64 / 3.5).ceil() as u64) + + ((thinking_chars as f64 / 3.5).ceil() as u64); + + // Add tool overhead based on session averages + let mut tool_overhead_tokens = 0u64; + { + let stats = self.stats.read(); + for tool_name in &tools { + if let Some(tool_stats) = stats.session_tools_usage.get(tool_name) { + if tool_stats.call_count > 0 { + // Use session average tokens per call for this tool + let avg_tokens = (tool_stats.estimated_input_tokens + tool_stats.estimated_output_tokens) + / tool_stats.call_count as u64; + tool_overhead_tokens += avg_tokens; + println!("[COST ESTIMATION] Tool {} average: {} tokens", tool_name, avg_tokens); + } + } + } + } + + let total_estimated_input = estimated_input_tokens + tool_overhead_tokens; + let total_estimated_output = estimated_output_tokens; + + // Add 20% safety margin to overestimate + let safety_margin = 1.2; + let conservative_input = (total_estimated_input as f64 * safety_margin).ceil() as u64; + let conservative_output = (total_estimated_output as f64 * safety_margin).ceil() as u64; + + println!("[COST ESTIMATION] Input: {} chars → {} tokens (+ {} tool overhead) × 1.2 safety = {} tokens", + input_chars, estimated_input_tokens, tool_overhead_tokens, conservative_input); + println!("[COST ESTIMATION] Output: {} chars → {} tokens × 1.2 safety = {} tokens", + output_chars + thinking_chars, + estimated_output_tokens, conservative_output); + + // Calculate cost (no cache tokens for interrupted requests) + let estimated_cost = calculate_cost( + conservative_input, + conservative_output, + &model, + None, + None, + ); + + println!("[COST ESTIMATION] Estimated cost: ${:.4} (conservative)", estimated_cost); + + // Add to stats with estimated flag + { + let mut stats_guard = self.stats.write(); + stats_guard.add_usage( + conservative_input, + conservative_output, + &model, + None, + None, + ); + } + + // Emit stats update + let stats_update_event = StatsUpdateEvent { + stats: self.stats.read().clone(), + }; + let _ = app.emit("claude:stats", stats_update_event); + + // Record to historical cost tracking (mark as estimated) + let app_clone = app.clone(); + tauri::async_runtime::spawn(async move { + record_cost(&app_clone, conservative_input, conservative_output, estimated_cost).await; + }); + } + pub fn stop(&mut self, app: &AppHandle) { if let Some(mut process) = self.process.take() { let _ = process.kill(); @@ -734,6 +848,14 @@ fn process_json_line( ClaudeMessage::StreamEvent { event } => { if event.event_type == "content_block_start" { if let Some(block) = &event.content_block { + // Track tool calls for cost estimation + if block.block_type == "tool_use" { + if let Some(name) = &block.name { + let mut stats_guard = stats.write(); + stats_guard.current_request_tools.push(name.clone()); + } + } + let state = match block.block_type.as_str() { "thinking" => CharacterState::Thinking, "text" => CharacterState::Typing, @@ -751,7 +873,16 @@ fn process_json_line( } else if event.event_type == "content_block_delta" { if let Some(delta) = &event.delta { if let Some(text) = &delta.text { + // Track output characters for cost estimation + { + let mut stats_guard = stats.write(); + stats_guard.current_request_output_chars += text.len() as u64; + } let _ = app.emit("claude:stream", text.clone()); + } else if let Some(thinking) = &delta.thinking { + // Track thinking characters for cost estimation + let mut stats_guard = stats.write(); + stats_guard.current_request_thinking_chars += thinking.len() as u64; } } } @@ -813,6 +944,15 @@ fn process_json_line( }); } + // Clear tracking fields since request completed successfully + { + let mut stats_guard = stats.write(); + stats_guard.current_request_input = None; + stats_guard.current_request_output_chars = 0; + stats_guard.current_request_thinking_chars = 0; + stats_guard.current_request_tools.clear(); + } + // Always emit updated stats on result message (less frequent) // This includes the latest session duration let newly_unlocked = { diff --git a/src/lib/commands/slashCommands.ts b/src/lib/commands/slashCommands.ts index 4df82a3..acd8034 100644 --- a/src/lib/commands/slashCommands.ts +++ b/src/lib/commands/slashCommands.ts @@ -131,8 +131,6 @@ async function startNewConversation(): Promise { }); // Update Discord RPC when starting new conversation - const config = configStore.getConfig(); - const activeConversation = get(conversationsStore.activeConversation); if (activeConversation) { await updateDiscordRpc( activeConversation.name, diff --git a/src/lib/components/InputBar.svelte b/src/lib/components/InputBar.svelte index 2d55fd6..4c3da2e 100644 --- a/src/lib/components/InputBar.svelte +++ b/src/lib/components/InputBar.svelte @@ -358,8 +358,6 @@ User: ${formattedMessage}`; }); // Update Discord RPC when reconnecting - const config = configStore.getConfig(); - const activeConversation = get(conversationsStore.activeConversation); if (activeConversation) { await updateDiscordRpc( activeConversation.name,