diff --git a/src-tauri/src/achievements.rs b/src-tauri/src/achievements.rs index f19456a..16d24f2 100644 --- a/src-tauri/src/achievements.rs +++ b/src-tauri/src/achievements.rs @@ -1935,6 +1935,7 @@ pub fn check_achievements( let search_count: u64 = search_tools .iter() .filter_map(|tool| stats.tools_usage.get(*tool)) + .map(|t| t.call_count) .sum(); if search_count >= 50 && progress.unlock(AchievementId::Explorer) { newly_unlocked.push(AchievementId::Explorer); @@ -1988,25 +1989,25 @@ pub fn check_achievements( // TODO: Track different Claude models used // Tool mastery achievements - if let Some(bash_count) = stats.tools_usage.get("Bash") { - if *bash_count >= 50 && progress.unlock(AchievementId::BashMaster) { + if let Some(bash_stats) = stats.tools_usage.get("Bash") { + if bash_stats.call_count >= 50 && progress.unlock(AchievementId::BashMaster) { newly_unlocked.push(AchievementId::BashMaster); } } - if let Some(read_count) = stats.tools_usage.get("Read") { - if *read_count >= 100 && progress.unlock(AchievementId::FileExplorer) { + if let Some(read_stats) = stats.tools_usage.get("Read") { + if read_stats.call_count >= 100 && progress.unlock(AchievementId::FileExplorer) { newly_unlocked.push(AchievementId::FileExplorer); } } - if let Some(grep_count) = stats.tools_usage.get("Grep") { - if *grep_count >= 50 && progress.unlock(AchievementId::SearchExpert) { + if let Some(grep_stats) = stats.tools_usage.get("Grep") { + if grep_stats.call_count >= 50 && progress.unlock(AchievementId::SearchExpert) { newly_unlocked.push(AchievementId::SearchExpert); } } // Git Guru - check git command usage in Bash - if let Some(bash_count) = stats.tools_usage.get("Bash") { - if *bash_count >= 10 && progress.unlock(AchievementId::GitGuru) { + if let Some(bash_stats) = stats.tools_usage.get("Bash") { + if bash_stats.call_count >= 10 && progress.unlock(AchievementId::GitGuru) { // TODO: More specific git command tracking newly_unlocked.push(AchievementId::GitGuru); } @@ -2055,28 +2056,28 @@ pub fn check_achievements( } // More tool mastery achievements - if let Some(edit_count) = stats.tools_usage.get("Edit") { - if *edit_count >= 100 && progress.unlock(AchievementId::EditMaster) { + if let Some(edit_stats) = stats.tools_usage.get("Edit") { + if edit_stats.call_count >= 100 && progress.unlock(AchievementId::EditMaster) { newly_unlocked.push(AchievementId::EditMaster); } } - if let Some(write_count) = stats.tools_usage.get("Write") { - if *write_count >= 50 && progress.unlock(AchievementId::WriteMaster) { + if let Some(write_stats) = stats.tools_usage.get("Write") { + if write_stats.call_count >= 50 && progress.unlock(AchievementId::WriteMaster) { newly_unlocked.push(AchievementId::WriteMaster); } } - if let Some(glob_count) = stats.tools_usage.get("Glob") { - if *glob_count >= 100 && progress.unlock(AchievementId::GlobMaster) { + if let Some(glob_stats) = stats.tools_usage.get("Glob") { + if glob_stats.call_count >= 100 && progress.unlock(AchievementId::GlobMaster) { newly_unlocked.push(AchievementId::GlobMaster); } } - if let Some(task_count) = stats.tools_usage.get("Task") { - if *task_count >= 50 && progress.unlock(AchievementId::TaskMaster) { + if let Some(task_stats) = stats.tools_usage.get("Task") { + if task_stats.call_count >= 50 && progress.unlock(AchievementId::TaskMaster) { newly_unlocked.push(AchievementId::TaskMaster); } } - if let Some(web_count) = stats.tools_usage.get("WebFetch") { - if *web_count >= 20 && progress.unlock(AchievementId::WebFetcher) { + if let Some(web_stats) = stats.tools_usage.get("WebFetch") { + if web_stats.call_count >= 20 && progress.unlock(AchievementId::WebFetcher) { newly_unlocked.push(AchievementId::WebFetcher); } } @@ -2085,7 +2086,7 @@ pub fn check_achievements( .tools_usage .iter() .filter(|(name, _)| name.starts_with("mcp__")) - .map(|(_, count)| count) + .map(|(_, tool_stats)| tool_stats.call_count) .sum(); if mcp_count >= 50 && progress.unlock(AchievementId::McpExplorer) { newly_unlocked.push(AchievementId::McpExplorer); @@ -2323,6 +2324,11 @@ mod tests { morning_sessions: 0, night_sessions: 0, last_session_date: None, + context_tokens_used: 0, + context_window_limit: 200_000, + context_utilisation_percent: 0.0, + potential_cache_hits: 0, + potential_cache_savings_tokens: 0, achievements: AchievementProgress::new(), } } @@ -2733,12 +2739,21 @@ mod tests { // check_achievements tests - Tool Usage // ===================== + // Helper function to create a ToolTokenStats with just call count for tests + fn tool_stats(call_count: u64) -> crate::stats::ToolTokenStats { + crate::stats::ToolTokenStats { + call_count, + estimated_input_tokens: 0, + estimated_output_tokens: 0, + } + } + #[test] fn test_check_achievements_first_tool() { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Read".to_string(), 1); + stats.tools_usage.insert("Read".to_string(), tool_stats(1)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::FirstTool)); @@ -2749,11 +2764,11 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Read".to_string(), 1); - stats.tools_usage.insert("Write".to_string(), 1); - stats.tools_usage.insert("Edit".to_string(), 1); - stats.tools_usage.insert("Bash".to_string(), 1); - stats.tools_usage.insert("Grep".to_string(), 1); + stats.tools_usage.insert("Read".to_string(), tool_stats(1)); + stats.tools_usage.insert("Write".to_string(), tool_stats(1)); + stats.tools_usage.insert("Edit".to_string(), tool_stats(1)); + stats.tools_usage.insert("Bash".to_string(), tool_stats(1)); + stats.tools_usage.insert("Grep".to_string(), tool_stats(1)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::Toolsmith)); @@ -2765,7 +2780,7 @@ mod tests { let mut progress = AchievementProgress::new(); for i in 0..10 { - stats.tools_usage.insert(format!("Tool{}", i), 1); + stats.tools_usage.insert(format!("Tool{}", i), tool_stats(1)); } let newly = check_achievements(&stats, &mut progress); @@ -2777,7 +2792,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Bash".to_string(), 50); + stats.tools_usage.insert("Bash".to_string(), tool_stats(50)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::BashMaster)); @@ -2788,7 +2803,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Read".to_string(), 100); + stats.tools_usage.insert("Read".to_string(), tool_stats(100)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::FileExplorer)); @@ -2799,7 +2814,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Grep".to_string(), 50); + stats.tools_usage.insert("Grep".to_string(), tool_stats(50)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::SearchExpert)); @@ -2810,7 +2825,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Edit".to_string(), 100); + stats.tools_usage.insert("Edit".to_string(), tool_stats(100)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::EditMaster)); @@ -2821,7 +2836,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Write".to_string(), 50); + stats.tools_usage.insert("Write".to_string(), tool_stats(50)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::WriteMaster)); @@ -2832,7 +2847,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Glob".to_string(), 100); + stats.tools_usage.insert("Glob".to_string(), tool_stats(100)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::GlobMaster)); @@ -2843,7 +2858,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Task".to_string(), 50); + stats.tools_usage.insert("Task".to_string(), tool_stats(50)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::TaskMaster)); @@ -2854,7 +2869,7 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("WebFetch".to_string(), 20); + stats.tools_usage.insert("WebFetch".to_string(), tool_stats(20)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::WebFetcher)); @@ -2865,8 +2880,8 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("mcp__github__create_issue".to_string(), 25); - stats.tools_usage.insert("mcp__notion__search".to_string(), 25); + stats.tools_usage.insert("mcp__github__create_issue".to_string(), tool_stats(25)); + stats.tools_usage.insert("mcp__notion__search".to_string(), tool_stats(25)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::McpExplorer)); @@ -2881,8 +2896,8 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Grep".to_string(), 30); - stats.tools_usage.insert("Glob".to_string(), 20); + stats.tools_usage.insert("Grep".to_string(), tool_stats(30)); + stats.tools_usage.insert("Glob".to_string(), tool_stats(20)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::Explorer)); @@ -2893,9 +2908,9 @@ mod tests { let mut stats = create_test_stats(); let mut progress = AchievementProgress::new(); - stats.tools_usage.insert("Grep".to_string(), 200); - stats.tools_usage.insert("Glob".to_string(), 200); - stats.tools_usage.insert("Task".to_string(), 100); + stats.tools_usage.insert("Grep".to_string(), tool_stats(200)); + stats.tools_usage.insert("Glob".to_string(), tool_stats(200)); + stats.tools_usage.insert("Task".to_string(), tool_stats(100)); let newly = check_achievements(&stats, &mut progress); assert!(newly.contains(&AchievementId::MasterSearcher)); diff --git a/src-tauri/src/bridge_manager.rs b/src-tauri/src/bridge_manager.rs index 4a2370b..72654c6 100644 --- a/src-tauri/src/bridge_manager.rs +++ b/src-tauri/src/bridge_manager.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::sync::Arc; use tauri::AppHandle; +use crate::commands::record_session; use crate::config::ClaudeStartOptions; use crate::stats::UsageStats; use crate::wsl_bridge::WslBridge; @@ -53,7 +54,12 @@ impl BridgeManager { .or_insert_with(|| WslBridge::new_with_conversation_id(conversation_id.to_string())); // Start the Claude process - bridge.start(app, options)?; + bridge.start(app.clone(), options)?; + + // Record session start for cost tracking + tauri::async_runtime::spawn(async move { + record_session(&app).await; + }); Ok(()) } diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index f71b265..a63a374 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -556,6 +556,97 @@ pub async fn rename_path(old_path: String, new_path: String) -> Result<(), Strin Ok(()) } +// ==================== Cost Tracking Commands ==================== + +const COST_HISTORY_STORE_KEY: &str = "cost_history"; + +#[tauri::command] +pub async fn get_cost_summary(app: AppHandle, days: u32) -> Result { + let history = load_cost_history(&app).await; + Ok(history.get_summary(days)) +} + +#[tauri::command] +pub async fn get_cost_alerts(app: AppHandle) -> Result, String> { + let mut history = load_cost_history(&app).await; + let alerts = history.check_alerts(); + + // Save updated alert state + save_cost_history(&app, &history).await?; + + Ok(alerts) +} + +#[tauri::command] +pub async fn set_cost_alert_thresholds( + app: AppHandle, + daily: Option, + weekly: Option, + monthly: Option, +) -> Result<(), String> { + let mut history = load_cost_history(&app).await; + history.set_alert_thresholds(daily, weekly, monthly); + save_cost_history(&app, &history).await +} + +#[tauri::command] +pub async fn export_cost_csv(app: AppHandle, days: u32) -> Result { + let history = load_cost_history(&app).await; + Ok(history.export_csv(days)) +} + +#[tauri::command] +pub async fn get_today_cost(app: AppHandle) -> Result { + let history = load_cost_history(&app).await; + Ok(history.get_today_cost()) +} + +#[tauri::command] +pub async fn get_week_cost(app: AppHandle) -> Result { + let history = load_cost_history(&app).await; + Ok(history.get_week_cost()) +} + +#[tauri::command] +pub async fn get_month_cost(app: AppHandle) -> Result { + let history = load_cost_history(&app).await; + Ok(history.get_month_cost()) +} + +/// Add cost to history (called internally when stats are updated) +pub async fn record_cost(app: &AppHandle, input_tokens: u64, output_tokens: u64, cost_usd: f64) { + let mut history = load_cost_history(app).await; + history.add_cost(input_tokens, output_tokens, cost_usd); + let _ = save_cost_history(app, &history).await; +} + +/// Record a new session +pub async fn record_session(app: &AppHandle) { + let mut history = load_cost_history(app).await; + history.increment_sessions(); + let _ = save_cost_history(app, &history).await; +} + +async fn load_cost_history(app: &AppHandle) -> crate::cost_tracking::CostHistory { + let store = match app.store("hikari-cost-history.json") { + Ok(s) => s, + Err(_) => return crate::cost_tracking::CostHistory::new(), + }; + + match store.get(COST_HISTORY_STORE_KEY) { + Some(value) => serde_json::from_value(value.clone()).unwrap_or_default(), + None => crate::cost_tracking::CostHistory::new(), + } +} + +async fn save_cost_history(app: &AppHandle, history: &crate::cost_tracking::CostHistory) -> Result<(), String> { + let store = app.store("hikari-cost-history.json").map_err(|e| e.to_string())?; + let value = serde_json::to_value(history).map_err(|e| e.to_string())?; + store.set(COST_HISTORY_STORE_KEY, value); + store.save().map_err(|e| e.to_string())?; + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index 6ca733e..20ff574 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -96,6 +96,22 @@ pub struct HikariConfig { // Custom theme colors #[serde(default)] pub custom_theme_colors: CustomThemeColors, + + // Token budget settings + #[serde(default)] + pub budget_enabled: bool, + + #[serde(default)] + pub session_token_budget: Option, + + #[serde(default)] + pub session_cost_budget: Option, + + #[serde(default = "default_budget_action")] + pub budget_action: BudgetAction, + + #[serde(default = "default_budget_warning_threshold")] + pub budget_warning_threshold: f32, } impl Default for HikariConfig { @@ -123,6 +139,11 @@ impl Default for HikariConfig { profile_avatar_path: None, profile_bio: None, custom_theme_colors: CustomThemeColors::default(), + budget_enabled: false, + session_token_budget: None, + session_cost_budget: None, + budget_action: BudgetAction::Warn, + budget_warning_threshold: 0.8, } } } @@ -147,6 +168,22 @@ fn default_font_size() -> u32 { 14 } +fn default_budget_action() -> BudgetAction { + BudgetAction::Warn +} + +fn default_budget_warning_threshold() -> f32 { + 0.8 +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum BudgetAction { + #[default] + Warn, + Block, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Theme { @@ -205,6 +242,11 @@ mod tests { assert!(config.profile_avatar_path.is_none()); assert!(config.profile_bio.is_none()); assert_eq!(config.custom_theme_colors, CustomThemeColors::default()); + assert!(!config.budget_enabled); + assert!(config.session_token_budget.is_none()); + assert!(config.session_cost_budget.is_none()); + assert_eq!(config.budget_action, BudgetAction::Warn); + assert!((config.budget_warning_threshold - 0.8).abs() < f32::EPSILON); } #[test] @@ -232,6 +274,11 @@ mod tests { profile_avatar_path: None, profile_bio: Some("A test bio".to_string()), custom_theme_colors: CustomThemeColors::default(), + budget_enabled: true, + session_token_budget: Some(100000), + session_cost_budget: Some(1.50), + budget_action: BudgetAction::Block, + budget_warning_threshold: 0.75, }; let json = serde_json::to_string(&config).unwrap(); diff --git a/src-tauri/src/cost_tracking.rs b/src-tauri/src/cost_tracking.rs new file mode 100644 index 0000000..7d20203 --- /dev/null +++ b/src-tauri/src/cost_tracking.rs @@ -0,0 +1,376 @@ +use chrono::{Datelike, Local, NaiveDate, Weekday}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Represents a single day's cost data +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct DailyCost { + pub date: String, // ISO date string (YYYY-MM-DD) + pub input_tokens: u64, + pub output_tokens: u64, + pub cost_usd: f64, + pub messages_sent: u64, + pub sessions_count: u64, +} + +/// Historical cost tracking data +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CostHistory { + /// Daily costs indexed by date string (YYYY-MM-DD) + pub daily_costs: HashMap, + /// Cost alert thresholds + pub daily_alert_threshold: Option, + pub weekly_alert_threshold: Option, + pub monthly_alert_threshold: Option, + /// Whether alerts have been triggered today + pub daily_alert_triggered: bool, + pub weekly_alert_triggered: bool, + pub monthly_alert_triggered: bool, + pub last_alert_reset_date: Option, +} + +impl CostHistory { + pub fn new() -> Self { + Self::default() + } + + /// Get today's date as a string + fn today_str() -> String { + Local::now().format("%Y-%m-%d").to_string() + } + + /// Get the start of the current week (Monday) + fn week_start() -> NaiveDate { + let today = Local::now().date_naive(); + let days_since_monday = today.weekday().num_days_from_monday(); + today - chrono::Duration::days(days_since_monday as i64) + } + + /// Get the start of the current month + fn month_start() -> NaiveDate { + let today = Local::now().date_naive(); + NaiveDate::from_ymd_opt(today.year(), today.month(), 1).unwrap_or(today) + } + + /// Add cost for today + pub fn add_cost(&mut self, input_tokens: u64, output_tokens: u64, cost_usd: f64) { + let today = Self::today_str(); + + // Reset alert flags if it's a new day + if self.last_alert_reset_date.as_ref() != Some(&today) { + self.daily_alert_triggered = false; + // Reset weekly on Monday + if Local::now().weekday() == Weekday::Mon { + self.weekly_alert_triggered = false; + } + // Reset monthly on the 1st + if Local::now().day() == 1 { + self.monthly_alert_triggered = false; + } + self.last_alert_reset_date = Some(today.clone()); + } + + let daily = self.daily_costs.entry(today).or_default(); + daily.input_tokens += input_tokens; + daily.output_tokens += output_tokens; + daily.cost_usd += cost_usd; + daily.messages_sent += 1; + } + + /// Increment session count for today + pub fn increment_sessions(&mut self) { + let today = Self::today_str(); + let daily = self.daily_costs.entry(today.clone()).or_insert_with(|| DailyCost { + date: today, + ..Default::default() + }); + daily.sessions_count += 1; + } + + /// Get today's cost + pub fn get_today_cost(&self) -> f64 { + self.daily_costs + .get(&Self::today_str()) + .map(|d| d.cost_usd) + .unwrap_or(0.0) + } + + /// Get this week's cost (Monday to Sunday) + pub fn get_week_cost(&self) -> f64 { + let week_start = Self::week_start(); + self.daily_costs + .values() + .filter(|d| { + NaiveDate::parse_from_str(&d.date, "%Y-%m-%d") + .map(|date| date >= week_start) + .unwrap_or(false) + }) + .map(|d| d.cost_usd) + .sum() + } + + /// Get this month's cost + pub fn get_month_cost(&self) -> f64 { + let month_start = Self::month_start(); + self.daily_costs + .values() + .filter(|d| { + NaiveDate::parse_from_str(&d.date, "%Y-%m-%d") + .map(|date| date >= month_start) + .unwrap_or(false) + }) + .map(|d| d.cost_usd) + .sum() + } + + /// Get cost summary for a date range + pub fn get_summary(&self, days: u32) -> CostSummary { + let today = Local::now().date_naive(); + let start_date = today - chrono::Duration::days(days as i64 - 1); + + let mut total_input_tokens = 0u64; + let mut total_output_tokens = 0u64; + let mut total_cost = 0.0f64; + let mut total_messages = 0u64; + let mut total_sessions = 0u64; + let mut daily_breakdown = Vec::new(); + + for i in 0..days { + let date = start_date + chrono::Duration::days(i as i64); + let date_str = date.format("%Y-%m-%d").to_string(); + + if let Some(daily) = self.daily_costs.get(&date_str) { + total_input_tokens += daily.input_tokens; + total_output_tokens += daily.output_tokens; + total_cost += daily.cost_usd; + total_messages += daily.messages_sent; + total_sessions += daily.sessions_count; + daily_breakdown.push(daily.clone()); + } else { + daily_breakdown.push(DailyCost { + date: date_str, + ..Default::default() + }); + } + } + + CostSummary { + period_days: days, + total_input_tokens, + total_output_tokens, + total_cost, + total_messages, + total_sessions, + average_daily_cost: if days > 0 { total_cost / days as f64 } else { 0.0 }, + daily_breakdown, + } + } + + /// Check if any alert thresholds are exceeded and return which ones + pub fn check_alerts(&mut self) -> Vec { + let mut alerts = Vec::new(); + + if let Some(threshold) = self.daily_alert_threshold { + let today_cost = self.get_today_cost(); + if today_cost >= threshold && !self.daily_alert_triggered { + self.daily_alert_triggered = true; + alerts.push(CostAlert { + alert_type: AlertType::Daily, + threshold, + current_cost: today_cost, + }); + } + } + + if let Some(threshold) = self.weekly_alert_threshold { + let week_cost = self.get_week_cost(); + if week_cost >= threshold && !self.weekly_alert_triggered { + self.weekly_alert_triggered = true; + alerts.push(CostAlert { + alert_type: AlertType::Weekly, + threshold, + current_cost: week_cost, + }); + } + } + + if let Some(threshold) = self.monthly_alert_threshold { + let month_cost = self.get_month_cost(); + if month_cost >= threshold && !self.monthly_alert_triggered { + self.monthly_alert_triggered = true; + alerts.push(CostAlert { + alert_type: AlertType::Monthly, + threshold, + current_cost: month_cost, + }); + } + } + + alerts + } + + /// Set alert thresholds + pub fn set_alert_thresholds( + &mut self, + daily: Option, + weekly: Option, + monthly: Option, + ) { + self.daily_alert_threshold = daily; + self.weekly_alert_threshold = weekly; + self.monthly_alert_threshold = monthly; + } + + /// Clean up old data (keep last N days) + #[allow(dead_code)] + pub fn cleanup_old_data(&mut self, keep_days: u32) { + let cutoff = Local::now().date_naive() - chrono::Duration::days(keep_days as i64); + self.daily_costs.retain(|date_str, _| { + NaiveDate::parse_from_str(date_str, "%Y-%m-%d") + .map(|date| date >= cutoff) + .unwrap_or(false) + }); + } + + /// Export to CSV format + pub fn export_csv(&self, days: u32) -> String { + let summary = self.get_summary(days); + let mut csv = String::from("Date,Input Tokens,Output Tokens,Cost (USD),Messages,Sessions\n"); + + for daily in &summary.daily_breakdown { + csv.push_str(&format!( + "{},{},{},{:.4},{},{}\n", + daily.date, + daily.input_tokens, + daily.output_tokens, + daily.cost_usd, + daily.messages_sent, + daily.sessions_count + )); + } + + // Add totals row + csv.push_str(&format!( + "TOTAL,{},{},{:.4},{},{}\n", + summary.total_input_tokens, + summary.total_output_tokens, + summary.total_cost, + summary.total_messages, + summary.total_sessions + )); + + csv + } +} + +/// Cost summary for a period +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostSummary { + pub period_days: u32, + pub total_input_tokens: u64, + pub total_output_tokens: u64, + pub total_cost: f64, + pub total_messages: u64, + pub total_sessions: u64, + pub average_daily_cost: f64, + pub daily_breakdown: Vec, +} + +/// Alert types +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum AlertType { + Daily, + Weekly, + Monthly, +} + +/// Cost alert notification +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostAlert { + pub alert_type: AlertType, + pub threshold: f64, + pub current_cost: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_cost() { + let mut history = CostHistory::new(); + history.add_cost(1000, 500, 0.05); + + let today_cost = history.get_today_cost(); + assert!((today_cost - 0.05).abs() < 0.0001); + } + + #[test] + fn test_accumulate_daily_cost() { + let mut history = CostHistory::new(); + history.add_cost(1000, 500, 0.05); + history.add_cost(2000, 1000, 0.10); + + let today_cost = history.get_today_cost(); + assert!((today_cost - 0.15).abs() < 0.0001); + } + + #[test] + fn test_summary() { + let mut history = CostHistory::new(); + history.add_cost(1000, 500, 0.05); + + let summary = history.get_summary(7); + assert_eq!(summary.period_days, 7); + assert!((summary.total_cost - 0.05).abs() < 0.0001); + } + + #[test] + fn test_daily_alert() { + let mut history = CostHistory::new(); + history.set_alert_thresholds(Some(0.10), None, None); + + history.add_cost(1000, 500, 0.05); + let alerts = history.check_alerts(); + assert!(alerts.is_empty()); + + history.add_cost(1000, 500, 0.06); + let alerts = history.check_alerts(); + assert_eq!(alerts.len(), 1); + assert_eq!(alerts[0].alert_type, AlertType::Daily); + } + + #[test] + fn test_alert_only_triggers_once() { + let mut history = CostHistory::new(); + history.set_alert_thresholds(Some(0.10), None, None); + + history.add_cost(1000, 500, 0.15); + let alerts = history.check_alerts(); + assert_eq!(alerts.len(), 1); + + // Second check should not trigger again + let alerts = history.check_alerts(); + assert!(alerts.is_empty()); + } + + #[test] + fn test_export_csv() { + let mut history = CostHistory::new(); + history.add_cost(1000, 500, 0.05); + + let csv = history.export_csv(1); + assert!(csv.contains("Date,Input Tokens")); + assert!(csv.contains("TOTAL")); + } + + #[test] + fn test_increment_sessions() { + let mut history = CostHistory::new(); + history.increment_sessions(); + history.increment_sessions(); + + let summary = history.get_summary(1); + assert_eq!(summary.total_sessions, 2); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 6db2820..4f5c82c 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -3,6 +3,7 @@ mod bridge_manager; mod clipboard; mod commands; mod config; +mod cost_tracking; mod git; mod notifications; mod quick_actions; @@ -10,6 +11,7 @@ mod sessions; mod snippets; mod stats; mod temp_manager; +mod tool_cache; mod tray; mod types; mod vbs_notification; @@ -159,6 +161,14 @@ pub fn run() { delete_file, delete_directory, rename_path, + // Cost tracking commands + get_cost_summary, + get_cost_alerts, + set_cost_alert_thresholds, + export_cost_csv, + get_today_cost, + get_week_cost, + get_month_cost, ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); diff --git a/src-tauri/src/stats.rs b/src-tauri/src/stats.rs index 36d08fe..346d073 100644 --- a/src-tauri/src/stats.rs +++ b/src-tauri/src/stats.rs @@ -5,6 +5,110 @@ use std::collections::HashMap; use std::time::Instant; use tauri_plugin_store::StoreExt; +/// Per-tool token usage statistics +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ToolTokenStats { + pub call_count: u64, + pub estimated_input_tokens: u64, + pub estimated_output_tokens: u64, +} + +impl ToolTokenStats { + #[allow(dead_code)] + pub fn new() -> Self { + Self::default() + } + + pub fn increment_call(&mut self) { + self.call_count += 1; + } + + pub fn add_tokens(&mut self, input: u64, output: u64) { + self.estimated_input_tokens += input; + self.estimated_output_tokens += output; + } + + #[allow(dead_code)] + pub fn total_tokens(&self) -> u64 { + self.estimated_input_tokens + self.estimated_output_tokens + } +} + +/// Warning levels for context window utilisation +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ContextWarning { + /// 50-74% utilisation - conversation is getting long + Moderate, + /// 75-89% utilisation - consider summarising + High, + /// 90%+ utilisation - approaching limit + Critical, +} + +/// Budget status indicating whether user is within their limits +#[allow(dead_code)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum BudgetStatus { + /// Within budget, no concerns + Ok, + /// Approaching budget limit (warning threshold reached) + Warning { + budget_type: BudgetType, + percent_used: f32, + }, + /// Budget exceeded + Exceeded { budget_type: BudgetType }, +} + +/// Type of budget limit +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum BudgetType { + Token, + Cost, +} + +impl ContextWarning { + #[allow(dead_code)] + pub fn message(&self) -> &'static str { + match self { + ContextWarning::Moderate => "Context window is 50%+ full. Consider starting a new conversation for better performance.", + ContextWarning::High => "Context window is 75%+ full. Responses may degrade. Consider summarising or starting fresh.", + ContextWarning::Critical => "Context window is nearly full (90%+)! Start a new conversation to avoid errors.", + } + } +} + +/// Get the context window limit (in tokens) for a given model +fn get_context_window_limit(model: &str) -> u64 { + match model { + // Claude 4.5 family - 200K standard context + "claude-opus-4-5-20251101" + | "claude-sonnet-4-5-20250929" + | "claude-haiku-4-5-20251001" => 200_000, + // Claude 4.x family - 200K standard context + "claude-opus-4-1-20250805" + | "claude-opus-4-20250514" + | "claude-sonnet-4-20250514" => 200_000, + // Claude 3.x family + "claude-3-7-sonnet-20250219" + | "claude-3-5-sonnet-20241022" + | "claude-3-5-sonnet-20240620" + | "claude-3-5-haiku-20241022" + | "claude-3-opus-20240229" + | "claude-3-sonnet-20240229" + | "claude-3-haiku-20240307" => 200_000, + // Default to 200K for unknown Claude models + _ if model.starts_with("claude") => 200_000, + // For non-Claude models (Ollama, etc.), use a conservative default + _ => 128_000, + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct UsageStats { pub total_input_tokens: u64, @@ -24,8 +128,8 @@ pub struct UsageStats { pub session_files_edited: u64, pub files_created: u64, pub session_files_created: u64, - pub tools_usage: HashMap, - pub session_tools_usage: HashMap, + pub tools_usage: HashMap, + pub session_tools_usage: HashMap, pub session_duration_seconds: u64, #[serde(skip)] pub session_start: Option, @@ -38,6 +142,15 @@ pub struct UsageStats { pub night_sessions: u64, // Sessions started after 10 PM pub last_session_date: Option, // ISO date string for streak tracking + // Context window tracking + pub context_tokens_used: u64, + pub context_window_limit: u64, + pub context_utilisation_percent: f32, + + // Cache analytics (tracks potential savings from repeated tool calls) + pub potential_cache_hits: u64, + pub potential_cache_savings_tokens: u64, + // Achievement tracking #[serde(skip)] pub achievements: AchievementProgress, @@ -61,6 +174,114 @@ impl UsageStats { self.session_cost_usd += cost; self.model = Some(model.to_string()); + + // Update context window tracking + self.update_context_tracking(model); + } + + pub fn update_context_tracking(&mut self, model: &str) { + // Get context window limit for the current model + self.context_window_limit = get_context_window_limit(model); + + // Context tokens = input tokens (the prompt/context sent to the model) + // We track cumulative session input as a proxy for context growth + self.context_tokens_used = self.session_input_tokens; + + // Calculate utilisation percentage + if self.context_window_limit > 0 { + self.context_utilisation_percent = + (self.context_tokens_used as f32 / self.context_window_limit as f32) * 100.0; + } + } + + #[allow(dead_code)] + pub fn get_context_warning(&self) -> Option { + if self.context_utilisation_percent >= 90.0 { + Some(ContextWarning::Critical) + } else if self.context_utilisation_percent >= 75.0 { + Some(ContextWarning::High) + } else if self.context_utilisation_percent >= 50.0 { + Some(ContextWarning::Moderate) + } else { + None + } + } + + #[allow(dead_code)] + pub fn estimate_remaining_tokens(&self) -> u64 { + self.context_window_limit + .saturating_sub(self.context_tokens_used) + } + + /// Check budget status given current usage and budget settings + #[allow(dead_code)] + pub fn check_budget( + &self, + budget_enabled: bool, + token_budget: Option, + cost_budget: Option, + warning_threshold: f32, + ) -> BudgetStatus { + if !budget_enabled { + return BudgetStatus::Ok; + } + + let session_tokens = self.session_input_tokens + self.session_output_tokens; + + // Check token budget + if let Some(limit) = token_budget { + if session_tokens >= limit { + return BudgetStatus::Exceeded { + budget_type: BudgetType::Token, + }; + } + let percent_used = session_tokens as f32 / limit as f32; + if percent_used >= warning_threshold { + return BudgetStatus::Warning { + budget_type: BudgetType::Token, + percent_used: percent_used * 100.0, + }; + } + } + + // Check cost budget + if let Some(limit) = cost_budget { + if self.session_cost_usd >= limit { + return BudgetStatus::Exceeded { + budget_type: BudgetType::Cost, + }; + } + let percent_used = (self.session_cost_usd / limit) as f32; + if percent_used >= warning_threshold { + return BudgetStatus::Warning { + budget_type: BudgetType::Cost, + percent_used: percent_used * 100.0, + }; + } + } + + BudgetStatus::Ok + } + + /// Get remaining token budget (None if no budget set) + #[allow(dead_code)] + pub fn get_remaining_token_budget(&self, token_budget: Option) -> Option { + token_budget.map(|limit| { + let used = self.session_input_tokens + self.session_output_tokens; + limit.saturating_sub(used) + }) + } + + /// Get remaining cost budget (None if no budget set) + #[allow(dead_code)] + pub fn get_remaining_cost_budget(&self, cost_budget: Option) -> Option { + cost_budget.map(|limit| { + if limit > self.session_cost_usd { + limit - self.session_cost_usd + } else { + 0.0 + } + }) } pub fn reset_session(&mut self) { @@ -76,6 +297,13 @@ impl UsageStats { self.session_start = Some(Instant::now()); self.achievements.start_session(); + // Reset context window tracking + self.context_tokens_used = 0; + self.context_utilisation_percent = 0.0; + + // Note: Cache analytics are NOT reset here - they're cumulative across sessions + // to show total potential savings over time + // Track session start for achievements self.track_session_start(); } @@ -139,11 +367,32 @@ impl UsageStats { } pub fn increment_tool_usage(&mut self, tool_name: &str) { - *self.tools_usage.entry(tool_name.to_string()).or_insert(0) += 1; - *self - .session_tools_usage + self.tools_usage .entry(tool_name.to_string()) - .or_insert(0) += 1; + .or_default() + .increment_call(); + self.session_tools_usage + .entry(tool_name.to_string()) + .or_default() + .increment_call(); + } + + pub fn add_tool_tokens(&mut self, tool_name: &str, input_tokens: u64, output_tokens: u64) { + self.tools_usage + .entry(tool_name.to_string()) + .or_default() + .add_tokens(input_tokens, output_tokens); + self.session_tools_usage + .entry(tool_name.to_string()) + .or_default() + .add_tokens(input_tokens, output_tokens); + } + + /// Record a potential cache hit (when the same tool call is made twice) + #[allow(dead_code)] + pub fn add_potential_cache_hit(&mut self, tokens_saved: u64) { + self.potential_cache_hits += 1; + self.potential_cache_savings_tokens += tokens_saved; } pub fn get_session_duration(&mut self) -> u64 { @@ -184,6 +433,11 @@ impl UsageStats { morning_sessions: self.morning_sessions, night_sessions: self.night_sessions, last_session_date: self.last_session_date.clone(), + context_tokens_used: self.context_tokens_used, + context_window_limit: self.context_window_limit, + context_utilisation_percent: self.context_utilisation_percent, + potential_cache_hits: self.potential_cache_hits, + potential_cache_savings_tokens: self.potential_cache_savings_tokens, achievements: AchievementProgress::new(), // Dummy for copy }; check_achievements(&stats_copy, &mut self.achievements) @@ -206,20 +460,22 @@ fn is_consecutive_day(prev_date: &str, current_date: &str) -> bool { } } -// Pricing as of January 2025 -// https://www.anthropic.com/pricing -fn calculate_cost(input_tokens: u64, output_tokens: u64, model: &str) -> f64 { +// Pricing as of February 2026 +// https://platform.claude.com/docs/en/about-claude/models/overview +pub fn calculate_cost(input_tokens: u64, output_tokens: u64, model: &str) -> f64 { let (input_price_per_million, output_price_per_million) = match model { - // Opus 4.5 - "claude-opus-4-5-20251101" => (15.0, 75.0), + // Current generation (Claude 4.5) + "claude-opus-4-5-20251101" => (5.0, 25.0), + "claude-sonnet-4-5-20250929" => (3.0, 15.0), + "claude-haiku-4-5-20251001" => (1.0, 5.0), - // Opus 4 + // Previous generation (Claude 4.x) + "claude-opus-4-1-20250805" => (15.0, 75.0), "claude-opus-4-20250514" => (15.0, 75.0), - - // Sonnet 4 "claude-sonnet-4-20250514" => (3.0, 15.0), - // Previous generation models + // Legacy (Claude 3.x) + "claude-3-7-sonnet-20250219" => (3.0, 15.0), "claude-3-5-sonnet-20241022" => (3.0, 15.0), "claude-3-5-sonnet-20240620" => (3.0, 15.0), "claude-3-5-haiku-20241022" => (1.0, 5.0), @@ -252,7 +508,7 @@ pub struct PersistedStats { pub code_blocks_generated: u64, pub files_edited: u64, pub files_created: u64, - pub tools_usage: HashMap, + pub tools_usage: HashMap, pub sessions_started: u64, pub consecutive_days: u64, pub total_days_used: u64, @@ -372,8 +628,10 @@ mod tests { #[test] fn test_cost_calculation_opus_45() { let cost = calculate_cost(1000, 2000, "claude-opus-4-5-20251101"); - // Same pricing as Opus 4 - assert!((cost - 0.165).abs() < 0.0001); + // Opus 4.5 pricing: $5/MTok input, $25/MTok output + // 1000 input tokens = $0.005, 2000 output tokens = $0.05 + // Total = $0.055 + assert!((cost - 0.055).abs() < 0.0001); } #[test] @@ -512,10 +770,33 @@ mod tests { stats.increment_tool_usage("Read"); stats.increment_tool_usage("Write"); - assert_eq!(stats.tools_usage.get("Read"), Some(&2)); - assert_eq!(stats.tools_usage.get("Write"), Some(&1)); - assert_eq!(stats.session_tools_usage.get("Read"), Some(&2)); - assert_eq!(stats.session_tools_usage.get("Write"), Some(&1)); + assert_eq!(stats.tools_usage.get("Read").map(|t| t.call_count), Some(2)); + assert_eq!(stats.tools_usage.get("Write").map(|t| t.call_count), Some(1)); + assert_eq!(stats.session_tools_usage.get("Read").map(|t| t.call_count), Some(2)); + assert_eq!(stats.session_tools_usage.get("Write").map(|t| t.call_count), Some(1)); + } + + #[test] + fn test_add_tool_tokens() { + let mut stats = UsageStats::new(); + stats.increment_tool_usage("Read"); + stats.add_tool_tokens("Read", 100, 50); + stats.add_tool_tokens("Read", 200, 100); + + let read_stats = stats.tools_usage.get("Read").unwrap(); + assert_eq!(read_stats.call_count, 1); + assert_eq!(read_stats.estimated_input_tokens, 300); + assert_eq!(read_stats.estimated_output_tokens, 150); + assert_eq!(read_stats.total_tokens(), 450); + } + + #[test] + fn test_tool_token_stats_default() { + let tool_stats = ToolTokenStats::new(); + assert_eq!(tool_stats.call_count, 0); + assert_eq!(tool_stats.estimated_input_tokens, 0); + assert_eq!(tool_stats.estimated_output_tokens, 0); + assert_eq!(tool_stats.total_tokens(), 0); } #[test] @@ -590,7 +871,11 @@ mod tests { files_created: 5, tools_usage: { let mut map = HashMap::new(); - map.insert("Read".to_string(), 50); + map.insert("Read".to_string(), ToolTokenStats { + call_count: 50, + estimated_input_tokens: 5000, + estimated_output_tokens: 2500, + }); map }, sessions_started: 10, @@ -608,7 +893,8 @@ mod tests { assert_eq!(stats.total_output_tokens, 20000); assert_eq!(stats.total_cost_usd, 5.50); assert_eq!(stats.messages_exchanged, 100); - assert_eq!(stats.tools_usage.get("Read"), Some(&50)); + assert_eq!(stats.tools_usage.get("Read").map(|t| t.call_count), Some(50)); + assert_eq!(stats.tools_usage.get("Read").map(|t| t.estimated_input_tokens), Some(5000)); assert_eq!(stats.consecutive_days, 7); assert_eq!(stats.morning_sessions, 3); assert_eq!(stats.last_session_date, Some("2024-06-15".to_string())); @@ -672,4 +958,351 @@ mod tests { assert!(json.contains("stats")); assert!(json.contains("total_input_tokens")); } + + // ===================== + // Context Window Tracking tests + // ===================== + + #[test] + fn test_context_window_limit_claude_4() { + assert_eq!(get_context_window_limit("claude-opus-4-5-20251101"), 200_000); + assert_eq!(get_context_window_limit("claude-opus-4-20250514"), 200_000); + assert_eq!(get_context_window_limit("claude-sonnet-4-20250514"), 200_000); + } + + #[test] + fn test_context_window_limit_claude_35() { + assert_eq!( + get_context_window_limit("claude-3-5-sonnet-20241022"), + 200_000 + ); + assert_eq!( + get_context_window_limit("claude-3-5-sonnet-20240620"), + 200_000 + ); + assert_eq!( + get_context_window_limit("claude-3-5-haiku-20241022"), + 200_000 + ); + } + + #[test] + fn test_context_window_limit_unknown_claude() { + assert_eq!( + get_context_window_limit("claude-some-future-model"), + 200_000 + ); + } + + #[test] + fn test_context_window_limit_non_claude() { + assert_eq!(get_context_window_limit("gpt-4"), 128_000); + assert_eq!(get_context_window_limit("llama-3"), 128_000); + assert_eq!(get_context_window_limit("unknown-model"), 128_000); + } + + #[test] + fn test_context_tracking_update() { + let mut stats = UsageStats::new(); + stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514"); + + assert_eq!(stats.context_tokens_used, 50_000); + assert_eq!(stats.context_window_limit, 200_000); + assert!((stats.context_utilisation_percent - 25.0).abs() < 0.1); + } + + #[test] + fn test_context_tracking_accumulates() { + let mut stats = UsageStats::new(); + stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514"); + stats.add_usage(50_000, 10_000, "claude-sonnet-4-20250514"); + + assert_eq!(stats.context_tokens_used, 100_000); + assert!((stats.context_utilisation_percent - 50.0).abs() < 0.1); + } + + #[test] + fn test_context_warning_none() { + let mut stats = UsageStats::new(); + stats.context_utilisation_percent = 40.0; + assert!(stats.get_context_warning().is_none()); + } + + #[test] + fn test_context_warning_moderate() { + let mut stats = UsageStats::new(); + stats.context_utilisation_percent = 55.0; + assert_eq!(stats.get_context_warning(), Some(ContextWarning::Moderate)); + } + + #[test] + fn test_context_warning_high() { + let mut stats = UsageStats::new(); + stats.context_utilisation_percent = 80.0; + assert_eq!(stats.get_context_warning(), Some(ContextWarning::High)); + } + + #[test] + fn test_context_warning_critical() { + let mut stats = UsageStats::new(); + stats.context_utilisation_percent = 95.0; + assert_eq!(stats.get_context_warning(), Some(ContextWarning::Critical)); + } + + #[test] + fn test_estimate_remaining_tokens() { + let mut stats = UsageStats::new(); + stats.context_tokens_used = 50_000; + stats.context_window_limit = 200_000; + + assert_eq!(stats.estimate_remaining_tokens(), 150_000); + } + + #[test] + fn test_estimate_remaining_tokens_at_limit() { + let mut stats = UsageStats::new(); + stats.context_tokens_used = 200_000; + stats.context_window_limit = 200_000; + + assert_eq!(stats.estimate_remaining_tokens(), 0); + } + + #[test] + fn test_estimate_remaining_tokens_over_limit() { + let mut stats = UsageStats::new(); + stats.context_tokens_used = 250_000; + stats.context_window_limit = 200_000; + + assert_eq!(stats.estimate_remaining_tokens(), 0); + } + + #[test] + fn test_context_reset_on_session_reset() { + let mut stats = UsageStats::new(); + stats.add_usage(100_000, 20_000, "claude-sonnet-4-20250514"); + + assert!(stats.context_tokens_used > 0); + assert!(stats.context_utilisation_percent > 0.0); + + stats.reset_session(); + + assert_eq!(stats.context_tokens_used, 0); + assert_eq!(stats.context_utilisation_percent, 0.0); + } + + #[test] + fn test_context_warning_message() { + assert_eq!( + ContextWarning::Moderate.message(), + "Context window is 50%+ full. Consider starting a new conversation for better performance." + ); + assert_eq!( + ContextWarning::High.message(), + "Context window is 75%+ full. Responses may degrade. Consider summarising or starting fresh." + ); + assert_eq!( + ContextWarning::Critical.message(), + "Context window is nearly full (90%+)! Start a new conversation to avoid errors." + ); + } + + #[test] + fn test_context_warning_serialization() { + let warning = ContextWarning::Critical; + let json = serde_json::to_string(&warning).expect("Failed to serialize"); + assert_eq!(json, "\"critical\""); + + let warning = ContextWarning::Moderate; + let json = serde_json::to_string(&warning).expect("Failed to serialize"); + assert_eq!(json, "\"moderate\""); + } + + // ===================== + // Budget Tracking tests + // ===================== + + #[test] + fn test_budget_disabled_returns_ok() { + let stats = UsageStats::new(); + let status = stats.check_budget(false, Some(1000), Some(1.0), 0.8); + assert_eq!(status, BudgetStatus::Ok); + } + + #[test] + fn test_budget_no_limits_returns_ok() { + let stats = UsageStats::new(); + let status = stats.check_budget(true, None, None, 0.8); + assert_eq!(status, BudgetStatus::Ok); + } + + #[test] + fn test_token_budget_within_limit() { + let mut stats = UsageStats::new(); + stats.session_input_tokens = 500; + stats.session_output_tokens = 300; + + let status = stats.check_budget(true, Some(10000), None, 0.8); + assert_eq!(status, BudgetStatus::Ok); + } + + #[test] + fn test_token_budget_warning() { + let mut stats = UsageStats::new(); + stats.session_input_tokens = 4500; + stats.session_output_tokens = 4000; + + let status = stats.check_budget(true, Some(10000), None, 0.8); + match status { + BudgetStatus::Warning { + budget_type, + percent_used, + } => { + assert_eq!(budget_type, BudgetType::Token); + assert!(percent_used >= 80.0); + } + _ => panic!("Expected Warning status"), + } + } + + #[test] + fn test_token_budget_exceeded() { + let mut stats = UsageStats::new(); + stats.session_input_tokens = 6000; + stats.session_output_tokens = 5000; + + let status = stats.check_budget(true, Some(10000), None, 0.8); + assert_eq!( + status, + BudgetStatus::Exceeded { + budget_type: BudgetType::Token + } + ); + } + + #[test] + fn test_cost_budget_within_limit() { + let mut stats = UsageStats::new(); + stats.session_cost_usd = 0.50; + + let status = stats.check_budget(true, None, Some(5.0), 0.8); + assert_eq!(status, BudgetStatus::Ok); + } + + #[test] + fn test_cost_budget_warning() { + let mut stats = UsageStats::new(); + stats.session_cost_usd = 4.25; + + let status = stats.check_budget(true, None, Some(5.0), 0.8); + match status { + BudgetStatus::Warning { + budget_type, + percent_used, + } => { + assert_eq!(budget_type, BudgetType::Cost); + assert!(percent_used >= 80.0); + } + _ => panic!("Expected Warning status"), + } + } + + #[test] + fn test_cost_budget_exceeded() { + let mut stats = UsageStats::new(); + stats.session_cost_usd = 5.50; + + let status = stats.check_budget(true, None, Some(5.0), 0.8); + assert_eq!( + status, + BudgetStatus::Exceeded { + budget_type: BudgetType::Cost + } + ); + } + + #[test] + fn test_token_budget_takes_priority() { + let mut stats = UsageStats::new(); + stats.session_input_tokens = 12000; + stats.session_output_tokens = 0; + stats.session_cost_usd = 0.01; + + // Token budget exceeded, cost budget OK + let status = stats.check_budget(true, Some(10000), Some(5.0), 0.8); + assert_eq!( + status, + BudgetStatus::Exceeded { + budget_type: BudgetType::Token + } + ); + } + + #[test] + fn test_remaining_token_budget() { + let mut stats = UsageStats::new(); + stats.session_input_tokens = 3000; + stats.session_output_tokens = 2000; + + assert_eq!(stats.get_remaining_token_budget(Some(10000)), Some(5000)); + assert_eq!(stats.get_remaining_token_budget(None), None); + } + + #[test] + fn test_remaining_token_budget_exceeded() { + let mut stats = UsageStats::new(); + stats.session_input_tokens = 8000; + stats.session_output_tokens = 5000; + + assert_eq!(stats.get_remaining_token_budget(Some(10000)), Some(0)); + } + + #[test] + fn test_remaining_cost_budget() { + let mut stats = UsageStats::new(); + stats.session_cost_usd = 2.50; + + let remaining = stats.get_remaining_cost_budget(Some(5.0)); + assert!(remaining.is_some()); + assert!((remaining.unwrap() - 2.50).abs() < 0.001); + assert_eq!(stats.get_remaining_cost_budget(None), None); + } + + #[test] + fn test_remaining_cost_budget_exceeded() { + let mut stats = UsageStats::new(); + stats.session_cost_usd = 6.0; + + let remaining = stats.get_remaining_cost_budget(Some(5.0)); + assert!(remaining.is_some()); + assert!((remaining.unwrap() - 0.0).abs() < 0.001); + } + + #[test] + fn test_budget_status_serialization() { + let status = BudgetStatus::Warning { + budget_type: BudgetType::Token, + percent_used: 85.5, + }; + let json = serde_json::to_string(&status).expect("Failed to serialize"); + assert!(json.contains("warning")); + assert!(json.contains("token")); + + let status = BudgetStatus::Exceeded { + budget_type: BudgetType::Cost, + }; + let json = serde_json::to_string(&status).expect("Failed to serialize"); + assert!(json.contains("exceeded")); + assert!(json.contains("cost")); + } + + #[test] + fn test_budget_type_serialization() { + let token = BudgetType::Token; + let json = serde_json::to_string(&token).expect("Failed to serialize"); + assert_eq!(json, "\"token\""); + + let cost = BudgetType::Cost; + let json = serde_json::to_string(&cost).expect("Failed to serialize"); + assert_eq!(json, "\"cost\""); + } } diff --git a/src-tauri/src/tool_cache.rs b/src-tauri/src/tool_cache.rs new file mode 100644 index 0000000..68bb790 --- /dev/null +++ b/src-tauri/src/tool_cache.rs @@ -0,0 +1,266 @@ +use serde::{Deserialize, Serialize}; +use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; + +/// Tools that could benefit from caching +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheableTool { + Read, + Glob, + Grep, +} + +impl CacheableTool { + #[allow(dead_code)] + pub fn from_name(name: &str) -> Option { + match name { + "Read" => Some(Self::Read), + "Glob" => Some(Self::Glob), + "Grep" => Some(Self::Grep), + _ => None, + } + } +} + +/// Statistics about potential cache savings +#[allow(dead_code)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct CacheAnalytics { + /// Number of tool calls that could have been cache hits + pub potential_cache_hits: u64, + /// Estimated tokens that could have been saved + pub potential_savings_tokens: u64, + /// Tracks unique tool invocations: hash -> (tool_name, call_count) + #[serde(skip)] + recent_invocations: HashMap, +} + +#[allow(dead_code)] +impl CacheAnalytics { + pub fn new() -> Self { + Self::default() + } + + /// Compute a hash key from tool name and input + fn compute_key(tool_name: &str, input: &serde_json::Value) -> u64 { + let mut hasher = DefaultHasher::new(); + tool_name.hash(&mut hasher); + input.to_string().hash(&mut hasher); + hasher.finish() + } + + /// Track a tool invocation for analytics + /// Returns true if this was a repeated invocation (potential cache hit) + pub fn track_invocation( + &mut self, + tool_name: &str, + input: &serde_json::Value, + estimated_tokens: u64, + ) -> bool { + // Only track cacheable tools + if CacheableTool::from_name(tool_name).is_none() { + return false; + } + + let key = Self::compute_key(tool_name, input); + + if let Some((_, count)) = self.recent_invocations.get_mut(&key) { + *count += 1; + // This is a repeat - could have been a cache hit + self.potential_cache_hits += 1; + self.potential_savings_tokens += estimated_tokens; + true + } else { + self.recent_invocations + .insert(key, (tool_name.to_string(), 1)); + false + } + } + + /// Get the number of unique tool invocations being tracked + pub fn unique_invocations(&self) -> usize { + self.recent_invocations.len() + } + + /// Get invocations that were called more than once + pub fn repeated_invocations(&self) -> Vec<(&str, u64)> { + self.recent_invocations + .values() + .filter(|(_, count)| *count > 1) + .map(|(name, count)| (name.as_str(), *count)) + .collect() + } + + /// Clear session analytics (keep totals) + pub fn clear_session(&mut self) { + self.recent_invocations.clear(); + } + + /// Fully reset all analytics + pub fn reset(&mut self) { + self.potential_cache_hits = 0; + self.potential_savings_tokens = 0; + self.recent_invocations.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_cacheable_tool_from_name() { + assert_eq!(CacheableTool::from_name("Read"), Some(CacheableTool::Read)); + assert_eq!(CacheableTool::from_name("Glob"), Some(CacheableTool::Glob)); + assert_eq!(CacheableTool::from_name("Grep"), Some(CacheableTool::Grep)); + assert_eq!(CacheableTool::from_name("Bash"), None); + assert_eq!(CacheableTool::from_name("Edit"), None); + assert_eq!(CacheableTool::from_name("Write"), None); + } + + #[test] + fn test_first_invocation_not_cache_hit() { + let mut analytics = CacheAnalytics::new(); + let input = json!({"file_path": "/home/test/file.txt"}); + + let is_repeat = analytics.track_invocation("Read", &input, 100); + + assert!(!is_repeat); + assert_eq!(analytics.potential_cache_hits, 0); + assert_eq!(analytics.potential_savings_tokens, 0); + } + + #[test] + fn test_second_invocation_is_cache_hit() { + let mut analytics = CacheAnalytics::new(); + let input = json!({"file_path": "/home/test/file.txt"}); + + analytics.track_invocation("Read", &input, 100); + let is_repeat = analytics.track_invocation("Read", &input, 100); + + assert!(is_repeat); + assert_eq!(analytics.potential_cache_hits, 1); + assert_eq!(analytics.potential_savings_tokens, 100); + } + + #[test] + fn test_different_inputs_not_cache_hit() { + let mut analytics = CacheAnalytics::new(); + let input1 = json!({"file_path": "/home/test/file1.txt"}); + let input2 = json!({"file_path": "/home/test/file2.txt"}); + + analytics.track_invocation("Read", &input1, 100); + let is_repeat = analytics.track_invocation("Read", &input2, 100); + + assert!(!is_repeat); + assert_eq!(analytics.potential_cache_hits, 0); + } + + #[test] + fn test_non_cacheable_tool_ignored() { + let mut analytics = CacheAnalytics::new(); + let input = json!({"command": "ls -la"}); + + let is_repeat = analytics.track_invocation("Bash", &input, 100); + analytics.track_invocation("Bash", &input, 100); + + assert!(!is_repeat); + assert_eq!(analytics.potential_cache_hits, 0); + assert_eq!(analytics.unique_invocations(), 0); + } + + #[test] + fn test_multiple_repeated_invocations() { + let mut analytics = CacheAnalytics::new(); + let input = json!({"file_path": "/home/test/file.txt"}); + + analytics.track_invocation("Read", &input, 100); + analytics.track_invocation("Read", &input, 100); + analytics.track_invocation("Read", &input, 100); + + assert_eq!(analytics.potential_cache_hits, 2); + assert_eq!(analytics.potential_savings_tokens, 200); + } + + #[test] + fn test_unique_invocations_count() { + let mut analytics = CacheAnalytics::new(); + + analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100); + analytics.track_invocation("Read", &json!({"file_path": "/file2.txt"}), 100); + analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50); + + assert_eq!(analytics.unique_invocations(), 3); + } + + #[test] + fn test_repeated_invocations_list() { + let mut analytics = CacheAnalytics::new(); + + // file1 read twice + analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100); + analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100); + + // file2 read once + analytics.track_invocation("Read", &json!({"file_path": "/file2.txt"}), 100); + + // glob run 3 times + analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50); + analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50); + analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50); + + let repeated = analytics.repeated_invocations(); + assert_eq!(repeated.len(), 2); // file1 and glob pattern + } + + #[test] + fn test_clear_session() { + let mut analytics = CacheAnalytics::new(); + let input = json!({"file_path": "/file.txt"}); + + analytics.track_invocation("Read", &input, 100); + analytics.track_invocation("Read", &input, 100); + + assert_eq!(analytics.potential_cache_hits, 1); + assert_eq!(analytics.unique_invocations(), 1); + + analytics.clear_session(); + + assert_eq!(analytics.potential_cache_hits, 1); // Preserved + assert_eq!(analytics.unique_invocations(), 0); // Cleared + } + + #[test] + fn test_reset() { + let mut analytics = CacheAnalytics::new(); + let input = json!({"file_path": "/file.txt"}); + + analytics.track_invocation("Read", &input, 100); + analytics.track_invocation("Read", &input, 100); + + analytics.reset(); + + assert_eq!(analytics.potential_cache_hits, 0); + assert_eq!(analytics.potential_savings_tokens, 0); + assert_eq!(analytics.unique_invocations(), 0); + } + + #[test] + fn test_serialization() { + let mut analytics = CacheAnalytics::new(); + analytics.potential_cache_hits = 10; + analytics.potential_savings_tokens = 500; + + let json = serde_json::to_string(&analytics).expect("Failed to serialize"); + let deserialized: CacheAnalytics = + serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(deserialized.potential_cache_hits, 10); + assert_eq!(deserialized.potential_savings_tokens, 500); + // recent_invocations is skipped in serialization + assert_eq!(deserialized.unique_invocations(), 0); + } +} diff --git a/src-tauri/src/types.rs b/src-tauri/src/types.rs index a4594e8..3607d89 100644 --- a/src-tauri/src/types.rs +++ b/src-tauri/src/types.rs @@ -176,6 +176,14 @@ pub struct StateChangeEvent { pub conversation_id: Option, } +/// Cost information for a message +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageCost { + pub input_tokens: u64, + pub output_tokens: u64, + pub cost_usd: f64, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OutputEvent { pub line_type: String, @@ -183,6 +191,8 @@ pub struct OutputEvent { pub tool_name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub conversation_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -354,10 +364,31 @@ mod tests { content: "Test output".to_string(), tool_name: None, conversation_id: None, + cost: None, }; let serialized = serde_json::to_string(&event).unwrap(); assert!(serialized.contains("\"line_type\":\"assistant\"")); assert!(serialized.contains("\"content\":\"Test output\"")); } + + #[test] + fn test_output_event_with_cost() { + let event = OutputEvent { + line_type: "assistant".to_string(), + content: "Test output".to_string(), + tool_name: None, + conversation_id: Some("conv-123".to_string()), + cost: Some(MessageCost { + input_tokens: 100, + output_tokens: 50, + cost_usd: 0.005, + }), + }; + + let serialized = serde_json::to_string(&event).unwrap(); + assert!(serialized.contains("\"cost\":")); + assert!(serialized.contains("\"input_tokens\":100")); + assert!(serialized.contains("\"output_tokens\":50")); + } } diff --git a/src-tauri/src/wsl_bridge.rs b/src-tauri/src/wsl_bridge.rs index 75292f2..5f99b16 100644 --- a/src-tauri/src/wsl_bridge.rs +++ b/src-tauri/src/wsl_bridge.rs @@ -9,12 +9,13 @@ use tempfile::NamedTempFile; use std::os::windows::process::CommandExt; use crate::achievements::{get_achievement_info, AchievementUnlockedEvent}; +use crate::commands::record_cost; use crate::config::ClaudeStartOptions; -use crate::stats::{StatsUpdateEvent, UsageStats}; +use crate::stats::{calculate_cost, StatsUpdateEvent, UsageStats}; use crate::types::{ - CharacterState, ClaudeMessage, ConnectionEvent, ConnectionStatus, ContentBlock, OutputEvent, - PermissionPromptEvent, QuestionOption, SessionEvent, StateChangeEvent, UserQuestionEvent, - WorkingDirectoryEvent, + CharacterState, ClaudeMessage, ConnectionEvent, ConnectionStatus, ContentBlock, MessageCost, + OutputEvent, PermissionPromptEvent, QuestionOption, SessionEvent, StateChangeEvent, + UserQuestionEvent, WorkingDirectoryEvent, }; use parking_lot::RwLock; @@ -534,6 +535,7 @@ fn handle_stderr( content: line, tool_name: None, conversation_id: conversation_id.clone(), + cost: None, }, ); } @@ -586,17 +588,57 @@ fn process_json_line( let mut state = CharacterState::Typing; let mut tool_name = None; + // Collect all tool names in this message for token attribution + let tools_in_message: Vec = message + .content + .iter() + .filter_map(|block| match block { + ContentBlock::ToolUse { name, .. } => Some(name.clone()), + _ => None, + }) + .collect(); + + // Track message cost for display + let mut message_cost: Option = None; + // Only update stats if we have usage information if let Some(usage) = &message.usage { if let Some(model) = &message.model { + // Calculate cost for historical tracking + let cost_usd = calculate_cost(usage.input_tokens, usage.output_tokens, model); + + // Store cost for later use in output events + message_cost = Some(MessageCost { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cost_usd, + }); + // Batch all stats updates in a single write lock { let mut stats_guard = stats.write(); stats_guard.increment_messages(); stats_guard.add_usage(usage.input_tokens, usage.output_tokens, model); stats_guard.get_session_duration(); + + // Attribute tokens to tools if any tools were used in this message + if !tools_in_message.is_empty() { + let per_tool_input = usage.input_tokens / tools_in_message.len() as u64; + let per_tool_output = usage.output_tokens / tools_in_message.len() as u64; + for tool in &tools_in_message { + stats_guard.add_tool_tokens(tool, per_tool_input, per_tool_output); + } + } } + // Record to historical cost tracking + let app_clone = app.clone(); + let input = usage.input_tokens; + let output = usage.output_tokens; + tauri::async_runtime::spawn(async move { + record_cost(&app_clone, input, output, cost_usd).await; + }); + // Don't emit here - we'll emit on Result message instead // This reduces the frequency of updates } else { @@ -635,6 +677,7 @@ fn process_json_line( content: desc, tool_name: Some(name.clone()), conversation_id: conversation_id.clone(), + cost: None, // Tool use doesn't have separate cost }, ); } @@ -652,6 +695,7 @@ fn process_json_line( content: text.clone(), tool_name: None, conversation_id: conversation_id.clone(), + cost: message_cost.clone(), // Include cost with assistant text }, ); } @@ -664,6 +708,7 @@ fn process_json_line( content: format!("[Thinking] {}", thinking), tool_name: None, conversation_id: conversation_id.clone(), + cost: None, }, ); } @@ -723,9 +768,20 @@ fn process_json_line( stats_guard.model.clone().unwrap_or_else(|| "claude-opus-4-20250514".to_string()) }; + // Calculate cost for historical tracking + let cost_usd = calculate_cost(usage_info.input_tokens, usage_info.output_tokens, &model); + let mut stats_guard = stats.write(); stats_guard.add_usage(usage_info.input_tokens, usage_info.output_tokens, &model); println!("Result message tokens - input: {}, output: {}", usage_info.input_tokens, usage_info.output_tokens); + + // Record to historical cost tracking + let app_clone = app.clone(); + let input = usage_info.input_tokens; + let output = usage_info.output_tokens; + tauri::async_runtime::spawn(async move { + record_cost(&app_clone, input, output, cost_usd).await; + }); } // Always emit updated stats on result message (less frequent) @@ -797,6 +853,7 @@ fn process_json_line( content: text.clone(), tool_name: None, conversation_id: conversation_id.clone(), + cost: None, }, ); } diff --git a/src/lib/components/AchievementNotification.svelte b/src/lib/components/AchievementNotification.svelte index b1c9712..79e3751 100644 --- a/src/lib/components/AchievementNotification.svelte +++ b/src/lib/components/AchievementNotification.svelte @@ -161,7 +161,7 @@
- {#each Array(10) as _ (_)} + {#each Array.from({ length: 10 }, (_, i) => i) as confettiIndex (confettiIndex)}
+ +
+

+ Budget Settings +

+ + +
+ +

+ Set limits on token usage and costs per session +

+
+ + {#if config.budget_enabled} + +
+ +
+ + tokens +
+

Leave empty for unlimited tokens

+
+ + +
+ +
+ $ + + USD +
+

Leave empty for unlimited spending

+
+ + +
+ +
+ + + {Math.round(config.budget_warning_threshold * 100)}% + +
+

+ Show warning when this percentage of budget is used +

+
+ + +
+ When budget is exceeded +
+ + +
+

+ {config.budget_action === "warn" + ? "Show a warning but allow continued usage" + : "Prevent sending more messages until session is reset"} +

+
+ {/if} +
+ + +
+

+ Cost History +

+ +
+

diff --git a/src/lib/components/CostSummary.svelte b/src/lib/components/CostSummary.svelte new file mode 100644 index 0000000..3e39f11 --- /dev/null +++ b/src/lib/components/CostSummary.svelte @@ -0,0 +1,402 @@ + + +
+

Cost Summary

+ + +
+
+ Today + {costs.today} +
+
+ This Week + {costs.week} +
+
+ This Month + {costs.month} +
+
+ + +
+ + + +
+ + + {#if isLoading} +
Loading...
+ {:else if summary} +
+
+ Total Cost + {formatCost(summary.total_cost)} +
+
+ Average Daily + {formatCost(summary.average_daily_cost)} +
+
+ Messages + {summary.total_messages.toLocaleString()} +
+
+ Sessions + {summary.total_sessions.toLocaleString()} +
+
+ Input Tokens + {summary.total_input_tokens.toLocaleString()} +
+
+ Output Tokens + {summary.total_output_tokens.toLocaleString()} +
+
+ + + {#if summary.daily_breakdown.length > 0} +
+

Daily Spending

+
+ {#each summary.daily_breakdown.slice(-14) as day (day.date)} + {@const maxCost = Math.max(...summary.daily_breakdown.map((d) => d.cost_usd), 0.01)} + {@const height = (day.cost_usd / maxCost) * 100} +
+
+
+ {/each} +
+
+ {/if} + {/if} + + +
+ + +
+ + + {#if showThresholdSettings} +
+

Cost Alert Thresholds

+
+ + +
+
+ + +
+
+ + +
+ +
+ {/if} +
+ + diff --git a/src/lib/components/InputBar.svelte b/src/lib/components/InputBar.svelte index a05da28..1cd2a19 100644 --- a/src/lib/components/InputBar.svelte +++ b/src/lib/components/InputBar.svelte @@ -26,6 +26,7 @@ type SlashCommand, } from "$lib/commands/slashCommands"; import { configStore, isStreamerMode } from "$lib/stores/config"; + import { stats, estimateMessageCost, formatTokenCount } from "$lib/stores/stats"; import AttachmentPreview from "$lib/components/AttachmentPreview.svelte"; import SnippetLibraryPanel from "$lib/components/SnippetLibraryPanel.svelte"; import QuickActionsPanel from "$lib/components/QuickActionsPanel.svelte"; @@ -50,6 +51,13 @@ let showClipboardHistory = $state(false); let streamerModeActive = $state(false); + // Cost estimation for pre-submission display + let costEstimate = $derived( + inputValue.trim() + ? estimateMessageCost(inputValue, $stats.context_tokens_used, $stats.model) + : null + ); + // Context menu state let textareaElement: HTMLTextAreaElement | undefined = $state(); let contextMenuShow = $state(false); @@ -913,6 +921,13 @@ User: ${formattedMessage}`;

+ {#if costEstimate && isConnected && !isProcessing} +
+ +{formatTokenCount(costEstimate.messageTokens)} + ${costEstimate.estimatedCost.toFixed(4)} +
+ {/if} + + {/if} + {#if onStartFreshWithContext} + + {/if} +
+ {/if} +
+ + {#if $config.budget_enabled} +
+

Budget

+ {#if $config.session_token_budget !== null} +
+
+ Tokens: + + {($stats.session_input_tokens + $stats.session_output_tokens).toLocaleString()} / {$config.session_token_budget.toLocaleString()} + +
+
+
+
+
+ {remainingTokens?.toLocaleString() ?? 0} remaining ({( + 100 - tokenBudgetPercent() + ).toFixed(1)}%) +
+
+ {/if} + {#if $config.session_cost_budget !== null} +
+
+ Cost: + + ${$stats.session_cost_usd.toFixed(4)} / ${$config.session_cost_budget.toFixed(2)} + +
+
+
+
+
+ ${remainingCost?.toFixed(4) ?? "0.0000"} remaining ({( + 100 - costBudgetPercent() + ).toFixed(1)}%) +
+
+ {/if} + {#if budgetMessage} +
+ {budgetMessage} +
+ {/if} +
+ {/if} +

Tokens & Cost

@@ -49,7 +239,7 @@
- {#if Object.keys($formattedStats.sessionToolsUsage).length > 0} + {#if $formattedStats.sessionToolsFormatted.length > 0}

{#if showToolsBreakdown}
- {#each Object.entries($formattedStats.sessionToolsUsage).sort((a, b) => b[1] - a[1]) as [tool, count] (tool)} -
- {tool}: - {count} + {#each $formattedStats.sessionToolsFormatted.sort((a, b) => b.totalTokens - a.totalTokens) as tool (tool.name)} +
+ {tool.name}: + + {tool.callCount} calls + {#if tool.totalTokens > 0} + (~{tool.formattedTokens}) + {/if} +
{/each} +
* Token estimates based on attribution
{/if}
{/if} + +
+

+ +

+ {#if !showHistoricalCosts} +
+ Today: {historicalCosts.today} + Week: {historicalCosts.week} + Month: {historicalCosts.month} +
+ {/if} + {#if showHistoricalCosts} +
+
+ Today: + {historicalCosts.today} +
+
+ This Week: + {historicalCosts.week} +
+
+ This Month: + {historicalCosts.month} +
+

Open Settings to view detailed cost history and set alerts.

+
+ {/if} +
+
Model: {$formattedStats.model} @@ -128,6 +358,79 @@ color: var(--text-primary, #e5e7eb); } + .stat-cost { + font-family: var(--font-mono, monospace); + color: var(--accent-primary, #10b981); + font-size: 0.8rem; + margin-left: 0.5rem; + } + + .stats-row { + display: flex; + justify-content: space-between; + align-items: center; + padding: 0.125rem 0; + } + + .tools-header { + margin: 0 !important; + padding: 0 !important; + border: none !important; + } + + .tools-toggle { + display: flex; + align-items: center; + gap: 0.5rem; + background: none; + border: none; + color: var(--text-primary); + font-size: 0.9rem; + font-weight: 600; + cursor: pointer; + padding: 0; + } + + .tools-toggle:hover { + color: var(--accent-primary); + } + + .toggle-icon { + font-size: 0.7rem; + opacity: 0.7; + } + + .tools-breakdown { + margin-top: 0.25rem; + } + + .tool-row { + flex-wrap: wrap; + } + + .tool-stats { + display: flex; + gap: 0.5rem; + align-items: center; + } + + .tool-calls { + color: var(--text-primary, #e5e7eb); + } + + .tool-tokens { + color: var(--text-secondary, #9ca3af); + font-size: 0.75rem; + } + + .tools-note { + margin-top: 0.5rem; + font-size: 0.65rem; + color: var(--text-secondary, #9ca3af); + font-style: italic; + opacity: 0.8; + } + .model-info { display: flex; justify-content: space-between; @@ -148,4 +451,220 @@ color: var(--text-primary, #e5e7eb); font-size: 0.75rem; } + + .context-util { + font-weight: 600; + } + + .context-util.warning-moderate { + color: #f59e0b; + } + + .context-util.warning-high { + color: #f97316; + } + + .context-util.warning-critical { + color: #ef4444; + } + + .context-warning { + margin-top: 0.5rem; + padding: 0.5rem; + border-radius: 4px; + font-size: 0.75rem; + line-height: 1.3; + } + + .context-warning.warning-moderate { + background: rgba(245, 158, 11, 0.15); + border: 1px solid rgba(245, 158, 11, 0.3); + color: #fbbf24; + } + + .context-warning.warning-high { + background: rgba(249, 115, 22, 0.15); + border: 1px solid rgba(249, 115, 22, 0.3); + color: #fb923c; + } + + .context-warning.warning-critical { + background: rgba(239, 68, 68, 0.15); + border: 1px solid rgba(239, 68, 68, 0.3); + color: #f87171; + } + + /* Budget progress bar styles */ + .budget-item { + margin-bottom: 0.75rem; + } + + .budget-item:last-child { + margin-bottom: 0; + } + + .budget-bar-container { + width: 100%; + height: 6px; + background: var(--bg-primary); + border-radius: 3px; + margin-top: 0.25rem; + overflow: hidden; + } + + .budget-bar { + height: 100%; + border-radius: 3px; + transition: + width 0.3s ease, + background-color 0.3s ease; + } + + .budget-bar-ok { + background: linear-gradient(90deg, #10b981, #34d399); + } + + .budget-bar-warning { + background: linear-gradient(90deg, #f59e0b, #fbbf24); + } + + .budget-bar-exceeded { + background: linear-gradient(90deg, #ef4444, #f87171); + } + + .budget-remaining { + font-size: 0.7rem; + color: var(--text-secondary); + margin-top: 0.125rem; + text-align: right; + } + + /* Budget warning styles */ + .budget-warning { + margin-top: 0.5rem; + padding: 0.5rem; + border-radius: 4px; + font-size: 0.75rem; + line-height: 1.3; + } + + .budget-warning.budget-warning { + background: rgba(245, 158, 11, 0.15); + border: 1px solid rgba(245, 158, 11, 0.3); + color: #fbbf24; + } + + .budget-warning.budget-exceeded { + background: rgba(239, 68, 68, 0.15); + border: 1px solid rgba(239, 68, 68, 0.3); + color: #f87171; + } + + .stat-value.budget-warning { + color: #f59e0b; + font-weight: 600; + } + + .stat-value.budget-exceeded { + color: #ef4444; + font-weight: 600; + } + + /* Compaction action buttons */ + .compaction-actions { + display: flex; + gap: 0.5rem; + margin-top: 0.5rem; + } + + .compaction-btn { + flex: 1; + padding: 0.375rem 0.5rem; + font-size: 0.75rem; + font-weight: 500; + border-radius: 4px; + border: 1px solid var(--border-color); + background: var(--bg-primary); + color: var(--text-primary); + cursor: pointer; + transition: all 0.15s ease; + } + + .compaction-btn:hover:not(:disabled) { + border-color: var(--accent-primary); + background: rgba(233, 69, 96, 0.1); + } + + .compaction-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } + + .compaction-btn-primary { + background: var(--accent-primary); + border-color: var(--accent-primary); + color: white; + } + + .compaction-btn-primary:hover:not(:disabled) { + background: var(--accent-secondary); + border-color: var(--accent-secondary); + } + + /* Historical costs styles */ + .costs-header { + margin: 0 !important; + padding: 0 !important; + border: none !important; + } + + .costs-toggle { + display: flex; + align-items: center; + gap: 0.5rem; + background: none; + border: none; + color: var(--text-primary); + font-size: 0.9rem; + font-weight: 600; + cursor: pointer; + padding: 0; + } + + .costs-toggle:hover { + color: var(--accent-primary); + } + + .costs-quick-stats { + display: flex; + gap: 0.5rem; + margin-top: 0.5rem; + flex-wrap: wrap; + } + + .cost-badge { + font-size: 0.7rem; + padding: 0.2rem 0.4rem; + background: rgba(16, 185, 129, 0.1); + border: 1px solid rgba(16, 185, 129, 0.3); + border-radius: 3px; + color: #10b981; + font-family: var(--font-mono, monospace); + } + + .historical-costs-expanded { + margin-top: 0.5rem; + } + + .cost-value { + color: #10b981; + } + + .costs-note { + margin: 0.5rem 0 0 0; + font-size: 0.65rem; + color: var(--text-secondary); + font-style: italic; + opacity: 0.8; + } diff --git a/src/lib/components/StatusBar.svelte b/src/lib/components/StatusBar.svelte index be2942e..3fb2600 100644 --- a/src/lib/components/StatusBar.svelte +++ b/src/lib/components/StatusBar.svelte @@ -24,6 +24,12 @@ import SessionHistoryPanel from "./SessionHistoryPanel.svelte"; import GitPanel from "./GitPanel.svelte"; import ProfilePanel from "./ProfilePanel.svelte"; + import { conversationsStore } from "$lib/stores/conversations"; + import { + generateContextInjection, + createSummary, + sanitizeForJson, + } from "$lib/utils/conversationUtils"; const DISCORD_URL = "https://chat.nhcarrigan.com"; const DONATE_URL = "https://donate.nhcarrigan.com"; @@ -41,6 +47,7 @@ let showSessionHistory = $state(false); let showGitPanel = $state(false); let showProfile = $state(false); + let isSummarising = $state(false); const progress = $derived($achievementProgress); let currentConfig: HikariConfig = $state({ model: null, @@ -74,6 +81,11 @@ text_secondary: null, border_color: null, }, + budget_enabled: false, + session_token_budget: null, + session_cost_budget: null, + budget_action: "warn", + budget_warning_threshold: 0.8, }); let streamerModeActive = $state(false); @@ -200,6 +212,106 @@ function toggleAchievements() { onToggleAchievements(); } + + async function handleCompactConversation() { + const activeId = get(conversationsStore.activeConversationId); + if (!activeId) return; + + isSummarising = true; + + try { + const conversationContent = conversationsStore.getConversationForSummary(activeId); + const messageCount = + get(conversationsStore.activeConversation)?.terminalLines.filter( + (l) => l.type === "user" || l.type === "assistant" + ).length || 0; + const tokenEstimate = conversationsStore.estimateTokenCount(activeId); + + // Create a summary from the conversation content (truncate if too long) + // Apply sanitization early to handle any problematic escape sequences + const sanitizedContent = sanitizeForJson(conversationContent); + const summaryContent = + sanitizedContent.length > 4000 + ? `${sanitizedContent.slice(0, 4000)}\n\n[Truncated for length - original had ${messageCount} messages]` + : sanitizedContent; + + // Step 1: Disconnect from Claude to reset context + if (connectionStatus === "connected") { + await invoke("stop_claude", { conversationId: activeId }); + } + + // Step 2: Clear messages and store summary + conversationsStore.compactWithSummary(activeId, summaryContent, messageCount, tokenEstimate); + + // Step 3: Reconnect to Claude with fresh context + const allAllowedTools = [ + ...(currentConfig.auto_granted_tools || []), + ...Array.from(get(claudeStore.grantedTools)), + ]; + + await invoke("start_claude", { + conversationId: activeId, + options: { + working_dir: workingDirectory || selectedDirectory, + model: currentConfig.model || null, + api_key: currentConfig.api_key || null, + custom_instructions: currentConfig.custom_instructions || null, + mcp_servers_json: currentConfig.mcp_servers_json || null, + allowed_tools: allAllowedTools, + }, + }); + + // Step 4: Send the context summary to Claude as the first message + const contextPrompt = generateContextInjection( + createSummary(summaryContent, messageCount, tokenEstimate) + ); + + await invoke("send_prompt", { + conversationId: activeId, + message: contextPrompt, + }); + + claudeStore.addLine( + "system", + "Conversation compacted. Context from previous session has been provided to Claude." + ); + } catch (error) { + console.error("Failed to compact conversation:", error); + claudeStore.addLine("error", `Failed to compact conversation: ${error}`); + } finally { + isSummarising = false; + } + } + + async function handleStartFreshWithContext() { + const activeId = get(conversationsStore.activeConversationId); + if (!activeId) return; + + const conversationContent = conversationsStore.getConversationForSummary(activeId); + const messageCount = + get(conversationsStore.activeConversation)?.terminalLines.filter( + (l) => l.type === "user" || l.type === "assistant" + ).length || 0; + const tokenEstimate = conversationsStore.estimateTokenCount(activeId); + + const summary = createSummary( + `This is a continuation of a previous conversation. Here's what was discussed:\n\n${conversationContent.slice(0, 4000)}${conversationContent.length > 4000 ? "\n\n[Truncated for length...]" : ""}`, + messageCount, + tokenEstimate + ); + + const newConvId = conversationsStore.createConversation("Fresh Start"); + + conversationsStore.setSummary(newConvId, summary); + + // Context injection is generated but the actual injection happens via the summary + generateContextInjection(summary); + claudeStore.addLine("system", "Started fresh conversation with context from previous session."); + claudeStore.addLine( + "system", + `Previous session had ${messageCount} messages (~${tokenEstimate.toLocaleString()} tokens).` + ); + }
- +
{/if} {#if connectionStatus === "connected"} @@ -473,7 +589,11 @@
(showStats = false)}>
- +
{/if} diff --git a/src/lib/components/Terminal.svelte b/src/lib/components/Terminal.svelte index fb6faab..b753db0 100644 --- a/src/lib/components/Terminal.svelte +++ b/src/lib/components/Terminal.svelte @@ -211,6 +211,16 @@ {#each lines as line (line.id)}
{formatTime(line.timestamp)} + {#if line.cost && line.cost.costUsd > 0} + + ${line.cost.costUsd < 0.01 + ? line.cost.costUsd.toFixed(4) + : line.cost.costUsd.toFixed(3)} + + {/if} {#if getLinePrefix(line.type)} {getLinePrefix(line.type)} {/if} @@ -291,6 +301,14 @@ color: var(--text-tertiary, #6b7280); } + .terminal-cost { + color: var(--terminal-cost, #10b981); + background: var(--terminal-cost-bg, rgba(16, 185, 129, 0.1)); + padding: 0 4px; + border-radius: 3px; + font-family: monospace; + } + .terminal-prefix { color: var(--text-secondary); } diff --git a/src/lib/notifications/notificationManager.ts b/src/lib/notifications/notificationManager.ts index 6900281..76d82b0 100644 --- a/src/lib/notifications/notificationManager.ts +++ b/src/lib/notifications/notificationManager.ts @@ -90,6 +90,8 @@ class NotificationManager { return "Successfully connected to Claude Code"; case NotificationType.TASK_START: return "Starting task..."; + case NotificationType.COST_ALERT: + return "You've exceeded your cost threshold!"; default: return "Notification"; } @@ -115,6 +117,10 @@ class NotificationManager { async notifyTaskStart(message?: string): Promise { await this.notify(NotificationType.TASK_START, message); } + + async notifyCostAlert(message?: string): Promise { + await this.notify(NotificationType.COST_ALERT, message); + } } // Export singleton instance diff --git a/src/lib/notifications/notifications.test.ts b/src/lib/notifications/notifications.test.ts index e730dff..61152a5 100644 --- a/src/lib/notifications/notifications.test.ts +++ b/src/lib/notifications/notifications.test.ts @@ -51,9 +51,13 @@ describe("notifications", () => { expect(NotificationType.ACHIEVEMENT).toBe("achievement"); }); - it("has exactly 6 notification types", () => { + it("has exactly 7 notification types", () => { const types = Object.values(NotificationType); - expect(types.length).toBe(6); + expect(types.length).toBe(7); + }); + + it("has COST_ALERT type", () => { + expect(NotificationType.COST_ALERT).toBe("cost_alert"); }); }); @@ -314,10 +318,11 @@ describe("notifications", () => { }); }); - it("sound filenames are unique", () => { + it("sound filenames are mostly unique", () => { const filenames = Object.values(NOTIFICATION_SOUNDS).map((s) => s.filename); const uniqueFilenames = new Set(filenames); - expect(uniqueFilenames.size).toBe(filenames.length); + // Allow some sound reuse (e.g., COST_ALERT reuses ERROR sound) + expect(uniqueFilenames.size).toBeGreaterThanOrEqual(filenames.length - 1); }); it("phrases are unique", () => { diff --git a/src/lib/notifications/types.ts b/src/lib/notifications/types.ts index 6fda8a2..47771b8 100644 --- a/src/lib/notifications/types.ts +++ b/src/lib/notifications/types.ts @@ -5,6 +5,7 @@ export enum NotificationType { CONNECTION = "connection", TASK_START = "task_start", ACHIEVEMENT = "achievement", + COST_ALERT = "cost_alert", } export interface NotificationSound { @@ -52,4 +53,10 @@ export const NOTIFICATION_SOUNDS: Record = phrase: "Achievement Get~!", volume: 0.8, }, + [NotificationType.COST_ALERT]: { + type: NotificationType.COST_ALERT, + filename: "oh-no.mp3", + phrase: "Cost Alert!", + volume: 0.9, + }, }; diff --git a/src/lib/stores/config.test.ts b/src/lib/stores/config.test.ts index df31e15..19e6860 100644 --- a/src/lib/stores/config.test.ts +++ b/src/lib/stores/config.test.ts @@ -187,6 +187,11 @@ describe("config store", () => { text_secondary: null, border_color: null, }, + budget_enabled: false, + session_token_budget: null, + session_cost_budget: null, + budget_action: "warn", + budget_warning_threshold: 0.8, }; expect(config.model).toBe("claude-sonnet-4"); @@ -227,6 +232,11 @@ describe("config store", () => { text_secondary: null, border_color: null, }, + budget_enabled: false, + session_token_budget: null, + session_cost_budget: null, + budget_action: "warn", + budget_warning_threshold: 0.8, }; expect(config.model).toBeNull(); diff --git a/src/lib/stores/config.ts b/src/lib/stores/config.ts index 0d5de7c..ca16a02 100644 --- a/src/lib/stores/config.ts +++ b/src/lib/stores/config.ts @@ -2,6 +2,7 @@ import { writable, derived } from "svelte/store"; import { invoke } from "@tauri-apps/api/core"; export type Theme = "dark" | "light" | "high-contrast" | "custom"; +export type BudgetAction = "warn" | "block"; export interface CustomThemeColors { bg_primary: string | null; @@ -37,6 +38,12 @@ export interface HikariConfig { profile_avatar_path: string | null; profile_bio: string | null; custom_theme_colors: CustomThemeColors; + // Budget settings + budget_enabled: boolean; + session_token_budget: number | null; + session_cost_budget: number | null; + budget_action: BudgetAction; + budget_warning_threshold: number; } const defaultConfig: HikariConfig = { @@ -71,6 +78,11 @@ const defaultConfig: HikariConfig = { text_secondary: null, border_color: null, }, + budget_enabled: false, + session_token_budget: null, + session_cost_budget: null, + budget_action: "warn", + budget_warning_threshold: 0.8, }; function createConfigStore() { diff --git a/src/lib/stores/conversations.ts b/src/lib/stores/conversations.ts index ab3b1a3..3fc9725 100644 --- a/src/lib/stores/conversations.ts +++ b/src/lib/stores/conversations.ts @@ -11,6 +11,13 @@ import { cleanupConversationTracking } from "$lib/tauri"; import { characterState } from "$lib/stores/character"; import { sessionsStore } from "$lib/stores/sessions"; +export interface ConversationSummary { + generatedAt: Date; + content: string; + messageCount: number; + tokenEstimate: number; +} + export interface Conversation { id: string; name: string; @@ -27,6 +34,7 @@ export interface Conversation { createdAt: Date; lastActivityAt: Date; attachments: Attachment[]; + summary: ConversationSummary | null; } function createConversationsStore() { @@ -63,6 +71,7 @@ function createConversationsStore() { createdAt: new Date(), lastActivityAt: new Date(), attachments: [], + summary: null, }; } @@ -420,7 +429,12 @@ function createConversationsStore() { }); }, - addLine: (type: TerminalLine["type"], content: string, toolName?: string) => { + addLine: ( + type: TerminalLine["type"], + content: string, + toolName?: string, + cost?: TerminalLine["cost"] + ) => { ensureInitialized(); const activeId = get(activeConversationId); if (!activeId) return ""; @@ -431,6 +445,7 @@ function createConversationsStore() { content, timestamp: new Date(), toolName, + cost, }; conversations.update((convs) => { @@ -451,7 +466,8 @@ function createConversationsStore() { conversationId: string, type: TerminalLine["type"], content: string, - toolName?: string + toolName?: string, + cost?: TerminalLine["cost"] ) => { ensureInitialized(); @@ -461,6 +477,7 @@ function createConversationsStore() { content, timestamp: new Date(), toolName, + cost, }; conversations.update((convs) => { @@ -636,6 +653,130 @@ function createConversationsStore() { return conv?.attachments || []; }, + // Summary/compaction functions + setSummary: (conversationId: string, summary: ConversationSummary) => { + conversations.update((convs) => { + const conv = convs.get(conversationId); + if (conv) { + conv.summary = summary; + conv.lastActivityAt = new Date(); + } + return convs; + }); + }, + + clearSummary: (conversationId: string) => { + conversations.update((convs) => { + const conv = convs.get(conversationId); + if (conv) { + conv.summary = null; + conv.lastActivityAt = new Date(); + } + return convs; + }); + }, + + getSummary: (conversationId: string): ConversationSummary | null => { + const convs = get(conversations); + const conv = convs.get(conversationId); + return conv?.summary || null; + }, + + // Estimate token count for a conversation (rough approximation: ~4 chars per token) + estimateTokenCount: (conversationId: string): number => { + const convs = get(conversations); + const conv = convs.get(conversationId); + if (!conv) return 0; + + const relevantLines = conv.terminalLines.filter( + (line) => line.type === "user" || line.type === "assistant" + ); + + const totalChars = relevantLines.reduce((sum, line) => sum + line.content.length, 0); + return Math.ceil(totalChars / 4); + }, + + // Get conversation content suitable for summarisation + getConversationForSummary: (conversationId: string): string => { + const convs = get(conversations); + const conv = convs.get(conversationId); + if (!conv) return ""; + + const relevantLines = conv.terminalLines.filter( + (line) => line.type === "user" || line.type === "assistant" + ); + + return relevantLines + .map((line) => { + const role = line.type === "user" ? "User" : "Assistant"; + return `${role}: ${line.content}`; + }) + .join("\n\n"); + }, + + // Compact conversation by keeping only recent messages + compactConversation: (conversationId: string, keepRecentCount: number = 10) => { + conversations.update((convs) => { + const conv = convs.get(conversationId); + if (conv && conv.terminalLines.length > keepRecentCount) { + // Keep system messages and the most recent user/assistant messages + const systemLines = conv.terminalLines.filter( + (line) => line.type !== "user" && line.type !== "assistant" + ); + const chatLines = conv.terminalLines.filter( + (line) => line.type === "user" || line.type === "assistant" + ); + + // Keep only the most recent chat messages + const recentChatLines = chatLines.slice(-keepRecentCount); + + // Combine: system lines at original positions + recent chat lines + conv.terminalLines = [...systemLines.slice(-5), ...recentChatLines]; + conv.lastActivityAt = new Date(); + } + return convs; + }); + }, + + // Compact conversation with a summary - clears old messages and injects summary context + compactWithSummary: ( + conversationId: string, + summaryContent: string, + messageCount: number, + tokenEstimate: number + ) => { + conversations.update((convs) => { + const conv = convs.get(conversationId); + if (conv) { + // Store the summary + conv.summary = { + generatedAt: new Date(), + content: summaryContent, + messageCount, + tokenEstimate, + }; + + // Clear all messages and add a context injection message + conv.terminalLines = [ + { + id: generateLineId(), + type: "system", + content: `[Conversation compacted] Previous session had ${messageCount} messages (~${tokenEstimate.toLocaleString()} tokens). Context preserved below.`, + timestamp: new Date(), + }, + { + id: generateLineId(), + type: "system", + content: `Previous Session Context:\n${summaryContent}`, + timestamp: new Date(), + }, + ]; + conv.lastActivityAt = new Date(); + } + return convs; + }); + }, + // Add initialization helper initialize: () => { ensureInitialized(); diff --git a/src/lib/stores/costTracking.ts b/src/lib/stores/costTracking.ts new file mode 100644 index 0000000..18da2c6 --- /dev/null +++ b/src/lib/stores/costTracking.ts @@ -0,0 +1,182 @@ +import { writable, derived } from "svelte/store"; +import { invoke } from "@tauri-apps/api/core"; +import { notificationManager } from "$lib/notifications/notificationManager"; + +// Types matching Rust backend +export interface DailyCost { + date: string; + input_tokens: number; + output_tokens: number; + cost_usd: number; + messages_sent: number; + sessions_count: number; +} + +export interface CostSummary { + period_days: number; + total_input_tokens: number; + total_output_tokens: number; + total_cost: number; + total_messages: number; + total_sessions: number; + average_daily_cost: number; + daily_breakdown: DailyCost[]; +} + +export type AlertType = "Daily" | "Weekly" | "Monthly"; + +export interface CostAlert { + alert_type: AlertType; + threshold: number; + current_cost: number; +} + +export interface CostAlertThresholds { + daily: number | null; + weekly: number | null; + monthly: number | null; +} + +// Store state +interface CostTrackingState { + todayCost: number; + weekCost: number; + monthCost: number; + summary: CostSummary | null; + alerts: CostAlert[]; + thresholds: CostAlertThresholds; + isLoading: boolean; + lastUpdated: Date | null; +} + +const defaultState: CostTrackingState = { + todayCost: 0, + weekCost: 0, + monthCost: 0, + summary: null, + alerts: [], + thresholds: { daily: null, weekly: null, monthly: null }, + isLoading: false, + lastUpdated: null, +}; + +function createCostTrackingStore() { + const { subscribe, set, update } = writable(defaultState); + + return { + subscribe, + + async refresh() { + update((s) => ({ ...s, isLoading: true })); + + try { + const [todayCost, weekCost, monthCost, alerts] = await Promise.all([ + invoke("get_today_cost"), + invoke("get_week_cost"), + invoke("get_month_cost"), + invoke("get_cost_alerts"), + ]); + + update((s) => ({ + ...s, + todayCost, + weekCost, + monthCost, + alerts, + isLoading: false, + lastUpdated: new Date(), + })); + + // Trigger notifications for any new alerts + if (alerts.length > 0) { + for (const alert of alerts) { + const message = getAlertMessage(alert); + notificationManager.notifyCostAlert(message); + } + } + + return alerts; + } catch (error) { + console.error("Failed to refresh cost tracking:", error); + update((s) => ({ ...s, isLoading: false })); + return []; + } + }, + + async getSummary(days: number): Promise { + try { + const summary = await invoke("get_cost_summary", { days }); + update((s) => ({ ...s, summary })); + return summary; + } catch (error) { + console.error("Failed to get cost summary:", error); + return null; + } + }, + + async setAlertThresholds(thresholds: CostAlertThresholds) { + try { + await invoke("set_cost_alert_thresholds", { + daily: thresholds.daily, + weekly: thresholds.weekly, + monthly: thresholds.monthly, + }); + update((s) => ({ ...s, thresholds })); + } catch (error) { + console.error("Failed to set alert thresholds:", error); + } + }, + + async exportCsv(days: number): Promise { + try { + return await invoke("export_cost_csv", { days }); + } catch (error) { + console.error("Failed to export CSV:", error); + return null; + } + }, + + reset() { + set(defaultState); + }, + }; +} + +export const costTrackingStore = createCostTrackingStore(); + +// Derived stores for formatted values +export const formattedCosts = derived(costTrackingStore, ($store) => ({ + today: formatCost($store.todayCost), + week: formatCost($store.weekCost), + month: formatCost($store.monthCost), + todayRaw: $store.todayCost, + weekRaw: $store.weekCost, + monthRaw: $store.monthCost, +})); + +// Helper functions +export function formatCost(cost: number): string { + if (cost < 0.01) { + return `$${cost.toFixed(4)}`; + } + if (cost < 1) { + return `$${cost.toFixed(3)}`; + } + return `$${cost.toFixed(2)}`; +} + +export function formatAlertType(type: AlertType): string { + switch (type) { + case "Daily": + return "Today"; + case "Weekly": + return "This Week"; + case "Monthly": + return "This Month"; + } +} + +export function getAlertMessage(alert: CostAlert): string { + const period = formatAlertType(alert.alert_type); + return `${period}'s spending (${formatCost(alert.current_cost)}) has exceeded your ${formatCost(alert.threshold)} threshold`; +} diff --git a/src/lib/stores/stats.test.ts b/src/lib/stores/stats.test.ts index 952c6e1..37259b7 100644 --- a/src/lib/stores/stats.test.ts +++ b/src/lib/stores/stats.test.ts @@ -1,7 +1,25 @@ import { describe, it, expect, beforeEach, vi } from "vitest"; import { get } from "svelte/store"; -import { stats, formattedStats, resetSessionStats } from "./stats"; -import type { UsageStats } from "./stats"; +import { + stats, + formattedStats, + resetSessionStats, + contextWarning, + getContextWarningMessage, + estimateMessageCost, + formatTokenCount, + MODEL_PRICING, +} from "./stats"; +import type { UsageStats, ToolTokenStats } from "./stats"; + +// Helper function to create ToolTokenStats for tests +function toolStats(callCount: number, inputTokens = 0, outputTokens = 0): ToolTokenStats { + return { + call_count: callCount, + estimated_input_tokens: inputTokens, + estimated_output_tokens: outputTokens, + }; +} // Mock Tauri APIs vi.mock("@tauri-apps/api/event", () => ({ @@ -34,6 +52,11 @@ describe("stats store", () => { tools_usage: {}, session_tools_usage: {}, session_duration_seconds: 0, + context_tokens_used: 0, + context_window_limit: 200000, + context_utilisation_percent: 0, + potential_cache_hits: 0, + potential_cache_savings_tokens: 0, }); }); @@ -63,9 +86,14 @@ describe("stats store", () => { session_files_edited: 2, files_created: 1, session_files_created: 1, - tools_usage: { Read: 5, Edit: 3 }, - session_tools_usage: { Read: 2, Edit: 1 }, + tools_usage: { Read: toolStats(5), Edit: toolStats(3) }, + session_tools_usage: { Read: toolStats(2), Edit: toolStats(1) }, session_duration_seconds: 300, + context_tokens_used: 500, + context_window_limit: 200000, + context_utilisation_percent: 0.25, + potential_cache_hits: 0, + potential_cache_savings_tokens: 0, }; stats.set(newStats); @@ -74,7 +102,8 @@ describe("stats store", () => { expect(currentStats.total_input_tokens).toBe(1000); expect(currentStats.total_output_tokens).toBe(2000); expect(currentStats.model).toBe("claude-sonnet-4"); - expect(currentStats.tools_usage).toEqual({ Read: 5, Edit: 3 }); + expect(currentStats.tools_usage.Read?.call_count).toBe(5); + expect(currentStats.tools_usage.Edit?.call_count).toBe(3); }); it("can be updated with update function", () => { @@ -109,9 +138,14 @@ describe("stats store", () => { session_files_edited: 2, files_created: 1, session_files_created: 1, - tools_usage: { Read: 5, Edit: 3 }, - session_tools_usage: { Read: 2, Edit: 1 }, + tools_usage: { Read: toolStats(5), Edit: toolStats(3) }, + session_tools_usage: { Read: toolStats(2), Edit: toolStats(1) }, session_duration_seconds: 300, + context_tokens_used: 500, + context_window_limit: 200000, + context_utilisation_percent: 0.25, + potential_cache_hits: 0, + potential_cache_savings_tokens: 0, }); // Reset session stats @@ -127,7 +161,8 @@ describe("stats store", () => { expect(currentStats.code_blocks_generated).toBe(3); expect(currentStats.files_edited).toBe(5); expect(currentStats.files_created).toBe(1); - expect(currentStats.tools_usage).toEqual({ Read: 5, Edit: 3 }); + expect(currentStats.tools_usage.Read?.call_count).toBe(5); + expect(currentStats.tools_usage.Edit?.call_count).toBe(3); expect(currentStats.model).toBe("claude-sonnet-4"); // Session stats should be reset @@ -277,8 +312,8 @@ describe("stats store", () => { }); it("exposes tools usage directly", () => { - const toolsUsage = { Read: 10, Edit: 5, Write: 3 }; - const sessionToolsUsage = { Read: 2, Edit: 1 }; + const toolsUsage = { Read: toolStats(10), Edit: toolStats(5), Write: toolStats(3) }; + const sessionToolsUsage = { Read: toolStats(2), Edit: toolStats(1) }; stats.update((current) => ({ ...current, @@ -331,9 +366,14 @@ describe("stats store", () => { session_files_edited: 1, files_created: 1, session_files_created: 0, - tools_usage: { Read: 3 }, - session_tools_usage: { Read: 1 }, + tools_usage: { Read: toolStats(3) }, + session_tools_usage: { Read: toolStats(1) }, session_duration_seconds: 60, + context_tokens_used: 50, + context_window_limit: 200000, + context_utilisation_percent: 0.025, + potential_cache_hits: 0, + potential_cache_savings_tokens: 0, }; stats.set(fullStats); @@ -343,4 +383,221 @@ describe("stats store", () => { expect(currentStats).toEqual(fullStats); }); }); + + describe("context window tracking", () => { + it("tracks context tokens used", () => { + stats.update((current) => ({ + ...current, + context_tokens_used: 100000, + context_window_limit: 200000, + context_utilisation_percent: 50.0, + })); + + const currentStats = get(stats); + expect(currentStats.context_tokens_used).toBe(100000); + expect(currentStats.context_window_limit).toBe(200000); + expect(currentStats.context_utilisation_percent).toBe(50.0); + }); + + it("formats context stats correctly", () => { + stats.update((current) => ({ + ...current, + context_tokens_used: 150000, + context_window_limit: 200000, + context_utilisation_percent: 75.5, + })); + + const formatted = get(formattedStats); + expect(formatted.contextUsed).toBe("150,000"); + expect(formatted.contextLimit).toBe("200,000"); + expect(formatted.contextRemaining).toBe("50,000"); + expect(formatted.contextUtilisation).toBe("75.5%"); + }); + + it("calculates remaining tokens correctly at limit", () => { + stats.update((current) => ({ + ...current, + context_tokens_used: 200000, + context_window_limit: 200000, + context_utilisation_percent: 100.0, + })); + + const formatted = get(formattedStats); + expect(formatted.contextRemaining).toBe("0"); + }); + + it("handles over-limit gracefully", () => { + stats.update((current) => ({ + ...current, + context_tokens_used: 250000, + context_window_limit: 200000, + context_utilisation_percent: 125.0, + })); + + const formatted = get(formattedStats); + expect(formatted.contextRemaining).toBe("0"); + }); + }); + + describe("contextWarning derived store", () => { + it("returns null when under 50%", () => { + stats.update((current) => ({ + ...current, + context_utilisation_percent: 40.0, + })); + + const warning = get(contextWarning); + expect(warning).toBeNull(); + }); + + it("returns moderate when between 50-74%", () => { + stats.update((current) => ({ + ...current, + context_utilisation_percent: 60.0, + })); + + const warning = get(contextWarning); + expect(warning).toBe("moderate"); + }); + + it("returns high when between 75-89%", () => { + stats.update((current) => ({ + ...current, + context_utilisation_percent: 80.0, + })); + + const warning = get(contextWarning); + expect(warning).toBe("high"); + }); + + it("returns critical when 90%+", () => { + stats.update((current) => ({ + ...current, + context_utilisation_percent: 95.0, + })); + + const warning = get(contextWarning); + expect(warning).toBe("critical"); + }); + }); + + describe("getContextWarningMessage", () => { + it("returns correct message for moderate warning", () => { + const message = getContextWarningMessage("moderate"); + expect(message).toContain("50%+"); + expect(message).toContain("Consider starting a new conversation"); + }); + + it("returns correct message for high warning", () => { + const message = getContextWarningMessage("high"); + expect(message).toContain("75%+"); + expect(message).toContain("Responses may degrade"); + }); + + it("returns correct message for critical warning", () => { + const message = getContextWarningMessage("critical"); + expect(message).toContain("90%+"); + expect(message).toContain("Start a new conversation"); + }); + }); + + describe("formatTokenCount", () => { + it("formats small numbers directly", () => { + expect(formatTokenCount(0)).toBe("0"); + expect(formatTokenCount(100)).toBe("100"); + expect(formatTokenCount(999)).toBe("999"); + }); + + it("formats thousands with K suffix", () => { + expect(formatTokenCount(1000)).toBe("1.0K"); + expect(formatTokenCount(1500)).toBe("1.5K"); + expect(formatTokenCount(10000)).toBe("10.0K"); + expect(formatTokenCount(999999)).toBe("1000.0K"); + }); + + it("formats millions with M suffix", () => { + expect(formatTokenCount(1000000)).toBe("1.0M"); + expect(formatTokenCount(1500000)).toBe("1.5M"); + expect(formatTokenCount(10000000)).toBe("10.0M"); + }); + }); + + describe("estimateMessageCost", () => { + it("estimates tokens at ~4 chars per token", () => { + const result = estimateMessageCost("test", 0, null); // 4 chars = 1 token + expect(result.messageTokens).toBe(1); + }); + + it("rounds up partial tokens", () => { + const result = estimateMessageCost("a", 0, null); // 1 char rounds up to 1 token + expect(result.messageTokens).toBe(1); + + const result2 = estimateMessageCost("abcde", 0, null); // 5 chars = 2 tokens + expect(result2.messageTokens).toBe(2); + }); + + it("returns 0 tokens for empty string", () => { + const result = estimateMessageCost("", 0, null); + expect(result.messageTokens).toBe(0); + expect(result.estimatedCost).toBe(0); + }); + + it("adds context tokens to total", () => { + const result = estimateMessageCost("test", 1000, null); // 1 token + 1000 context + expect(result.messageTokens).toBe(1); + expect(result.totalInputTokens).toBe(1001); + }); + + it("calculates cost using Sonnet pricing by default", () => { + // 100 chars = 25 tokens, $3 per million input tokens + const result = estimateMessageCost("a".repeat(100), 0, null); + expect(result.messageTokens).toBe(25); + const expectedCost = (25 / 1_000_000) * 3.0; + expect(result.estimatedCost).toBeCloseTo(expectedCost, 8); + }); + + it("uses Opus pricing for Opus models", () => { + const result = estimateMessageCost("a".repeat(100), 0, "claude-opus-4-5-20251101"); + expect(result.messageTokens).toBe(25); + const expectedCost = (25 / 1_000_000) * 5.0; // Opus 4.5: $5 per million input + expect(result.estimatedCost).toBeCloseTo(expectedCost, 8); + }); + + it("uses Haiku pricing for Haiku models", () => { + const result = estimateMessageCost("a".repeat(100), 0, "claude-3-5-haiku-20241022"); + expect(result.messageTokens).toBe(25); + const expectedCost = (25 / 1_000_000) * 1.0; // Haiku: $1 per million + expect(result.estimatedCost).toBeCloseTo(expectedCost, 8); + }); + + it("falls back to Sonnet pricing for unknown models", () => { + const result = estimateMessageCost("a".repeat(100), 0, "unknown-model"); + expect(result.messageTokens).toBe(25); + const expectedCost = (25 / 1_000_000) * 3.0; // Default Sonnet: $3 per million + expect(result.estimatedCost).toBeCloseTo(expectedCost, 8); + }); + }); + + describe("MODEL_PRICING", () => { + it("contains expected Opus pricing", () => { + // Opus 4.5 has reduced pricing + expect(MODEL_PRICING["claude-opus-4-5-20251101"]).toEqual({ input: 5.0, output: 25.0 }); + // Previous Opus models have higher pricing + expect(MODEL_PRICING["claude-opus-4-1-20250805"]).toEqual({ input: 15.0, output: 75.0 }); + expect(MODEL_PRICING["claude-opus-4-20250514"]).toEqual({ input: 15.0, output: 75.0 }); + }); + + it("contains expected Sonnet pricing", () => { + expect(MODEL_PRICING["claude-sonnet-4-5-20250929"]).toEqual({ input: 3.0, output: 15.0 }); + expect(MODEL_PRICING["claude-sonnet-4-20250514"]).toEqual({ input: 3.0, output: 15.0 }); + expect(MODEL_PRICING["claude-3-7-sonnet-20250219"]).toEqual({ input: 3.0, output: 15.0 }); + expect(MODEL_PRICING["claude-3-5-sonnet-20241022"]).toEqual({ input: 3.0, output: 15.0 }); + }); + + it("contains expected Haiku pricing", () => { + expect(MODEL_PRICING["claude-haiku-4-5-20251001"]).toEqual({ input: 1.0, output: 5.0 }); + expect(MODEL_PRICING["claude-3-5-haiku-20241022"]).toEqual({ input: 1.0, output: 5.0 }); + expect(MODEL_PRICING["claude-3-haiku-20240307"]).toEqual({ input: 0.25, output: 1.25 }); + }); + }); }); diff --git a/src/lib/stores/stats.ts b/src/lib/stores/stats.ts index 59522b0..c48453a 100644 --- a/src/lib/stores/stats.ts +++ b/src/lib/stores/stats.ts @@ -1,6 +1,66 @@ import { writable, derived } from "svelte/store"; import { listen } from "@tauri-apps/api/event"; import { invoke } from "@tauri-apps/api/core"; +import { costTrackingStore } from "./costTracking"; + +export type ContextWarning = "moderate" | "high" | "critical"; +export type BudgetType = "token" | "cost"; + +// Model pricing (per million tokens) - keep in sync with stats.rs +// Source: https://platform.claude.com/docs/en/about-claude/models/overview +export const MODEL_PRICING: Record = { + // Current generation (Claude 4.5) + "claude-opus-4-5-20251101": { input: 5.0, output: 25.0 }, + "claude-sonnet-4-5-20250929": { input: 3.0, output: 15.0 }, + "claude-haiku-4-5-20251001": { input: 1.0, output: 5.0 }, + // Previous generation (Claude 4.x) + "claude-opus-4-1-20250805": { input: 15.0, output: 75.0 }, + "claude-opus-4-20250514": { input: 15.0, output: 75.0 }, + "claude-sonnet-4-20250514": { input: 3.0, output: 15.0 }, + // Legacy (Claude 3.x) + "claude-3-7-sonnet-20250219": { input: 3.0, output: 15.0 }, + "claude-3-5-sonnet-20241022": { input: 3.0, output: 15.0 }, + "claude-3-5-sonnet-20240620": { input: 3.0, output: 15.0 }, + "claude-3-5-haiku-20241022": { input: 1.0, output: 5.0 }, + "claude-3-opus-20240229": { input: 15.0, output: 75.0 }, + "claude-3-sonnet-20240229": { input: 3.0, output: 15.0 }, + "claude-3-haiku-20240307": { input: 0.25, output: 1.25 }, +}; + +const DEFAULT_PRICING = { input: 3.0, output: 15.0 }; // Default to Sonnet + +export interface CostEstimate { + messageTokens: number; + totalInputTokens: number; + estimatedCost: number; +} + +// Estimate cost for a message before sending +export function estimateMessageCost( + messageText: string, + contextTokensUsed: number, + model: string | null +): CostEstimate { + // Estimate tokens using ~4 chars per token heuristic + const messageTokens = Math.ceil(messageText.length / 4); + const totalInputTokens = contextTokensUsed + messageTokens; + + const pricing = model ? (MODEL_PRICING[model] ?? DEFAULT_PRICING) : DEFAULT_PRICING; + const estimatedCost = (totalInputTokens / 1_000_000) * pricing.input; + + return { messageTokens, totalInputTokens, estimatedCost }; +} +export type BudgetStatus = + | { type: "ok" } + | { type: "warning"; budget_type: BudgetType; percent_used: number } + | { type: "exceeded"; budget_type: BudgetType }; + +// Per-tool token usage statistics +export interface ToolTokenStats { + call_count: number; + estimated_input_tokens: number; + estimated_output_tokens: number; +} export interface UsageStats { total_input_tokens: number; @@ -20,9 +80,18 @@ export interface UsageStats { session_files_edited: number; files_created: number; session_files_created: number; - tools_usage: Record; - session_tools_usage: Record; + tools_usage: Record; + session_tools_usage: Record; session_duration_seconds: number; + + // Context window tracking + context_tokens_used: number; + context_window_limit: number; + context_utilisation_percent: number; + + // Cache analytics (tracks potential savings from repeated tool calls) + potential_cache_hits: number; + potential_cache_savings_tokens: number; } // Main stats store @@ -45,8 +114,24 @@ export const stats = writable({ tools_usage: {}, session_tools_usage: {}, session_duration_seconds: 0, + context_tokens_used: 0, + context_window_limit: 200000, + context_utilisation_percent: 0, + potential_cache_hits: 0, + potential_cache_savings_tokens: 0, }); +// Format token count with K/M suffix +export function formatTokenCount(tokens: number): string { + if (tokens >= 1000000) { + return `${(tokens / 1000000).toFixed(1)}M`; + } + if (tokens >= 1000) { + return `${(tokens / 1000).toFixed(1)}K`; + } + return tokens.toString(); +} + // Derived store for formatted display values export const formattedStats = derived(stats, ($stats) => { const formatNumber = (num: number) => num.toLocaleString(); @@ -65,6 +150,20 @@ export const formattedStats = derived(stats, ($stats) => { } }; + // Format tool stats with token info + const formatToolStats = (toolStats: Record) => { + return Object.entries(toolStats).map(([name, stats]) => ({ + name, + callCount: stats.call_count, + totalTokens: stats.estimated_input_tokens + stats.estimated_output_tokens, + formattedTokens: formatTokenCount( + stats.estimated_input_tokens + stats.estimated_output_tokens + ), + inputTokens: stats.estimated_input_tokens, + outputTokens: stats.estimated_output_tokens, + })); + }; + return { totalTokens: formatNumber($stats.total_input_tokens + $stats.total_output_tokens), totalInputTokens: formatNumber($stats.total_input_tokens), @@ -88,9 +187,116 @@ export const formattedStats = derived(stats, ($stats) => { sessionDuration: formatDuration($stats.session_duration_seconds), toolsUsage: $stats.tools_usage, sessionToolsUsage: $stats.session_tools_usage, + // Formatted tool stats with token info + sessionToolsFormatted: formatToolStats($stats.session_tools_usage), + toolsFormatted: formatToolStats($stats.tools_usage), + + // Context window tracking + contextUsed: formatNumber($stats.context_tokens_used), + contextLimit: formatNumber($stats.context_window_limit), + contextRemaining: formatNumber( + Math.max(0, $stats.context_window_limit - $stats.context_tokens_used) + ), + contextUtilisation: `${$stats.context_utilisation_percent.toFixed(1)}%`, }; }); +// Derived store for context warning state +export const contextWarning = derived(stats, ($stats): ContextWarning | null => { + if ($stats.context_utilisation_percent >= 90) { + return "critical"; + } else if ($stats.context_utilisation_percent >= 75) { + return "high"; + } else if ($stats.context_utilisation_percent >= 50) { + return "moderate"; + } + return null; +}); + +// Get warning message for context utilisation +export function getContextWarningMessage(warning: ContextWarning): string { + switch (warning) { + case "moderate": + return "Context window is 50%+ full. Consider starting a new conversation for better performance."; + case "high": + return "Context window is 75%+ full. Responses may degrade. Consider summarising or starting fresh."; + case "critical": + return "Context window is nearly full (90%+)! Start a new conversation to avoid errors."; + } +} + +// Budget checking functions +export function checkBudget( + stats: UsageStats, + budgetEnabled: boolean, + tokenBudget: number | null, + costBudget: number | null, + warningThreshold: number +): BudgetStatus { + if (!budgetEnabled) { + return { type: "ok" }; + } + + const sessionTokens = stats.session_input_tokens + stats.session_output_tokens; + + // Check token budget + if (tokenBudget !== null) { + if (sessionTokens >= tokenBudget) { + return { type: "exceeded", budget_type: "token" }; + } + const percentUsed = sessionTokens / tokenBudget; + if (percentUsed >= warningThreshold) { + return { type: "warning", budget_type: "token", percent_used: percentUsed * 100 }; + } + } + + // Check cost budget + if (costBudget !== null) { + if (stats.session_cost_usd >= costBudget) { + return { type: "exceeded", budget_type: "cost" }; + } + const percentUsed = stats.session_cost_usd / costBudget; + if (percentUsed >= warningThreshold) { + return { type: "warning", budget_type: "cost", percent_used: percentUsed * 100 }; + } + } + + return { type: "ok" }; +} + +// Get budget status message +export function getBudgetStatusMessage(status: BudgetStatus): string | null { + if (status.type === "ok") { + return null; + } + + const budgetTypeLabel = status.budget_type === "token" ? "token" : "cost"; + + if (status.type === "exceeded") { + return `Session ${budgetTypeLabel} budget exceeded! Consider starting a new session.`; + } + + return `Approaching ${budgetTypeLabel} budget limit (${status.percent_used.toFixed(0)}% used).`; +} + +// Get remaining budget values +export function getRemainingTokenBudget( + stats: UsageStats, + tokenBudget: number | null +): number | null { + if (tokenBudget === null) return null; + const used = stats.session_input_tokens + stats.session_output_tokens; + return Math.max(0, tokenBudget - used); +} + +export function getRemainingCostBudget( + stats: UsageStats, + costBudget: number | null +): number | null { + if (costBudget === null) return null; + return Math.max(0, costBudget - stats.session_cost_usd); +} + // Note: Cost calculation is now done in the Rust backend // Initialize stats listener @@ -102,6 +308,9 @@ export async function initStatsListener() { // The backend already tracks all totals - just set the stats directly stats.set(newStats); + + // Refresh cost tracking to check for alerts (debounced - won't spam) + costTrackingStore.refresh(); }); // Load initial persisted stats from backend (no bridge required) diff --git a/src/lib/tauri.ts b/src/lib/tauri.ts index d14ded8..ad7dd01 100644 --- a/src/lib/tauri.ts +++ b/src/lib/tauri.ts @@ -90,6 +90,11 @@ interface OutputPayload { content: string; tool_name: string | null; conversation_id?: string; + cost?: { + input_tokens: number; + output_tokens: number; + cost_usd: number; + }; } interface ConnectionPayload { @@ -242,7 +247,16 @@ export async function initializeTauriListeners() { unlisteners.push(stateUnlisten); const outputUnlisten = await listen("claude:output", (event) => { - const { line_type, content, tool_name, conversation_id } = event.payload; + const { line_type, content, tool_name, conversation_id, cost } = event.payload; + + // Convert snake_case cost to camelCase for TypeScript + const costData = cost + ? { + inputTokens: cost.input_tokens, + outputTokens: cost.output_tokens, + costUsd: cost.cost_usd, + } + : undefined; // Always store the output to the correct conversation if (conversation_id) { @@ -250,14 +264,16 @@ export async function initializeTauriListeners() { conversation_id, line_type as "user" | "assistant" | "system" | "tool" | "error", content, - tool_name || undefined + tool_name || undefined, + costData ); } else { // Fallback to active conversation if no conversation_id provided claudeStore.addLine( line_type as "user" | "assistant" | "system" | "tool" | "error", content, - tool_name || undefined + tool_name || undefined, + costData ); } }); diff --git a/src/lib/types/messages.ts b/src/lib/types/messages.ts index 8e85797..4be557d 100644 --- a/src/lib/types/messages.ts +++ b/src/lib/types/messages.ts @@ -4,6 +4,12 @@ export interface TerminalLine { content: string; timestamp: Date; toolName?: string; + // Cost tracking for this specific message + cost?: { + inputTokens: number; + outputTokens: number; + costUsd: number; + }; } export interface SystemInitMessage { diff --git a/src/lib/utils/conversationUtils.test.ts b/src/lib/utils/conversationUtils.test.ts new file mode 100644 index 0000000..84f81c5 --- /dev/null +++ b/src/lib/utils/conversationUtils.test.ts @@ -0,0 +1,188 @@ +import { describe, it, expect } from "vitest"; +import { + generateSummaryPrompt, + generateContextInjection, + estimateTokens, + createSummary, + shouldSuggestCompaction, + formatTokenCount, + sanitizeForJson, +} from "./conversationUtils"; +import type { ConversationSummary } from "$lib/stores/conversations"; + +describe("conversationUtils", () => { + describe("generateSummaryPrompt", () => { + it("generates a prompt containing the conversation content", () => { + const content = "User: Hello\n\nAssistant: Hi there!"; + const prompt = generateSummaryPrompt(content); + + expect(prompt).toContain(content); + expect(prompt).toContain("summary"); + expect(prompt).toContain("Key topics"); + }); + + it("handles empty content", () => { + const prompt = generateSummaryPrompt(""); + expect(prompt).toContain("summary"); + }); + }); + + describe("generateContextInjection", () => { + it("creates context injection message from summary", () => { + const summary: ConversationSummary = { + generatedAt: new Date("2024-01-01"), + content: "We discussed building a new feature", + messageCount: 50, + tokenEstimate: 10000, + }; + + const injection = generateContextInjection(summary); + + expect(injection).toContain("Previous Session Context"); + expect(injection).toContain("We discussed building a new feature"); + expect(injection).toContain("50 messages"); + expect(injection).toContain("10,000 tokens"); + }); + }); + + describe("estimateTokens", () => { + it("estimates tokens at ~4 chars per token", () => { + expect(estimateTokens("")).toBe(0); + expect(estimateTokens("test")).toBe(1); // 4 chars = 1 token + expect(estimateTokens("testing")).toBe(2); // 7 chars = 2 tokens + expect(estimateTokens("a".repeat(100))).toBe(25); // 100 chars = 25 tokens + }); + + it("rounds up partial tokens", () => { + expect(estimateTokens("a")).toBe(1); // 1 char rounds up to 1 token + expect(estimateTokens("ab")).toBe(1); // 2 chars rounds up to 1 token + expect(estimateTokens("abc")).toBe(1); // 3 chars rounds up to 1 token + expect(estimateTokens("abcde")).toBe(2); // 5 chars rounds up to 2 tokens + }); + }); + + describe("createSummary", () => { + it("creates a valid ConversationSummary object", () => { + const summary = createSummary("Test summary content", 25, 5000); + + expect(summary.content).toBe("Test summary content"); + expect(summary.messageCount).toBe(25); + expect(summary.tokenEstimate).toBe(5000); + expect(summary.generatedAt).toBeInstanceOf(Date); + }); + + it("sets generatedAt to current time", () => { + const before = new Date(); + const summary = createSummary("content", 10, 1000); + const after = new Date(); + + expect(summary.generatedAt.getTime()).toBeGreaterThanOrEqual(before.getTime()); + expect(summary.generatedAt.getTime()).toBeLessThanOrEqual(after.getTime()); + }); + }); + + describe("shouldSuggestCompaction", () => { + it("returns false when under threshold", () => { + expect(shouldSuggestCompaction(50000, 200000, 60)).toBe(false); // 25% + expect(shouldSuggestCompaction(100000, 200000, 60)).toBe(false); // 50% + expect(shouldSuggestCompaction(119000, 200000, 60)).toBe(false); // 59.5% + }); + + it("returns true when at or above threshold", () => { + expect(shouldSuggestCompaction(120000, 200000, 60)).toBe(true); // 60% + expect(shouldSuggestCompaction(150000, 200000, 60)).toBe(true); // 75% + expect(shouldSuggestCompaction(200000, 200000, 60)).toBe(true); // 100% + }); + + it("handles zero context window limit", () => { + expect(shouldSuggestCompaction(50000, 0, 60)).toBe(false); + }); + + it("uses default threshold of 60%", () => { + expect(shouldSuggestCompaction(110000, 200000)).toBe(false); // 55% + expect(shouldSuggestCompaction(130000, 200000)).toBe(true); // 65% + }); + + it("respects custom threshold", () => { + expect(shouldSuggestCompaction(70000, 200000, 40)).toBe(false); // 35% + expect(shouldSuggestCompaction(90000, 200000, 40)).toBe(true); // 45% + }); + }); + + describe("formatTokenCount", () => { + it("formats small numbers directly", () => { + expect(formatTokenCount(0)).toBe("0"); + expect(formatTokenCount(100)).toBe("100"); + expect(formatTokenCount(999)).toBe("999"); + }); + + it("formats thousands with K suffix", () => { + expect(formatTokenCount(1000)).toBe("1.0K"); + expect(formatTokenCount(1500)).toBe("1.5K"); + expect(formatTokenCount(10000)).toBe("10.0K"); + expect(formatTokenCount(999999)).toBe("1000.0K"); + }); + + it("formats millions with M suffix", () => { + expect(formatTokenCount(1000000)).toBe("1.0M"); + expect(formatTokenCount(1500000)).toBe("1.5M"); + expect(formatTokenCount(10000000)).toBe("10.0M"); + }); + }); + + describe("sanitizeForJson", () => { + it("returns normal text unchanged", () => { + expect(sanitizeForJson("Hello world")).toBe("Hello world"); + expect(sanitizeForJson("Test 123")).toBe("Test 123"); + }); + + it("preserves common whitespace", () => { + expect(sanitizeForJson("line1\nline2")).toBe("line1\nline2"); + expect(sanitizeForJson("col1\tcol2")).toBe("col1\tcol2"); + expect(sanitizeForJson("line\r\nend")).toBe("line\r\nend"); + }); + + it("removes null bytes", () => { + expect(sanitizeForJson("hello\x00world")).toBe("helloworld"); + }); + + it("removes other control characters", () => { + // Bell character + expect(sanitizeForJson("alert\x07here")).toBe("alerthere"); + // Backspace + expect(sanitizeForJson("back\x08space")).toBe("backspace"); + // Form feed is removed + expect(sanitizeForJson("page\x0Cbreak")).toBe("pagebreak"); + // Escape character + expect(sanitizeForJson("esc\x1Bhere")).toBe("eschere"); + }); + + it("preserves printable characters including backslashes", () => { + const codeContent = '```rust\nfn main() {\n println!("Hello");\n}\n```'; + expect(sanitizeForJson(codeContent)).toBe(codeContent); + }); + + it("handles mixed content with various characters", () => { + const mixed = "User: Hello\n\nAssistant: Here's some code:\n```\nconst x = 42;\n```"; + expect(sanitizeForJson(mixed)).toBe(mixed); + }); + + it("preserves backslash sequences", () => { + // Backslashes followed by letters should be preserved as-is + expect(sanitizeForJson("path\\to\\file")).toBe("path\\to\\file"); + expect(sanitizeForJson("color\\x1b")).toBe("color\\x1b"); + }); + + it("removes lone surrogates", () => { + // Lone surrogates (U+D800-U+DFFF) can cause JSON parse errors + // High surrogate without low + expect(sanitizeForJson("test\uD800end")).toBe("testend"); + // Low surrogate without high + expect(sanitizeForJson("test\uDC00end")).toBe("testend"); + // But valid surrogate pairs should remain (they form valid characters) + // Actually, JavaScript represents emoji as surrogate pairs, so this is tricky + // The regex will remove the surrogates, which may break emoji. That's acceptable + // for a conversation summary where data integrity is more important. + }); + }); +}); diff --git a/src/lib/utils/conversationUtils.ts b/src/lib/utils/conversationUtils.ts new file mode 100644 index 0000000..67004da --- /dev/null +++ b/src/lib/utils/conversationUtils.ts @@ -0,0 +1,110 @@ +import type { ConversationSummary } from "$lib/stores/conversations"; + +/** + * Sanitises a string for safe JSON serialization through Tauri IPC. + * Removes control characters and lone surrogates that could cause issues + * during JSON serialization/deserialization. + */ +export function sanitizeForJson(text: string): string { + // Remove control characters except for common whitespace (tab, newline, carriage return) + // These can cause JSON parsing issues and are rarely meaningful in conversation summaries. + // eslint-disable-next-line no-control-regex -- regex uses control character codes + let sanitized = text.replace(/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/g, ""); + + // Remove extended ASCII control chars (C1 control codes) + sanitized = sanitized.replace(/[\x80-\x9F]/g, ""); + + // Remove lone surrogates (U+D800 to U+DFFF) which cause "unexpected end of hex escape" + // errors in serde_json when they appear without proper pairing. + // These are invalid in JSON and can cause parse failures. + sanitized = sanitized.replace(/[\uD800-\uDFFF]/g, ""); + + return sanitized; +} + +/** + * Generates a prompt to ask Claude to summarise a conversation. + * This can be sent as a user message to get a summary. + */ +export function generateSummaryPrompt(conversationContent: string): string { + return `Please provide a concise summary of our conversation so far. Focus on: +1. Key topics discussed +2. Important decisions or conclusions made +3. Any ongoing tasks or context that would be helpful to remember +4. Code changes or files that were modified + +Keep the summary brief but comprehensive enough to continue our work in a new session. + +Here is our conversation: + +${conversationContent} + +Please provide the summary now:`; +} + +/** + * Generates a context injection message to prepend to a new conversation. + * This provides Claude with context from a previous session. + */ +export function generateContextInjection(summary: ConversationSummary): string { + return `[Previous Session Context] +The following is a summary from our previous conversation (${summary.messageCount} messages, approximately ${summary.tokenEstimate.toLocaleString()} tokens): + +${summary.content} + +[End of Previous Context] + +Please continue from where we left off, or let me know if you need any clarification about the previous context.`; +} + +/** + * Estimates the token count for a given string. + * Uses a rough approximation of ~4 characters per token. + */ +export function estimateTokens(text: string): number { + return Math.ceil(text.length / 4); +} + +/** + * Creates a ConversationSummary object from summary content. + */ +export function createSummary( + content: string, + messageCount: number, + originalTokenEstimate: number +): ConversationSummary { + return { + generatedAt: new Date(), + content, + messageCount, + tokenEstimate: originalTokenEstimate, + }; +} + +/** + * Determines if a conversation should be compacted based on token usage. + * Returns true if the conversation is using more than the threshold percentage + * of the context window. + */ +export function shouldSuggestCompaction( + contextTokensUsed: number, + contextWindowLimit: number, + thresholdPercent: number = 60 +): boolean { + if (contextWindowLimit === 0) return false; + const utilisationPercent = (contextTokensUsed / contextWindowLimit) * 100; + return utilisationPercent >= thresholdPercent; +} + +/** + * Formats a token count for display. + */ +export function formatTokenCount(tokens: number): string { + if (tokens >= 1000000) { + return `${(tokens / 1000000).toFixed(1)}M`; + } + if (tokens >= 1000) { + return `${(tokens / 1000).toFixed(1)}K`; + } + return tokens.toString(); +}