generated from nhcarrigan/template
feat: massive overhaul to manage costs (#103)
### Explanation _No response_ ### Issue Closes #102 ### Attestations - [ ] I have read and agree to the [Code of Conduct](https://docs.nhcarrigan.com/community/coc/) - [ ] I have read and agree to the [Community Guidelines](https://docs.nhcarrigan.com/community/guide/). - [ ] My contribution complies with the [Contributor Covenant](https://docs.nhcarrigan.com/dev/covenant/). ### Dependencies - [ ] I have pinned the dependencies to a specific patch version. ### Style - [ ] I have run the linter and resolved any errors. - [ ] My pull request uses an appropriate title, matching the conventional commit standards. - [ ] My scope of feat/fix/chore/etc. correctly matches the nature of changes in my pull request. ### Tests - [ ] My contribution adds new code, and I have added tests to cover it. - [ ] My contribution modifies existing code, and I have updated the tests to reflect these changes. - [ ] All new and existing tests pass locally with my changes. - [ ] Code coverage remains at or above the configured threshold. ### Documentation _No response_ ### Versioning _No response_ Reviewed-on: #103 Co-authored-by: Naomi Carrigan <commits@nhcarrigan.com> Co-committed-by: Naomi Carrigan <commits@nhcarrigan.com>
This commit was merged in pull request #103.
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Tools that could benefit from caching
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CacheableTool {
|
||||
Read,
|
||||
Glob,
|
||||
Grep,
|
||||
}
|
||||
|
||||
impl CacheableTool {
|
||||
#[allow(dead_code)]
|
||||
pub fn from_name(name: &str) -> Option<Self> {
|
||||
match name {
|
||||
"Read" => Some(Self::Read),
|
||||
"Glob" => Some(Self::Glob),
|
||||
"Grep" => Some(Self::Grep),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics about potential cache savings
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheAnalytics {
|
||||
/// Number of tool calls that could have been cache hits
|
||||
pub potential_cache_hits: u64,
|
||||
/// Estimated tokens that could have been saved
|
||||
pub potential_savings_tokens: u64,
|
||||
/// Tracks unique tool invocations: hash -> (tool_name, call_count)
|
||||
#[serde(skip)]
|
||||
recent_invocations: HashMap<u64, (String, u64)>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl CacheAnalytics {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Compute a hash key from tool name and input
|
||||
fn compute_key(tool_name: &str, input: &serde_json::Value) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tool_name.hash(&mut hasher);
|
||||
input.to_string().hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Track a tool invocation for analytics
|
||||
/// Returns true if this was a repeated invocation (potential cache hit)
|
||||
pub fn track_invocation(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
input: &serde_json::Value,
|
||||
estimated_tokens: u64,
|
||||
) -> bool {
|
||||
// Only track cacheable tools
|
||||
if CacheableTool::from_name(tool_name).is_none() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let key = Self::compute_key(tool_name, input);
|
||||
|
||||
if let Some((_, count)) = self.recent_invocations.get_mut(&key) {
|
||||
*count += 1;
|
||||
// This is a repeat - could have been a cache hit
|
||||
self.potential_cache_hits += 1;
|
||||
self.potential_savings_tokens += estimated_tokens;
|
||||
true
|
||||
} else {
|
||||
self.recent_invocations
|
||||
.insert(key, (tool_name.to_string(), 1));
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of unique tool invocations being tracked
|
||||
pub fn unique_invocations(&self) -> usize {
|
||||
self.recent_invocations.len()
|
||||
}
|
||||
|
||||
/// Get invocations that were called more than once
|
||||
pub fn repeated_invocations(&self) -> Vec<(&str, u64)> {
|
||||
self.recent_invocations
|
||||
.values()
|
||||
.filter(|(_, count)| *count > 1)
|
||||
.map(|(name, count)| (name.as_str(), *count))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clear session analytics (keep totals)
|
||||
pub fn clear_session(&mut self) {
|
||||
self.recent_invocations.clear();
|
||||
}
|
||||
|
||||
/// Fully reset all analytics
|
||||
pub fn reset(&mut self) {
|
||||
self.potential_cache_hits = 0;
|
||||
self.potential_savings_tokens = 0;
|
||||
self.recent_invocations.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_cacheable_tool_from_name() {
|
||||
assert_eq!(CacheableTool::from_name("Read"), Some(CacheableTool::Read));
|
||||
assert_eq!(CacheableTool::from_name("Glob"), Some(CacheableTool::Glob));
|
||||
assert_eq!(CacheableTool::from_name("Grep"), Some(CacheableTool::Grep));
|
||||
assert_eq!(CacheableTool::from_name("Bash"), None);
|
||||
assert_eq!(CacheableTool::from_name("Edit"), None);
|
||||
assert_eq!(CacheableTool::from_name("Write"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_first_invocation_not_cache_hit() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input = json!({"file_path": "/home/test/file.txt"});
|
||||
|
||||
let is_repeat = analytics.track_invocation("Read", &input, 100);
|
||||
|
||||
assert!(!is_repeat);
|
||||
assert_eq!(analytics.potential_cache_hits, 0);
|
||||
assert_eq!(analytics.potential_savings_tokens, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_second_invocation_is_cache_hit() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input = json!({"file_path": "/home/test/file.txt"});
|
||||
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
let is_repeat = analytics.track_invocation("Read", &input, 100);
|
||||
|
||||
assert!(is_repeat);
|
||||
assert_eq!(analytics.potential_cache_hits, 1);
|
||||
assert_eq!(analytics.potential_savings_tokens, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_inputs_not_cache_hit() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input1 = json!({"file_path": "/home/test/file1.txt"});
|
||||
let input2 = json!({"file_path": "/home/test/file2.txt"});
|
||||
|
||||
analytics.track_invocation("Read", &input1, 100);
|
||||
let is_repeat = analytics.track_invocation("Read", &input2, 100);
|
||||
|
||||
assert!(!is_repeat);
|
||||
assert_eq!(analytics.potential_cache_hits, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_cacheable_tool_ignored() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input = json!({"command": "ls -la"});
|
||||
|
||||
let is_repeat = analytics.track_invocation("Bash", &input, 100);
|
||||
analytics.track_invocation("Bash", &input, 100);
|
||||
|
||||
assert!(!is_repeat);
|
||||
assert_eq!(analytics.potential_cache_hits, 0);
|
||||
assert_eq!(analytics.unique_invocations(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_repeated_invocations() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input = json!({"file_path": "/home/test/file.txt"});
|
||||
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
|
||||
assert_eq!(analytics.potential_cache_hits, 2);
|
||||
assert_eq!(analytics.potential_savings_tokens, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unique_invocations_count() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
|
||||
analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100);
|
||||
analytics.track_invocation("Read", &json!({"file_path": "/file2.txt"}), 100);
|
||||
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
||||
|
||||
assert_eq!(analytics.unique_invocations(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_repeated_invocations_list() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
|
||||
// file1 read twice
|
||||
analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100);
|
||||
analytics.track_invocation("Read", &json!({"file_path": "/file1.txt"}), 100);
|
||||
|
||||
// file2 read once
|
||||
analytics.track_invocation("Read", &json!({"file_path": "/file2.txt"}), 100);
|
||||
|
||||
// glob run 3 times
|
||||
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
||||
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
||||
analytics.track_invocation("Glob", &json!({"pattern": "*.rs"}), 50);
|
||||
|
||||
let repeated = analytics.repeated_invocations();
|
||||
assert_eq!(repeated.len(), 2); // file1 and glob pattern
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_session() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input = json!({"file_path": "/file.txt"});
|
||||
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
|
||||
assert_eq!(analytics.potential_cache_hits, 1);
|
||||
assert_eq!(analytics.unique_invocations(), 1);
|
||||
|
||||
analytics.clear_session();
|
||||
|
||||
assert_eq!(analytics.potential_cache_hits, 1); // Preserved
|
||||
assert_eq!(analytics.unique_invocations(), 0); // Cleared
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
let input = json!({"file_path": "/file.txt"});
|
||||
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
analytics.track_invocation("Read", &input, 100);
|
||||
|
||||
analytics.reset();
|
||||
|
||||
assert_eq!(analytics.potential_cache_hits, 0);
|
||||
assert_eq!(analytics.potential_savings_tokens, 0);
|
||||
assert_eq!(analytics.unique_invocations(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let mut analytics = CacheAnalytics::new();
|
||||
analytics.potential_cache_hits = 10;
|
||||
analytics.potential_savings_tokens = 500;
|
||||
|
||||
let json = serde_json::to_string(&analytics).expect("Failed to serialize");
|
||||
let deserialized: CacheAnalytics =
|
||||
serde_json::from_str(&json).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(deserialized.potential_cache_hits, 10);
|
||||
assert_eq!(deserialized.potential_savings_tokens, 500);
|
||||
// recent_invocations is skipped in serialization
|
||||
assert_eq!(deserialized.unique_invocations(), 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user