305 lines
6.9 KiB
Rust
305 lines
6.9 KiB
Rust
use dashmap::DashMap;
|
|
use std::collections::HashMap;
|
|
use std::hash::Hash;
|
|
use std::sync::Mutex;
|
|
use std::time::{Duration, Instant};
|
|
|
|
struct CacheEntry<V> {
|
|
value: V,
|
|
expires_at: Instant,
|
|
}
|
|
|
|
struct LruNode<K> {
|
|
key: Option<K>,
|
|
prev: usize,
|
|
next: usize,
|
|
}
|
|
|
|
struct LruTracker<K> {
|
|
nodes: Vec<LruNode<K>>,
|
|
key_to_idx: HashMap<K, usize>,
|
|
head: usize,
|
|
tail: usize,
|
|
}
|
|
|
|
impl<K: Eq + Hash + Clone> LruTracker<K> {
|
|
fn new() -> Self {
|
|
let sentinel = LruNode {
|
|
key: None,
|
|
prev: 0,
|
|
next: 0,
|
|
};
|
|
Self {
|
|
nodes: vec![sentinel],
|
|
key_to_idx: HashMap::new(),
|
|
head: 0,
|
|
tail: 0,
|
|
}
|
|
}
|
|
|
|
fn touch(&mut self, key: &K) {
|
|
if let Some(&idx) = self.key_to_idx.get(key) {
|
|
self.detach(idx);
|
|
self.attach_front(idx);
|
|
}
|
|
}
|
|
|
|
fn push_front(&mut self, key: K) -> usize {
|
|
let idx = self.nodes.len();
|
|
self.nodes.push(LruNode {
|
|
key: Some(key.clone()),
|
|
prev: 0,
|
|
next: 0,
|
|
});
|
|
self.key_to_idx.insert(key, idx);
|
|
self.attach_front(idx);
|
|
idx
|
|
}
|
|
|
|
fn pop_back(&mut self) -> Option<K> {
|
|
if self.tail == 0 {
|
|
return None;
|
|
}
|
|
let lru = self.tail;
|
|
let key = self.nodes[lru].key.take();
|
|
self.detach(lru);
|
|
if let Some(ref k) = key {
|
|
self.key_to_idx.remove(k);
|
|
}
|
|
key
|
|
}
|
|
|
|
fn remove(&mut self, key: &K) {
|
|
if let Some(&idx) = self.key_to_idx.get(key) {
|
|
self.detach(idx);
|
|
self.key_to_idx.remove(key);
|
|
}
|
|
}
|
|
|
|
fn clear(&mut self) {
|
|
self.key_to_idx.clear();
|
|
self.nodes.truncate(1);
|
|
self.head = 0;
|
|
self.tail = 0;
|
|
}
|
|
|
|
fn len(&self) -> usize {
|
|
self.key_to_idx.len()
|
|
}
|
|
|
|
fn detach(&mut self, idx: usize) {
|
|
let prev = self.nodes[idx].prev;
|
|
let next = self.nodes[idx].next;
|
|
|
|
if prev != 0 {
|
|
self.nodes[prev].next = next;
|
|
} else {
|
|
self.head = next;
|
|
}
|
|
|
|
if next != 0 {
|
|
self.nodes[next].prev = prev;
|
|
} else {
|
|
self.tail = prev;
|
|
}
|
|
}
|
|
|
|
fn attach_front(&mut self, idx: usize) {
|
|
self.nodes[idx].prev = 0;
|
|
self.nodes[idx].next = self.head;
|
|
|
|
if self.head != 0 {
|
|
self.nodes[self.head].prev = idx;
|
|
} else {
|
|
self.tail = idx;
|
|
}
|
|
|
|
self.head = idx;
|
|
}
|
|
}
|
|
|
|
pub struct LruTtlCache<K, V> {
|
|
map: DashMap<K, CacheEntry<V>>,
|
|
lru: Mutex<LruTracker<K>>,
|
|
capacity: usize,
|
|
ttl: Duration,
|
|
}
|
|
|
|
impl<K: Eq + Hash + Clone, V: Clone> LruTtlCache<K, V> {
|
|
pub fn new(capacity: usize, ttl: Duration) -> Self {
|
|
Self {
|
|
map: DashMap::with_capacity(capacity),
|
|
lru: Mutex::new(LruTracker::new()),
|
|
capacity,
|
|
ttl,
|
|
}
|
|
}
|
|
|
|
pub fn get(&self, key: &K) -> Option<V> {
|
|
let entry = self.map.get(key)?;
|
|
let expired = entry.expires_at <= Instant::now();
|
|
let value = entry.value.clone();
|
|
drop(entry);
|
|
|
|
if expired {
|
|
self.remove(key);
|
|
return None;
|
|
}
|
|
|
|
if let Ok(mut lru) = self.lru.lock() {
|
|
lru.touch(key);
|
|
}
|
|
|
|
Some(value)
|
|
}
|
|
|
|
pub fn insert(&self, key: K, value: V) {
|
|
self.insert_with_ttl(key, value, self.ttl);
|
|
}
|
|
|
|
pub fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) {
|
|
let now = Instant::now();
|
|
|
|
if self.map.contains_key(&key) {
|
|
self.map.insert(
|
|
key.clone(),
|
|
CacheEntry {
|
|
value,
|
|
expires_at: now + ttl,
|
|
},
|
|
);
|
|
if let Ok(mut lru) = self.lru.lock() {
|
|
lru.touch(&key);
|
|
}
|
|
return;
|
|
}
|
|
|
|
let mut lru = self.lru.lock().unwrap();
|
|
|
|
if lru.len() >= self.capacity
|
|
&& let Some(evicted_key) = lru.pop_back()
|
|
{
|
|
self.map.remove(&evicted_key);
|
|
}
|
|
|
|
self.map.insert(
|
|
key.clone(),
|
|
CacheEntry {
|
|
value,
|
|
expires_at: now + ttl,
|
|
},
|
|
);
|
|
lru.push_front(key);
|
|
}
|
|
|
|
pub fn remove(&self, key: &K) -> Option<V> {
|
|
if let Ok(mut lru) = self.lru.lock() {
|
|
lru.remove(key);
|
|
}
|
|
self.map.remove(key).map(|(_, entry)| entry.value)
|
|
}
|
|
|
|
pub fn contains(&self, key: &K) -> bool {
|
|
self.map.contains_key(key)
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.map.len()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.map.is_empty()
|
|
}
|
|
|
|
pub fn clear(&self) {
|
|
self.map.clear();
|
|
if let Ok(mut lru) = self.lru.lock() {
|
|
lru.clear();
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_insert_and_get() {
|
|
let cache = LruTtlCache::new(3, Duration::from_secs(60));
|
|
cache.insert("a", 1);
|
|
cache.insert("b", 2);
|
|
cache.insert("c", 3);
|
|
|
|
assert_eq!(cache.get(&"a"), Some(1));
|
|
assert_eq!(cache.get(&"b"), Some(2));
|
|
assert_eq!(cache.get(&"c"), Some(3));
|
|
}
|
|
|
|
#[test]
|
|
fn test_lru_eviction() {
|
|
let cache = LruTtlCache::new(2, Duration::from_secs(60));
|
|
cache.insert("a", 1);
|
|
cache.insert("b", 2);
|
|
cache.get(&"a");
|
|
cache.insert("c", 3);
|
|
|
|
assert_eq!(cache.get(&"a"), Some(1));
|
|
assert_eq!(cache.get(&"b"), None);
|
|
assert_eq!(cache.get(&"c"), Some(3));
|
|
}
|
|
|
|
#[test]
|
|
fn test_ttl_expiry() {
|
|
let cache = LruTtlCache::new(3, Duration::from_millis(10));
|
|
cache.insert("a", 1);
|
|
std::thread::sleep(Duration::from_millis(20));
|
|
|
|
assert_eq!(cache.get(&"a"), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_update_existing() {
|
|
let cache = LruTtlCache::new(3, Duration::from_secs(60));
|
|
cache.insert("a", 1);
|
|
cache.insert("a", 100);
|
|
|
|
assert_eq!(cache.get(&"a"), Some(100));
|
|
assert_eq!(cache.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_remove() {
|
|
let cache = LruTtlCache::new(3, Duration::from_secs(60));
|
|
cache.insert("a", 1);
|
|
cache.insert("b", 2);
|
|
|
|
assert_eq!(cache.remove(&"a"), Some(1));
|
|
assert_eq!(cache.get(&"a"), None);
|
|
assert_eq!(cache.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_concurrent_access() {
|
|
let cache = std::sync::Arc::new(LruTtlCache::new(10, Duration::from_secs(60)));
|
|
let c1 = cache.clone();
|
|
let c2 = cache.clone();
|
|
|
|
let t1 = std::thread::spawn(move || {
|
|
for i in 0..100 {
|
|
c1.insert(i, i * 2);
|
|
}
|
|
});
|
|
|
|
let t2 = std::thread::spawn(move || {
|
|
for i in 0..100 {
|
|
let _ = c2.get(&i);
|
|
}
|
|
});
|
|
|
|
t1.join().unwrap();
|
|
t2.join().unwrap();
|
|
|
|
assert!(cache.len() <= 10);
|
|
}
|
|
}
|