use std::time::{SystemTime, UNIX_EPOCH}; use async_trait::async_trait; use fred::prelude::*; use crate::socket::message_bus::redis::RedisMessageBus; use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait}; fn now_millis() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_millis() as u64 } const DEFAULT_TTL_SECS: u64 = 60; const KEY_PREFIX: &str = "socket.io:session"; pub struct RedisSessionStore { client: Client, ttl_secs: u64, } impl RedisSessionStore { pub fn new(bus: &RedisMessageBus, ttl_secs: Option) -> Self { Self { client: bus.client().clone(), ttl_secs: ttl_secs.unwrap_or(DEFAULT_TTL_SECS), } } fn key(&self, sid: &str) -> String { format!("{}:{}", KEY_PREFIX, sid) } } #[async_trait] impl SessionStoreTrait for RedisSessionStore { async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> { let key = self.key(sid); let now = now_millis(); // Batch all fields in a single HSET call for efficiency let fields: Vec<(&str, String)> = vec![ ("sid", sid.to_string()), ("transport", transport.to_string()), ("state", "connecting".to_string()), ("server_id", server_id.to_string()), ("created_at", now.to_string()), ("last_ping", now.to_string()), ]; self.client .hset::<(), _, _>(&key, fields) .await .map_err(|e| SessionError::Redis(e.to_string()))?; self.client .expire::<(), _>(&key, self.ttl_secs as i64, None) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(()) } async fn get(&self, sid: &str) -> Result, SessionError> { let key = self.key(sid); // Use hgetall directly — if the key doesn't exist Redis returns an empty map. // This avoids the TOCTOU race between EXISTS and HGETALL. let values: std::collections::HashMap = self.client .hgetall::, _>(&key) .await .map_err(|e| SessionError::Redis(e.to_string()))?; if values.is_empty() { return Ok(None); } let info = SessionInfo { sid: values.get("sid").cloned().unwrap_or_default(), transport: values.get("transport").cloned().unwrap_or_default(), state: values.get("state").cloned().unwrap_or_default(), server_id: values.get("server_id").cloned().unwrap_or_default(), created_at: values.get("created_at").and_then(|v| v.parse::().ok()).unwrap_or(0), last_ping: values.get("last_ping").and_then(|v| v.parse::().ok()).unwrap_or(0), }; Ok(Some(info)) } async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> { let key = self.key(sid); // Use HSET (not HSETNX) to overwrite existing fields self.client .hset::<(), _, _>(&key, ("state", state)) .await .map_err(|e| SessionError::Redis(e.to_string()))?; self.client .expire::<(), _>(&key, self.ttl_secs as i64, None) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(()) } async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> { let key = self.key(sid); // Use HSET (not HSETNX) to overwrite existing fields self.client .hset::<(), _, _>(&key, ("transport", transport)) .await .map_err(|e| SessionError::Redis(e.to_string()))?; self.client .expire::<(), _>(&key, self.ttl_secs as i64, None) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(()) } async fn update_ping(&self, sid: &str) -> Result<(), SessionError> { let key = self.key(sid); let now = now_millis(); // Use HSET (not HSETNX) to overwrite existing fields self.client .hset::<(), _, _>(&key, ("last_ping", now.to_string())) .await .map_err(|e| SessionError::Redis(e.to_string()))?; self.client .expire::<(), _>(&key, self.ttl_secs as i64, None) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(()) } async fn remove(&self, sid: &str) -> Result<(), SessionError> { let key = self.key(sid); self.client .del::<(), _>(&key) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(()) } async fn exists(&self, sid: &str) -> Result { let key = self.key(sid); let exists: bool = self.client .exists::(&key) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(exists) } }