generated from nhcarrigan/template
267 lines
8.6 KiB
Rust
267 lines
8.6 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use std::collections::hash_map::DefaultHasher;
|
|
use std::collections::HashMap;
|
|
use std::hash::{Hash, Hasher};
|
|
|
|
/// Tools that could benefit from caching
|
|
#[allow(dead_code)]
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum CacheableTool {
|
|
Read,
|
|
Glob,
|
|
Grep,
|
|
}
|
|
|
|
impl CacheableTool {
|
|
#[allow(dead_code)]
|
|
pub fn from_name(name: &str) -> Option<Self> {
|
|
match name {
|
|
"Read" => Some(Self::Read),
|
|
"Glob" => Some(Self::Glob),
|
|
"Grep" => Some(Self::Grep),
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Statistics about potential cache savings
|
|
#[allow(dead_code)]
|
|
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
|
pub struct CacheAnalytics {
|
|
/// Number of tool calls that could have been cache hits
|
|
pub potential_cache_hits: u64,
|
|
/// Estimated tokens that could have been saved
|
|
pub potential_savings_tokens: u64,
|
|
/// Tracks unique tool invocations: hash -> (tool_name, call_count)
|
|
#[serde(skip)]
|
|
recent_invocations: HashMap<u64, (String, u64)>,
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
impl CacheAnalytics {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
/// Compute a hash key from tool name and input
|
|
fn compute_key(tool_name: &str, input: &serde_json::Value) -> u64 {
|
|
let mut hasher = DefaultHasher::new();
|
|
tool_name.hash(&mut hasher);
|
|
input.to_string().hash(&mut hasher);
|
|
hasher.finish()
|
|
}
|
|
|
|
/// Track a tool invocation for analytics
|
|
/// Returns true if this was a repeated invocation (potential cache hit)
|
|
pub fn track_invocation(
|
|
&mut self,
|
|
tool_name: &str,
|
|
input: &serde_json::Value,
|
|
estimated_tokens: u64,
|
|
) -> bool {
|
|
// Only track cacheable tools
|
|
if CacheableTool::from_name(tool_name).is_none() {
|
|
return false;
|
|
}
|
|
|
|
let key = Self::compute_key(tool_name, input);
|
|
|
|
if let Some((_, count)) = self.recent_invocations.get_mut(&key) {
|
|
*count += 1;
|
|
// This is a repeat - could have been a cache hit
|
|
self.potential_cache_hits += 1;
|
|
self.potential_savings_tokens += estimated_tokens;
|
|
true
|
|
} else {
|
|
self.recent_invocations
|
|
.insert(key, (tool_name.to_string(), 1));
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Get the number of unique tool invocations being tracked
|
|
pub fn unique_invocations(&self) -> usize {
|
|
self.recent_invocations.len()
|
|
}
|
|
|
|
/// Get invocations that were called more than once
|
|
pub fn repeated_invocations(&self) -> Vec<(&str, u64)> {
|
|
self.recent_invocations
|
|
.values()
|
|
.filter(|(_, count)| *count > 1)
|
|
.map(|(name, count)| (name.as_str(), *count))
|
|
.collect()
|
|
}
|
|
|
|
/// Clear session analytics (keep totals)
|
|
pub fn clear_session(&mut self) {
|
|
self.recent_invocations.clear();
|
|
}
|
|
|
|
/// Fully reset all analytics
|
|
pub fn reset(&mut self) {
|
|
self.potential_cache_hits = 0;
|
|
self.potential_savings_tokens = 0;
|
|
self.recent_invocations.clear();
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use serde_json::json;
|
|
|
|
#[test]
|
|
fn test_cacheable_tool_from_name() {
|
|
assert_eq!(CacheableTool::from_name("Read"), Some(CacheableTool::Read));
|
|
assert_eq!(CacheableTool::from_name("Glob"), Some(CacheableTool::Glob));
|
|
assert_eq!(CacheableTool::from_name("Grep"), Some(CacheableTool::Grep));
|
|
assert_eq!(CacheableTool::from_name("Bash"), None);
|
|
assert_eq!(CacheableTool::from_name("Edit"), None);
|
|
assert_eq!(CacheableTool::from_name("Write"), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_first_invocation_not_cache_hit() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input = json!({"file_path": "/home/test/file.txt"});
|
|
|
|
let is_repeat = analytics.track_invocation("Read", &input, 100);
|
|
|
|
assert!(!is_repeat);
|
|
assert_eq!(analytics.potential_cache_hits, 0);
|
|
assert_eq!(analytics.potential_savings_tokens, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_second_invocation_is_cache_hit() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input = json!({"file_path": "/home/test/file.txt"});
|
|
|
|
analytics.track_invocation("Read", &input, 100);
|
|
let is_repeat = analytics.track_invocation("Read", &input, 100);
|
|
|
|
assert!(is_repeat);
|
|
assert_eq!(analytics.potential_cache_hits, 1);
|
|
assert_eq!(analytics.potential_savings_tokens, 100);
|
|
}
|
|
|
|
#[test]
|
|
fn test_different_inputs_not_cache_hit() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input1 = json!({"file_path": "/home/test/file1.txt"});
|
|
let input2 = json!({"file_path": "/home/test/file2.txt"});
|
|
|
|
analytics.track_invocation("Read", &input1, 100);
|
|
let is_repeat = analytics.track_invocation("Read", &input2, 100);
|
|
|
|
assert!(!is_repeat);
|
|
assert_eq!(analytics.potential_cache_hits, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_non_cacheable_tool_ignored() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input = json!({"command": "ls -la"});
|
|
|
|
let is_repeat = analytics.track_invocation("Bash", &input, 100);
|
|
analytics.track_invocation("Bash", &input, 100);
|
|
|
|
assert!(!is_repeat);
|
|
assert_eq!(analytics.potential_cache_hits, 0);
|
|
assert_eq!(analytics.unique_invocations(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_repeated_invocations() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input = json!({"file_path": "/home/test/file.txt"});
|
|
|
|
analytics.track_invocation("Read", &input, 100);
|
|
analytics.track_invocation("Read", &input, 100);
|
|
analytics.track_invocation("Read", &input, 100);
|
|
|
|
assert_eq!(analytics.potential_cache_hits, 2);
|
|
assert_eq!(analytics.potential_savings_tokens, 200);
|
|
}
|
|
|
|
#[test]
|
|
fn test_unique_invocations_count() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
|
|
analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100);
|
|
analytics.track_invocation("Read", &json!({"file_path": "/file2.txt"}), 100);
|
|
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
|
|
|
assert_eq!(analytics.unique_invocations(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_repeated_invocations_list() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
|
|
// file1 read twice
|
|
analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100);
|
|
analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100);
|
|
|
|
// file2 read once
|
|
analytics.track_invocation("Read", &json!({"file_path": "/file2.txt"}), 100);
|
|
|
|
// glob run 3 times
|
|
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
|
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
|
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
|
|
|
let repeated = analytics.repeated_invocations();
|
|
assert_eq!(repeated.len(), 2); // file1 and glob pattern
|
|
}
|
|
|
|
#[test]
|
|
fn test_clear_session() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input = json!({"file_path": "/file.txt"});
|
|
|
|
analytics.track_invocation("Read", &input, 100);
|
|
analytics.track_invocation("Read", &input, 100);
|
|
|
|
assert_eq!(analytics.potential_cache_hits, 1);
|
|
assert_eq!(analytics.unique_invocations(), 1);
|
|
|
|
analytics.clear_session();
|
|
|
|
assert_eq!(analytics.potential_cache_hits, 1); // Preserved
|
|
assert_eq!(analytics.unique_invocations(), 0); // Cleared
|
|
}
|
|
|
|
#[test]
|
|
fn test_reset() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
let input = json!({"file_path": "/file.txt"});
|
|
|
|
analytics.track_invocation("Read", &input, 100);
|
|
analytics.track_invocation("Read", &input, 100);
|
|
|
|
analytics.reset();
|
|
|
|
assert_eq!(analytics.potential_cache_hits, 0);
|
|
assert_eq!(analytics.potential_savings_tokens, 0);
|
|
assert_eq!(analytics.unique_invocations(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_serialization() {
|
|
let mut analytics = CacheAnalytics::new();
|
|
analytics.potential_cache_hits = 10;
|
|
analytics.potential_savings_tokens = 500;
|
|
|
|
let json = serde_json::to_string(&analytics).expect("Failed to serialize");
|
|
let deserialized: CacheAnalytics =
|
|
serde_json::from_str(&json).expect("Failed to deserialize");
|
|
|
|
assert_eq!(deserialized.potential_cache_hits, 10);
|
|
assert_eq!(deserialized.potential_savings_tokens, 500);
|
|
// recent_invocations is skipped in serialization
|
|
assert_eq!(deserialized.unique_invocations(), 0);
|
|
}
|
|
}
|