feat: estimate costs when interruption prevents getting cost results

This commit is contained in:
2026-02-06 10:05:15 -08:00
parent 4aa538db43
commit c471c7ad3d
4 changed files with 154 additions and 4 deletions
+14
View File
@@ -154,6 +154,16 @@ pub struct UsageStats {
// Achievement tracking // Achievement tracking
#[serde(skip)] #[serde(skip)]
pub achievements: AchievementProgress, pub achievements: AchievementProgress,
// Track current in-flight request for cost estimation on interrupt
#[serde(skip)]
pub current_request_input: Option<String>,
#[serde(skip)]
pub current_request_output_chars: u64,
#[serde(skip)]
pub current_request_thinking_chars: u64,
#[serde(skip)]
pub current_request_tools: Vec<String>,
} }
impl UsageStats { impl UsageStats {
@@ -452,6 +462,10 @@ impl UsageStats {
potential_cache_hits: self.potential_cache_hits, potential_cache_hits: self.potential_cache_hits,
potential_cache_savings_tokens: self.potential_cache_savings_tokens, potential_cache_savings_tokens: self.potential_cache_savings_tokens,
achievements: AchievementProgress::new(), // Dummy for copy achievements: AchievementProgress::new(), // Dummy for copy
current_request_input: None, // Don't copy tracking fields
current_request_output_chars: 0,
current_request_thinking_chars: 0,
current_request_tools: Vec::new(),
}; };
check_achievements(&stats_copy, &mut self.achievements) check_achievements(&stats_copy, &mut self.achievements)
} }
+140
View File
@@ -355,6 +355,15 @@ impl WslBridge {
pub fn send_message(&mut self, message: &str) -> Result<(), String> { pub fn send_message(&mut self, message: &str) -> Result<(), String> {
let stdin = self.stdin.as_mut().ok_or("Process not running")?; let stdin = self.stdin.as_mut().ok_or("Process not running")?;
// Track input for cost estimation on interrupt
{
let mut stats = self.stats.write();
stats.current_request_input = Some(message.to_string());
stats.current_request_output_chars = 0;
stats.current_request_thinking_chars = 0;
stats.current_request_tools.clear();
}
let input = serde_json::json!({ let input = serde_json::json!({
"type": "user", "type": "user",
"message": { "message": {
@@ -419,6 +428,9 @@ impl WslBridge {
// we have to kill the process. This is the only reliable way to stop it. // we have to kill the process. This is the only reliable way to stop it.
// See: https://github.com/anthropics/claude-code/issues/3455 // See: https://github.com/anthropics/claude-code/issues/3455
if let Some(mut process) = self.process.take() { if let Some(mut process) = self.process.take() {
// Estimate cost for interrupted request before killing
self.estimate_interrupted_request_cost(app);
// Kill the process immediately // Kill the process immediately
let _ = process.kill(); let _ = process.kill();
let _ = process.wait(); let _ = process.wait();
@@ -426,6 +438,15 @@ impl WslBridge {
// Clear stdin // Clear stdin
self.stdin = None; self.stdin = None;
// Clear tracking fields
{
let mut stats = self.stats.write();
stats.current_request_input = None;
stats.current_request_output_chars = 0;
stats.current_request_thinking_chars = 0;
stats.current_request_tools.clear();
}
// Keep session_id and working directory for user reference // Keep session_id and working directory for user reference
// The user will see what session was interrupted // The user will see what session was interrupted
@@ -442,6 +463,99 @@ impl WslBridge {
} }
} }
fn estimate_interrupted_request_cost(&mut self, app: &AppHandle) {
// Read tracking data from stats
let (input_chars, output_chars, thinking_chars, tools, model) = {
let stats = self.stats.read();
// Only estimate if we have tracked content
if stats.current_request_input.is_none()
&& stats.current_request_output_chars == 0
&& stats.current_request_thinking_chars == 0
&& stats.current_request_tools.is_empty() {
return;
}
let input_chars = stats.current_request_input.as_ref().map(|s| s.len() as u64).unwrap_or(0);
let model = stats.model.clone().unwrap_or_else(|| "claude-sonnet-4-5-20250929".to_string());
(input_chars, stats.current_request_output_chars, stats.current_request_thinking_chars, stats.current_request_tools.clone(), model)
};
println!("[COST ESTIMATION] Estimating cost for interrupted request");
// Use conservative 3.5 chars/token for estimation (vs standard 4)
let estimated_input_tokens = (input_chars as f64 / 3.5).ceil() as u64;
let estimated_output_tokens = ((output_chars as f64 / 3.5).ceil() as u64)
+ ((thinking_chars as f64 / 3.5).ceil() as u64);
// Add tool overhead based on session averages
let mut tool_overhead_tokens = 0u64;
{
let stats = self.stats.read();
for tool_name in &tools {
if let Some(tool_stats) = stats.session_tools_usage.get(tool_name) {
if tool_stats.call_count > 0 {
// Use session average tokens per call for this tool
let avg_tokens = (tool_stats.estimated_input_tokens + tool_stats.estimated_output_tokens)
/ tool_stats.call_count as u64;
tool_overhead_tokens += avg_tokens;
println!("[COST ESTIMATION] Tool {} average: {} tokens", tool_name, avg_tokens);
}
}
}
}
let total_estimated_input = estimated_input_tokens + tool_overhead_tokens;
let total_estimated_output = estimated_output_tokens;
// Add 20% safety margin to overestimate
let safety_margin = 1.2;
let conservative_input = (total_estimated_input as f64 * safety_margin).ceil() as u64;
let conservative_output = (total_estimated_output as f64 * safety_margin).ceil() as u64;
println!("[COST ESTIMATION] Input: {} chars → {} tokens (+ {} tool overhead) × 1.2 safety = {} tokens",
input_chars, estimated_input_tokens, tool_overhead_tokens, conservative_input);
println!("[COST ESTIMATION] Output: {} chars → {} tokens × 1.2 safety = {} tokens",
output_chars + thinking_chars,
estimated_output_tokens, conservative_output);
// Calculate cost (no cache tokens for interrupted requests)
let estimated_cost = calculate_cost(
conservative_input,
conservative_output,
&model,
None,
None,
);
println!("[COST ESTIMATION] Estimated cost: ${:.4} (conservative)", estimated_cost);
// Add to stats with estimated flag
{
let mut stats_guard = self.stats.write();
stats_guard.add_usage(
conservative_input,
conservative_output,
&model,
None,
None,
);
}
// Emit stats update
let stats_update_event = StatsUpdateEvent {
stats: self.stats.read().clone(),
};
let _ = app.emit("claude:stats", stats_update_event);
// Record to historical cost tracking (mark as estimated)
let app_clone = app.clone();
tauri::async_runtime::spawn(async move {
record_cost(&app_clone, conservative_input, conservative_output, estimated_cost).await;
});
}
pub fn stop(&mut self, app: &AppHandle) { pub fn stop(&mut self, app: &AppHandle) {
if let Some(mut process) = self.process.take() { if let Some(mut process) = self.process.take() {
let _ = process.kill(); let _ = process.kill();
@@ -734,6 +848,14 @@ fn process_json_line(
ClaudeMessage::StreamEvent { event } => { ClaudeMessage::StreamEvent { event } => {
if event.event_type == "content_block_start" { if event.event_type == "content_block_start" {
if let Some(block) = &event.content_block { if let Some(block) = &event.content_block {
// Track tool calls for cost estimation
if block.block_type == "tool_use" {
if let Some(name) = &block.name {
let mut stats_guard = stats.write();
stats_guard.current_request_tools.push(name.clone());
}
}
let state = match block.block_type.as_str() { let state = match block.block_type.as_str() {
"thinking" => CharacterState::Thinking, "thinking" => CharacterState::Thinking,
"text" => CharacterState::Typing, "text" => CharacterState::Typing,
@@ -751,7 +873,16 @@ fn process_json_line(
} else if event.event_type == "content_block_delta" { } else if event.event_type == "content_block_delta" {
if let Some(delta) = &event.delta { if let Some(delta) = &event.delta {
if let Some(text) = &delta.text { if let Some(text) = &delta.text {
// Track output characters for cost estimation
{
let mut stats_guard = stats.write();
stats_guard.current_request_output_chars += text.len() as u64;
}
let _ = app.emit("claude:stream", text.clone()); let _ = app.emit("claude:stream", text.clone());
} else if let Some(thinking) = &delta.thinking {
// Track thinking characters for cost estimation
let mut stats_guard = stats.write();
stats_guard.current_request_thinking_chars += thinking.len() as u64;
} }
} }
} }
@@ -813,6 +944,15 @@ fn process_json_line(
}); });
} }
// Clear tracking fields since request completed successfully
{
let mut stats_guard = stats.write();
stats_guard.current_request_input = None;
stats_guard.current_request_output_chars = 0;
stats_guard.current_request_thinking_chars = 0;
stats_guard.current_request_tools.clear();
}
// Always emit updated stats on result message (less frequent) // Always emit updated stats on result message (less frequent)
// This includes the latest session duration // This includes the latest session duration
let newly_unlocked = { let newly_unlocked = {
-2
View File
@@ -131,8 +131,6 @@ async function startNewConversation(): Promise<void> {
}); });
// Update Discord RPC when starting new conversation // Update Discord RPC when starting new conversation
const config = configStore.getConfig();
const activeConversation = get(conversationsStore.activeConversation);
if (activeConversation) { if (activeConversation) {
await updateDiscordRpc( await updateDiscordRpc(
activeConversation.name, activeConversation.name,
-2
View File
@@ -358,8 +358,6 @@ User: ${formattedMessage}`;
}); });
// Update Discord RPC when reconnecting // Update Discord RPC when reconnecting
const config = configStore.getConfig();
const activeConversation = get(conversationsStore.activeConversation);
if (activeConversation) { if (activeConversation) {
await updateDiscordRpc( await updateDiscordRpc(
activeConversation.name, activeConversation.name,