103 lines
2.9 KiB
Rust
103 lines
2.9 KiB
Rust
use std::time::Instant;
|
|
|
|
use uuid::Uuid;
|
|
|
|
use crate::cache::redis::AppRedis;
|
|
use crate::error::AppResult;
|
|
use ::redis::Cmd;
|
|
|
|
use super::redis_keys::*;
|
|
|
|
pub struct RateLimiter {
|
|
redis: AppRedis,
|
|
max_per_sec: u32,
|
|
}
|
|
|
|
impl RateLimiter {
|
|
pub fn new(redis: AppRedis) -> Self {
|
|
Self {
|
|
redis,
|
|
max_per_sec: WS_MAX_MESSAGES_PER_SEC,
|
|
}
|
|
}
|
|
|
|
pub fn with_limit(redis: AppRedis, max_per_sec: u32) -> Self {
|
|
Self { redis, max_per_sec }
|
|
}
|
|
|
|
pub fn check(&self, connection_id: Uuid) -> AppResult<bool> {
|
|
let key = format!("{WS_RATE_PREFIX}{connection_id}");
|
|
let mut conn = self.redis.get_connection()?;
|
|
let count: i64 = Cmd::new()
|
|
.arg("INCR")
|
|
.arg(&key)
|
|
.query(&mut *conn.inner_mut())
|
|
.map_err(crate::error::AppError::Redis)?;
|
|
if count == 1 {
|
|
let _ = Cmd::new()
|
|
.arg("EXPIRE")
|
|
.arg(&key)
|
|
.arg(1_u64)
|
|
.query::<()>(&mut *conn.inner_mut());
|
|
}
|
|
Ok(count <= self.max_per_sec as i64)
|
|
}
|
|
|
|
pub fn check_sliding(&self, connection_id: Uuid) -> AppResult<bool> {
|
|
let key = format!("{WS_RATE_PREFIX}{connection_id}");
|
|
let mut conn = self.redis.get_connection()?;
|
|
let count: i64 = Cmd::new()
|
|
.arg("INCR")
|
|
.arg(&key)
|
|
.query(&mut *conn.inner_mut())
|
|
.map_err(crate::error::AppError::Redis)?;
|
|
if count == 1 {
|
|
let _ = Cmd::new()
|
|
.arg("EXPIRE")
|
|
.arg(&key)
|
|
.arg(2_u64)
|
|
.query::<()>(&mut *conn.inner_mut());
|
|
}
|
|
Ok(count <= self.max_per_sec as i64)
|
|
}
|
|
|
|
pub fn remaining(&self, connection_id: Uuid) -> AppResult<u32> {
|
|
let key = format!("{WS_RATE_PREFIX}{connection_id}");
|
|
let mut conn = self.redis.get_connection()?;
|
|
let count: Option<i64> = Cmd::new()
|
|
.arg("GET")
|
|
.arg(&key)
|
|
.query(&mut *conn.inner_mut())
|
|
.map_err(crate::error::AppError::Redis)?;
|
|
Ok(self.max_per_sec.saturating_sub(count.unwrap_or(0) as u32))
|
|
}
|
|
}
|
|
|
|
pub struct LocalRateLimiter {
|
|
count: std::sync::atomic::AtomicU32,
|
|
start: std::sync::Mutex<Instant>,
|
|
max_per_sec: u32,
|
|
}
|
|
|
|
impl LocalRateLimiter {
|
|
pub fn new(max_per_sec: u32) -> Self {
|
|
Self {
|
|
count: std::sync::atomic::AtomicU32::new(0),
|
|
start: std::sync::Mutex::new(Instant::now()),
|
|
max_per_sec,
|
|
}
|
|
}
|
|
|
|
pub fn check(&self) -> bool {
|
|
let mut start = self.start.lock().unwrap();
|
|
if start.elapsed().as_secs() >= 1 {
|
|
self.count.store(0, std::sync::atomic::Ordering::Relaxed);
|
|
*start = Instant::now();
|
|
}
|
|
drop(start);
|
|
self.count
|
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
|
< self.max_per_sec
|
|
}
|
|
}
|