use std::time::Instant; use uuid::Uuid; use crate::cache::redis::AppRedis; use crate::error::AppResult; use ::redis::Cmd; use super::redis_keys::*; pub struct RateLimiter { redis: AppRedis, max_per_sec: u32, } impl RateLimiter { pub fn new(redis: AppRedis) -> Self { Self { redis, max_per_sec: WS_MAX_MESSAGES_PER_SEC, } } pub fn with_limit(redis: AppRedis, max_per_sec: u32) -> Self { Self { redis, max_per_sec } } pub fn check(&self, connection_id: Uuid) -> AppResult { let key = format!("{WS_RATE_PREFIX}{connection_id}"); let mut conn = self.redis.get_connection()?; let count: i64 = Cmd::new() .arg("INCR") .arg(&key) .query(&mut *conn.inner_mut()) .map_err(crate::error::AppError::Redis)?; if count == 1 { let _ = Cmd::new() .arg("EXPIRE") .arg(&key) .arg(1_u64) .query::<()>(&mut *conn.inner_mut()); } Ok(count <= self.max_per_sec as i64) } pub fn check_sliding(&self, connection_id: Uuid) -> AppResult { let key = format!("{WS_RATE_PREFIX}{connection_id}"); let mut conn = self.redis.get_connection()?; let count: i64 = Cmd::new() .arg("INCR") .arg(&key) .query(&mut *conn.inner_mut()) .map_err(crate::error::AppError::Redis)?; if count == 1 { let _ = Cmd::new() .arg("EXPIRE") .arg(&key) .arg(2_u64) .query::<()>(&mut *conn.inner_mut()); } Ok(count <= self.max_per_sec as i64) } pub fn remaining(&self, connection_id: Uuid) -> AppResult { let key = format!("{WS_RATE_PREFIX}{connection_id}"); let mut conn = self.redis.get_connection()?; let count: Option = Cmd::new() .arg("GET") .arg(&key) .query(&mut *conn.inner_mut()) .map_err(crate::error::AppError::Redis)?; Ok(self.max_per_sec.saturating_sub(count.unwrap_or(0) as u32)) } } pub struct LocalRateLimiter { count: std::sync::atomic::AtomicU32, start: std::sync::Mutex, max_per_sec: u32, } impl LocalRateLimiter { pub fn new(max_per_sec: u32) -> Self { Self { count: std::sync::atomic::AtomicU32::new(0), start: std::sync::Mutex::new(Instant::now()), max_per_sec, } } pub fn check(&self) -> bool { let mut start = self.start.lock().unwrap(); if start.elapsed().as_secs() >= 1 { self.count.store(0, std::sync::atomic::Ordering::Relaxed); *start = Instant::now(); } drop(start); self.count .fetch_add(1, std::sync::atomic::Ordering::Relaxed) < self.max_per_sec } }