generated from nhcarrigan/template
fix: ensure permission/stats persist until explicit disconnect (#110)
Also includes cached tokens in cost calculations to provide more accurate billing estimates. Reviewed-on: #110 Co-authored-by: Naomi Carrigan <commits@nhcarrigan.com> Co-committed-by: Naomi Carrigan <commits@nhcarrigan.com>
This commit was merged in pull request #110.
This commit is contained in:
+221
-40
@@ -355,6 +355,15 @@ impl WslBridge {
|
||||
pub fn send_message(&mut self, message: &str) -> Result<(), String> {
|
||||
let stdin = self.stdin.as_mut().ok_or("Process not running")?;
|
||||
|
||||
// Track input for cost estimation on interrupt
|
||||
{
|
||||
let mut stats = self.stats.write();
|
||||
stats.current_request_input = Some(message.to_string());
|
||||
stats.current_request_output_chars = 0;
|
||||
stats.current_request_thinking_chars = 0;
|
||||
stats.current_request_tools.clear();
|
||||
}
|
||||
|
||||
let input = serde_json::json!({
|
||||
"type": "user",
|
||||
"message": {
|
||||
@@ -419,6 +428,9 @@ impl WslBridge {
|
||||
// we have to kill the process. This is the only reliable way to stop it.
|
||||
// See: https://github.com/anthropics/claude-code/issues/3455
|
||||
if let Some(mut process) = self.process.take() {
|
||||
// Estimate cost for interrupted request before killing
|
||||
self.estimate_interrupted_request_cost(app);
|
||||
|
||||
// Kill the process immediately
|
||||
let _ = process.kill();
|
||||
let _ = process.wait();
|
||||
@@ -426,6 +438,15 @@ impl WslBridge {
|
||||
// Clear stdin
|
||||
self.stdin = None;
|
||||
|
||||
// Clear tracking fields
|
||||
{
|
||||
let mut stats = self.stats.write();
|
||||
stats.current_request_input = None;
|
||||
stats.current_request_output_chars = 0;
|
||||
stats.current_request_thinking_chars = 0;
|
||||
stats.current_request_tools.clear();
|
||||
}
|
||||
|
||||
// Keep session_id and working directory for user reference
|
||||
// The user will see what session was interrupted
|
||||
|
||||
@@ -442,6 +463,99 @@ impl WslBridge {
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_interrupted_request_cost(&mut self, app: &AppHandle) {
|
||||
// Read tracking data from stats
|
||||
let (input_chars, output_chars, thinking_chars, tools, model) = {
|
||||
let stats = self.stats.read();
|
||||
|
||||
// Only estimate if we have tracked content
|
||||
if stats.current_request_input.is_none()
|
||||
&& stats.current_request_output_chars == 0
|
||||
&& stats.current_request_thinking_chars == 0
|
||||
&& stats.current_request_tools.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let input_chars = stats.current_request_input.as_ref().map(|s| s.len() as u64).unwrap_or(0);
|
||||
let model = stats.model.clone().unwrap_or_else(|| "claude-sonnet-4-5-20250929".to_string());
|
||||
|
||||
(input_chars, stats.current_request_output_chars, stats.current_request_thinking_chars, stats.current_request_tools.clone(), model)
|
||||
};
|
||||
|
||||
println!("[COST ESTIMATION] Estimating cost for interrupted request");
|
||||
|
||||
// Use conservative 3.5 chars/token for estimation (vs standard 4)
|
||||
let estimated_input_tokens = (input_chars as f64 / 3.5).ceil() as u64;
|
||||
let estimated_output_tokens = ((output_chars as f64 / 3.5).ceil() as u64)
|
||||
+ ((thinking_chars as f64 / 3.5).ceil() as u64);
|
||||
|
||||
// Add tool overhead based on session averages
|
||||
let mut tool_overhead_tokens = 0u64;
|
||||
{
|
||||
let stats = self.stats.read();
|
||||
for tool_name in &tools {
|
||||
if let Some(tool_stats) = stats.session_tools_usage.get(tool_name) {
|
||||
if tool_stats.call_count > 0 {
|
||||
// Use session average tokens per call for this tool
|
||||
let avg_tokens = (tool_stats.estimated_input_tokens + tool_stats.estimated_output_tokens)
|
||||
/ tool_stats.call_count;
|
||||
tool_overhead_tokens += avg_tokens;
|
||||
println!("[COST ESTIMATION] Tool {} average: {} tokens", tool_name, avg_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_estimated_input = estimated_input_tokens + tool_overhead_tokens;
|
||||
let total_estimated_output = estimated_output_tokens;
|
||||
|
||||
// Add 20% safety margin to overestimate
|
||||
let safety_margin = 1.2;
|
||||
let conservative_input = (total_estimated_input as f64 * safety_margin).ceil() as u64;
|
||||
let conservative_output = (total_estimated_output as f64 * safety_margin).ceil() as u64;
|
||||
|
||||
println!("[COST ESTIMATION] Input: {} chars → {} tokens (+ {} tool overhead) × 1.2 safety = {} tokens",
|
||||
input_chars, estimated_input_tokens, tool_overhead_tokens, conservative_input);
|
||||
println!("[COST ESTIMATION] Output: {} chars → {} tokens × 1.2 safety = {} tokens",
|
||||
output_chars + thinking_chars,
|
||||
estimated_output_tokens, conservative_output);
|
||||
|
||||
// Calculate cost (no cache tokens for interrupted requests)
|
||||
let estimated_cost = calculate_cost(
|
||||
conservative_input,
|
||||
conservative_output,
|
||||
&model,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
println!("[COST ESTIMATION] Estimated cost: ${:.4} (conservative)", estimated_cost);
|
||||
|
||||
// Add to stats with estimated flag
|
||||
{
|
||||
let mut stats_guard = self.stats.write();
|
||||
stats_guard.add_usage(
|
||||
conservative_input,
|
||||
conservative_output,
|
||||
&model,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
// Emit stats update
|
||||
let stats_update_event = StatsUpdateEvent {
|
||||
stats: self.stats.read().clone(),
|
||||
};
|
||||
let _ = app.emit("claude:stats", stats_update_event);
|
||||
|
||||
// Record to historical cost tracking (mark as estimated)
|
||||
let app_clone = app.clone();
|
||||
tauri::async_runtime::spawn(async move {
|
||||
record_cost(&app_clone, conservative_input, conservative_output, estimated_cost).await;
|
||||
});
|
||||
}
|
||||
|
||||
pub fn stop(&mut self, app: &AppHandle) {
|
||||
if let Some(mut process) = self.process.take() {
|
||||
let _ = process.kill();
|
||||
@@ -603,48 +717,72 @@ fn process_json_line(
|
||||
|
||||
// Only update stats if we have usage information
|
||||
if let Some(usage) = &message.usage {
|
||||
if let Some(model) = &message.model {
|
||||
// Calculate cost for historical tracking
|
||||
let cost_usd = calculate_cost(usage.input_tokens, usage.output_tokens, model);
|
||||
// Get model from message, or fall back to last known model from stats
|
||||
let model = message.model.clone().or_else(|| {
|
||||
let stats_guard = stats.read();
|
||||
stats_guard.model.clone()
|
||||
}).unwrap_or_else(|| {
|
||||
println!("[WARNING] No model info available for cost calculation, using default");
|
||||
"claude-sonnet-4-5-20250929".to_string()
|
||||
});
|
||||
|
||||
// Store cost for later use in output events
|
||||
message_cost = Some(MessageCost {
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
cost_usd,
|
||||
});
|
||||
// Calculate cost for historical tracking (including cache tokens)
|
||||
let cost_usd = calculate_cost(
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
&model,
|
||||
usage.cache_creation_input_tokens,
|
||||
usage.cache_read_input_tokens,
|
||||
);
|
||||
|
||||
// Batch all stats updates in a single write lock
|
||||
{
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.increment_messages();
|
||||
stats_guard.add_usage(usage.input_tokens, usage.output_tokens, model);
|
||||
stats_guard.get_session_duration();
|
||||
println!("Assistant message tokens - input: {}, output: {}, cache_creation: {:?}, cache_read: {:?}, cost: ${:.4}",
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
usage.cache_creation_input_tokens,
|
||||
usage.cache_read_input_tokens,
|
||||
cost_usd
|
||||
);
|
||||
|
||||
// Attribute tokens to tools if any tools were used in this message
|
||||
if !tools_in_message.is_empty() {
|
||||
let per_tool_input = usage.input_tokens / tools_in_message.len() as u64;
|
||||
let per_tool_output = usage.output_tokens / tools_in_message.len() as u64;
|
||||
for tool in &tools_in_message {
|
||||
stats_guard.add_tool_tokens(tool, per_tool_input, per_tool_output);
|
||||
}
|
||||
// Store cost for later use in output events
|
||||
message_cost = Some(MessageCost {
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
cost_usd,
|
||||
});
|
||||
|
||||
// Batch all stats updates in a single write lock
|
||||
{
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.increment_messages();
|
||||
stats_guard.add_usage(
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
&model,
|
||||
usage.cache_creation_input_tokens,
|
||||
usage.cache_read_input_tokens,
|
||||
);
|
||||
stats_guard.get_session_duration();
|
||||
|
||||
// Attribute tokens to tools if any tools were used in this message
|
||||
if !tools_in_message.is_empty() {
|
||||
let per_tool_input = usage.input_tokens / tools_in_message.len() as u64;
|
||||
let per_tool_output = usage.output_tokens / tools_in_message.len() as u64;
|
||||
for tool in &tools_in_message {
|
||||
stats_guard.add_tool_tokens(tool, per_tool_input, per_tool_output);
|
||||
}
|
||||
}
|
||||
|
||||
// Record to historical cost tracking
|
||||
let app_clone = app.clone();
|
||||
let input = usage.input_tokens;
|
||||
let output = usage.output_tokens;
|
||||
tauri::async_runtime::spawn(async move {
|
||||
record_cost(&app_clone, input, output, cost_usd).await;
|
||||
});
|
||||
|
||||
// Don't emit here - we'll emit on Result message instead
|
||||
// This reduces the frequency of updates
|
||||
} else {
|
||||
// Just increment message count if no usage info
|
||||
stats.write().increment_messages();
|
||||
}
|
||||
|
||||
// Record to historical cost tracking
|
||||
let app_clone = app.clone();
|
||||
let input = usage.input_tokens;
|
||||
let output = usage.output_tokens;
|
||||
tauri::async_runtime::spawn(async move {
|
||||
record_cost(&app_clone, input, output, cost_usd).await;
|
||||
});
|
||||
|
||||
// Don't emit here - we'll emit on Result message instead
|
||||
// This reduces the frequency of updates
|
||||
} else {
|
||||
// Just increment message count if no usage info
|
||||
stats.write().increment_messages();
|
||||
@@ -722,6 +860,14 @@ fn process_json_line(
|
||||
ClaudeMessage::StreamEvent { event } => {
|
||||
if event.event_type == "content_block_start" {
|
||||
if let Some(block) = &event.content_block {
|
||||
// Track tool calls for cost estimation
|
||||
if block.block_type == "tool_use" {
|
||||
if let Some(name) = &block.name {
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.current_request_tools.push(name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let state = match block.block_type.as_str() {
|
||||
"thinking" => CharacterState::Thinking,
|
||||
"text" => CharacterState::Typing,
|
||||
@@ -739,7 +885,16 @@ fn process_json_line(
|
||||
} else if event.event_type == "content_block_delta" {
|
||||
if let Some(delta) = &event.delta {
|
||||
if let Some(text) = &delta.text {
|
||||
// Track output characters for cost estimation
|
||||
{
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.current_request_output_chars += text.len() as u64;
|
||||
}
|
||||
let _ = app.emit("claude:stream", text.clone());
|
||||
} else if let Some(thinking) = &delta.thinking {
|
||||
// Track thinking characters for cost estimation
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.current_request_thinking_chars += thinking.len() as u64;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -768,12 +923,29 @@ fn process_json_line(
|
||||
stats_guard.model.clone().unwrap_or_else(|| "claude-opus-4-20250514".to_string())
|
||||
};
|
||||
|
||||
// Calculate cost for historical tracking
|
||||
let cost_usd = calculate_cost(usage_info.input_tokens, usage_info.output_tokens, &model);
|
||||
// Calculate cost for historical tracking (including cache tokens)
|
||||
let cost_usd = calculate_cost(
|
||||
usage_info.input_tokens,
|
||||
usage_info.output_tokens,
|
||||
&model,
|
||||
usage_info.cache_creation_input_tokens,
|
||||
usage_info.cache_read_input_tokens,
|
||||
);
|
||||
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.add_usage(usage_info.input_tokens, usage_info.output_tokens, &model);
|
||||
println!("Result message tokens - input: {}, output: {}", usage_info.input_tokens, usage_info.output_tokens);
|
||||
stats_guard.add_usage(
|
||||
usage_info.input_tokens,
|
||||
usage_info.output_tokens,
|
||||
&model,
|
||||
usage_info.cache_creation_input_tokens,
|
||||
usage_info.cache_read_input_tokens,
|
||||
);
|
||||
println!("Result message tokens - input: {}, output: {}, cache_creation: {:?}, cache_read: {:?}",
|
||||
usage_info.input_tokens,
|
||||
usage_info.output_tokens,
|
||||
usage_info.cache_creation_input_tokens,
|
||||
usage_info.cache_read_input_tokens
|
||||
);
|
||||
|
||||
// Record to historical cost tracking
|
||||
let app_clone = app.clone();
|
||||
@@ -784,6 +956,15 @@ fn process_json_line(
|
||||
});
|
||||
}
|
||||
|
||||
// Clear tracking fields since request completed successfully
|
||||
{
|
||||
let mut stats_guard = stats.write();
|
||||
stats_guard.current_request_input = None;
|
||||
stats_guard.current_request_output_chars = 0;
|
||||
stats_guard.current_request_thinking_chars = 0;
|
||||
stats_guard.current_request_tools.clear();
|
||||
}
|
||||
|
||||
// Always emit updated stats on result message (less frequent)
|
||||
// This includes the latest session duration
|
||||
let newly_unlocked = {
|
||||
|
||||
Reference in New Issue
Block a user