302 lines
9.6 KiB
Rust
302 lines
9.6 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use dashmap::DashMap;
|
|
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
use crate::cache::redis::AppRedis;
|
|
use crate::error::{AppError, AppResult};
|
|
use crate::queue::NatsQueue;
|
|
use ::redis::Cmd;
|
|
|
|
use super::redis_keys::*;
|
|
use super::session_redis::{heartbeat_redis, register_redis_online, unregister_redis_online};
|
|
use super::typing;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum WsSessionState {
|
|
Connecting,
|
|
Authenticated,
|
|
Replaced,
|
|
Closing,
|
|
Closed,
|
|
}
|
|
|
|
impl WsSessionState {
|
|
pub fn is_deliverable(self) -> bool {
|
|
matches!(self, Self::Authenticated)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WsSession {
|
|
pub user_id: Uuid,
|
|
pub device_id: String,
|
|
pub connection_id: Uuid,
|
|
pub workspace_name: String,
|
|
pub connected_at: i64,
|
|
pub authenticated_at: Option<i64>,
|
|
pub state: WsSessionState,
|
|
pub superseded_by: Option<Uuid>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct WsSessionManager {
|
|
redis: AppRedis,
|
|
#[allow(dead_code)]
|
|
nats: Arc<NatsQueue>,
|
|
user_devices: Arc<DashMap<Uuid, HashMap<String, Uuid>>>,
|
|
sessions: Arc<DashMap<Uuid, WsSession>>,
|
|
channel_routes: Arc<DashMap<Uuid, HashSet<Uuid>>>,
|
|
session_channels: Arc<DashMap<Uuid, HashSet<Uuid>>>,
|
|
}
|
|
|
|
impl WsSessionManager {
|
|
pub fn new(redis: AppRedis, nats: Arc<NatsQueue>) -> Self {
|
|
Self {
|
|
redis,
|
|
nats,
|
|
user_devices: Arc::new(DashMap::new()),
|
|
sessions: Arc::new(DashMap::new()),
|
|
channel_routes: Arc::new(DashMap::new()),
|
|
session_channels: Arc::new(DashMap::new()),
|
|
}
|
|
}
|
|
|
|
pub fn issue_token(&self, user_id: Uuid, workspace_name: &str) -> AppResult<String> {
|
|
self.issue_token_for_device(user_id, workspace_name, "default")
|
|
}
|
|
|
|
pub fn issue_token_for_device(
|
|
&self,
|
|
user_id: Uuid,
|
|
workspace_name: &str,
|
|
device_id: &str,
|
|
) -> AppResult<String> {
|
|
let token = format!("ws_{}", Uuid::now_v7());
|
|
let session = WsSession {
|
|
user_id,
|
|
device_id: device_id.to_string(),
|
|
connection_id: Uuid::nil(),
|
|
workspace_name: workspace_name.to_string(),
|
|
connected_at: 0,
|
|
authenticated_at: None,
|
|
state: WsSessionState::Connecting,
|
|
superseded_by: None,
|
|
};
|
|
let json = serde_json::to_string(&session)?;
|
|
let key = format!("{WS_TOKEN_PREFIX}{token}");
|
|
let mut conn = self.redis.get_connection()?;
|
|
Cmd::new()
|
|
.arg("SETEX")
|
|
.arg(&key)
|
|
.arg(WS_TOKEN_TTL_SECS)
|
|
.arg(&json)
|
|
.query::<()>(&mut *conn.inner_mut())?;
|
|
Ok(token)
|
|
}
|
|
|
|
pub fn redeem_token(&self, token: &str) -> AppResult<WsSession> {
|
|
let key = format!("{WS_TOKEN_PREFIX}{token}");
|
|
let mut conn = self.redis.get_connection()?;
|
|
let json: Option<String> = Cmd::new()
|
|
.arg("GETDEL")
|
|
.arg(&key)
|
|
.query::<Option<String>>(&mut *conn.inner_mut())
|
|
.map_err(AppError::Redis)?;
|
|
let json = json.ok_or(AppError::Unauthorized)?;
|
|
let mut session: WsSession = serde_json::from_str(&json)
|
|
.map_err(|e| AppError::Config(format!("invalid ws session: {e}")))?;
|
|
let now = chrono::Utc::now().timestamp_millis();
|
|
session.connection_id = Uuid::now_v7();
|
|
session.connected_at = now;
|
|
session.authenticated_at = Some(now);
|
|
session.state = WsSessionState::Authenticated;
|
|
session.superseded_by = None;
|
|
Ok(session)
|
|
}
|
|
|
|
pub fn register_connection(&self, session: &WsSession) -> AppResult<()> {
|
|
let _ = self.register_connection_with_replacement(session)?;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn register_connection_with_replacement(
|
|
&self,
|
|
session: &WsSession,
|
|
) -> AppResult<Option<Uuid>> {
|
|
let mut current = session.clone();
|
|
current.state = WsSessionState::Authenticated;
|
|
current.superseded_by = None;
|
|
self.sessions.insert(current.connection_id, current.clone());
|
|
|
|
let replaced = {
|
|
let mut entry = self.user_devices.entry(current.user_id).or_default();
|
|
entry.insert(current.device_id.clone(), current.connection_id)
|
|
};
|
|
|
|
if let Some(old_id) = replaced
|
|
&& old_id != current.connection_id
|
|
{
|
|
if let Some(mut old) = self.sessions.get_mut(&old_id) {
|
|
old.state = WsSessionState::Replaced;
|
|
old.superseded_by = Some(current.connection_id);
|
|
}
|
|
self.unsubscribe_all(old_id);
|
|
}
|
|
|
|
register_redis_online(&self.redis, ¤t)?;
|
|
Ok(replaced.filter(|old| *old != current.connection_id))
|
|
}
|
|
|
|
pub fn unregister_connection(&self, session: &WsSession) -> AppResult<()> {
|
|
let removed = self.sessions.remove(&session.connection_id).map(|(_, s)| s);
|
|
let current = removed.as_ref().unwrap_or(session);
|
|
self.unsubscribe_all(current.connection_id);
|
|
|
|
if let Some(mut devices) = self.user_devices.get_mut(¤t.user_id)
|
|
&& devices.get(¤t.device_id).copied() == Some(current.connection_id)
|
|
{
|
|
devices.remove(¤t.device_id);
|
|
}
|
|
self.user_devices
|
|
.remove_if(¤t.user_id, |_, devices| devices.is_empty());
|
|
unregister_redis_online(&self.redis, current)
|
|
}
|
|
|
|
pub fn heartbeat(&self, session: &WsSession) -> AppResult<()> {
|
|
if !self.is_deliverable(session.connection_id) {
|
|
return Err(AppError::Unauthorized);
|
|
}
|
|
heartbeat_redis(&self.redis, session)
|
|
}
|
|
|
|
pub fn subscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
|
|
self.channel_routes
|
|
.entry(channel_id)
|
|
.or_default()
|
|
.insert(connection_id);
|
|
self.session_channels
|
|
.entry(connection_id)
|
|
.or_default()
|
|
.insert(channel_id);
|
|
}
|
|
|
|
pub fn unsubscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
|
|
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
|
|
sessions.remove(&connection_id);
|
|
}
|
|
self.channel_routes
|
|
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
|
|
if let Some(mut channels) = self.session_channels.get_mut(&connection_id) {
|
|
channels.remove(&channel_id);
|
|
}
|
|
self.session_channels
|
|
.remove_if(&connection_id, |_, channels| channels.is_empty());
|
|
}
|
|
|
|
pub fn unsubscribe_all(&self, connection_id: Uuid) {
|
|
let channels = self
|
|
.session_channels
|
|
.remove(&connection_id)
|
|
.map(|(_, channels)| channels)
|
|
.unwrap_or_default();
|
|
for channel_id in channels {
|
|
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
|
|
sessions.remove(&connection_id);
|
|
}
|
|
self.channel_routes
|
|
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
|
|
}
|
|
}
|
|
|
|
pub fn subscribers(&self, channel_id: Uuid) -> Vec<Uuid> {
|
|
self.channel_routes
|
|
.get(&channel_id)
|
|
.map(|sessions| sessions.iter().copied().collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
pub fn user_connections(&self, user_id: Uuid) -> Vec<Uuid> {
|
|
self.user_devices
|
|
.get(&user_id)
|
|
.map(|devices| devices.values().copied().collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
pub fn workspace_connections(&self, workspace_name: &str) -> Vec<Uuid> {
|
|
self.sessions
|
|
.iter()
|
|
.filter_map(|entry| {
|
|
let session = entry.value();
|
|
(session.workspace_name == workspace_name && session.state.is_deliverable())
|
|
.then_some(session.connection_id)
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
pub fn get_session(&self, connection_id: Uuid) -> Option<WsSession> {
|
|
self.sessions
|
|
.get(&connection_id)
|
|
.map(|session| session.clone())
|
|
}
|
|
|
|
pub fn is_deliverable(&self, connection_id: Uuid) -> bool {
|
|
self.sessions
|
|
.get(&connection_id)
|
|
.map(|session| session.state.is_deliverable() && session.superseded_by.is_none())
|
|
.unwrap_or(false)
|
|
}
|
|
|
|
pub fn is_user_online(&self, user_id: Uuid) -> AppResult<bool> {
|
|
Ok(self
|
|
.user_devices
|
|
.get(&user_id)
|
|
.map(|devices| !devices.is_empty())
|
|
.unwrap_or(false))
|
|
}
|
|
|
|
pub fn get_connection_count(&self, user_id: Uuid) -> AppResult<u32> {
|
|
Ok(self
|
|
.user_devices
|
|
.get(&user_id)
|
|
.map(|devices| devices.len() as u32)
|
|
.unwrap_or(0))
|
|
}
|
|
|
|
pub fn set_typing(
|
|
&self,
|
|
channel_id: Uuid,
|
|
thread_id: Option<Uuid>,
|
|
user_id: Uuid,
|
|
) -> AppResult<()> {
|
|
typing::set_typing(&self.redis, channel_id, thread_id, user_id)
|
|
}
|
|
|
|
pub fn clear_typing(
|
|
&self,
|
|
channel_id: Uuid,
|
|
thread_id: Option<Uuid>,
|
|
user_id: Uuid,
|
|
) -> AppResult<()> {
|
|
typing::clear_typing(&self.redis, channel_id, thread_id, user_id)
|
|
}
|
|
|
|
pub fn get_typing_users(
|
|
&self,
|
|
channel_id: Uuid,
|
|
thread_id: Option<Uuid>,
|
|
) -> AppResult<Vec<Uuid>> {
|
|
typing::get_typing_users(&self.redis, channel_id, thread_id)
|
|
}
|
|
|
|
pub fn heartbeat_interval(&self) -> Duration {
|
|
Duration::from_secs(WS_HEARTBEAT_INTERVAL_SECS)
|
|
}
|
|
pub fn heartbeat_interval_secs(&self) -> u64 {
|
|
WS_HEARTBEAT_INTERVAL_SECS
|
|
}
|
|
}
|