use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; use dashmap::DashMap; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::cache::redis::AppRedis; use crate::error::{AppError, AppResult}; use crate::queue::NatsQueue; use ::redis::Cmd; use super::redis_keys::*; use super::session_redis::{heartbeat_redis, register_redis_online, unregister_redis_online}; use super::typing; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum WsSessionState { Connecting, Authenticated, Replaced, Closing, Closed, } impl WsSessionState { pub fn is_deliverable(self) -> bool { matches!(self, Self::Authenticated) } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WsSession { pub user_id: Uuid, pub device_id: String, pub connection_id: Uuid, pub workspace_name: String, pub connected_at: i64, pub authenticated_at: Option, pub state: WsSessionState, pub superseded_by: Option, } #[derive(Clone)] pub struct WsSessionManager { redis: AppRedis, #[allow(dead_code)] nats: Arc, user_devices: Arc>>, sessions: Arc>, channel_routes: Arc>>, session_channels: Arc>>, } impl WsSessionManager { pub fn new(redis: AppRedis, nats: Arc) -> Self { Self { redis, nats, user_devices: Arc::new(DashMap::new()), sessions: Arc::new(DashMap::new()), channel_routes: Arc::new(DashMap::new()), session_channels: Arc::new(DashMap::new()), } } pub fn issue_token(&self, user_id: Uuid, workspace_name: &str) -> AppResult { self.issue_token_for_device(user_id, workspace_name, "default") } pub fn issue_token_for_device( &self, user_id: Uuid, workspace_name: &str, device_id: &str, ) -> AppResult { let token = format!("ws_{}", Uuid::now_v7()); let session = WsSession { user_id, device_id: device_id.to_string(), connection_id: Uuid::nil(), workspace_name: workspace_name.to_string(), connected_at: 0, authenticated_at: None, state: WsSessionState::Connecting, superseded_by: None, }; let json = serde_json::to_string(&session)?; let key = format!("{WS_TOKEN_PREFIX}{token}"); let mut conn = self.redis.get_connection()?; Cmd::new() .arg("SETEX") .arg(&key) .arg(WS_TOKEN_TTL_SECS) .arg(&json) .query::<()>(&mut *conn.inner_mut())?; Ok(token) } pub fn redeem_token(&self, token: &str) -> AppResult { let key = format!("{WS_TOKEN_PREFIX}{token}"); let mut conn = self.redis.get_connection()?; let json: Option = Cmd::new() .arg("GETDEL") .arg(&key) .query::>(&mut *conn.inner_mut()) .map_err(AppError::Redis)?; let json = json.ok_or(AppError::Unauthorized)?; let mut session: WsSession = serde_json::from_str(&json) .map_err(|e| AppError::Config(format!("invalid ws session: {e}")))?; let now = chrono::Utc::now().timestamp_millis(); session.connection_id = Uuid::now_v7(); session.connected_at = now; session.authenticated_at = Some(now); session.state = WsSessionState::Authenticated; session.superseded_by = None; Ok(session) } pub fn register_connection(&self, session: &WsSession) -> AppResult<()> { let _ = self.register_connection_with_replacement(session)?; Ok(()) } pub fn register_connection_with_replacement( &self, session: &WsSession, ) -> AppResult> { let mut current = session.clone(); current.state = WsSessionState::Authenticated; current.superseded_by = None; self.sessions.insert(current.connection_id, current.clone()); let replaced = { let mut entry = self.user_devices.entry(current.user_id).or_default(); entry.insert(current.device_id.clone(), current.connection_id) }; if let Some(old_id) = replaced && old_id != current.connection_id { if let Some(mut old) = self.sessions.get_mut(&old_id) { old.state = WsSessionState::Replaced; old.superseded_by = Some(current.connection_id); } self.unsubscribe_all(old_id); } register_redis_online(&self.redis, ¤t)?; Ok(replaced.filter(|old| *old != current.connection_id)) } pub fn unregister_connection(&self, session: &WsSession) -> AppResult<()> { let removed = self.sessions.remove(&session.connection_id).map(|(_, s)| s); let current = removed.as_ref().unwrap_or(session); self.unsubscribe_all(current.connection_id); if let Some(mut devices) = self.user_devices.get_mut(¤t.user_id) && devices.get(¤t.device_id).copied() == Some(current.connection_id) { devices.remove(¤t.device_id); } self.user_devices .remove_if(¤t.user_id, |_, devices| devices.is_empty()); unregister_redis_online(&self.redis, current) } pub fn heartbeat(&self, session: &WsSession) -> AppResult<()> { if !self.is_deliverable(session.connection_id) { return Err(AppError::Unauthorized); } heartbeat_redis(&self.redis, session) } pub fn subscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) { self.channel_routes .entry(channel_id) .or_default() .insert(connection_id); self.session_channels .entry(connection_id) .or_default() .insert(channel_id); } pub fn unsubscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) { if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) { sessions.remove(&connection_id); } self.channel_routes .remove_if(&channel_id, |_, sessions| sessions.is_empty()); if let Some(mut channels) = self.session_channels.get_mut(&connection_id) { channels.remove(&channel_id); } self.session_channels .remove_if(&connection_id, |_, channels| channels.is_empty()); } pub fn unsubscribe_all(&self, connection_id: Uuid) { let channels = self .session_channels .remove(&connection_id) .map(|(_, channels)| channels) .unwrap_or_default(); for channel_id in channels { if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) { sessions.remove(&connection_id); } self.channel_routes .remove_if(&channel_id, |_, sessions| sessions.is_empty()); } } pub fn subscribers(&self, channel_id: Uuid) -> Vec { self.channel_routes .get(&channel_id) .map(|sessions| sessions.iter().copied().collect()) .unwrap_or_default() } pub fn user_connections(&self, user_id: Uuid) -> Vec { self.user_devices .get(&user_id) .map(|devices| devices.values().copied().collect()) .unwrap_or_default() } pub fn workspace_connections(&self, workspace_name: &str) -> Vec { self.sessions .iter() .filter_map(|entry| { let session = entry.value(); (session.workspace_name == workspace_name && session.state.is_deliverable()) .then_some(session.connection_id) }) .collect() } pub fn get_session(&self, connection_id: Uuid) -> Option { self.sessions .get(&connection_id) .map(|session| session.clone()) } pub fn is_deliverable(&self, connection_id: Uuid) -> bool { self.sessions .get(&connection_id) .map(|session| session.state.is_deliverable() && session.superseded_by.is_none()) .unwrap_or(false) } pub fn is_user_online(&self, user_id: Uuid) -> AppResult { Ok(self .user_devices .get(&user_id) .map(|devices| !devices.is_empty()) .unwrap_or(false)) } pub fn get_connection_count(&self, user_id: Uuid) -> AppResult { Ok(self .user_devices .get(&user_id) .map(|devices| devices.len() as u32) .unwrap_or(0)) } pub fn set_typing( &self, channel_id: Uuid, thread_id: Option, user_id: Uuid, ) -> AppResult<()> { typing::set_typing(&self.redis, channel_id, thread_id, user_id) } pub fn clear_typing( &self, channel_id: Uuid, thread_id: Option, user_id: Uuid, ) -> AppResult<()> { typing::clear_typing(&self.redis, channel_id, thread_id, user_id) } pub fn get_typing_users( &self, channel_id: Uuid, thread_id: Option, ) -> AppResult> { typing::get_typing_users(&self.redis, channel_id, thread_id) } pub fn heartbeat_interval(&self) -> Duration { Duration::from_secs(WS_HEARTBEAT_INTERVAL_SECS) } pub fn heartbeat_interval_secs(&self) -> u64 { WS_HEARTBEAT_INTERVAL_SECS } }