feat: init
This commit is contained in:
@@ -0,0 +1,301 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
use crate::queue::NatsQueue;
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
use super::session_redis::{heartbeat_redis, register_redis_online, unregister_redis_online};
|
||||
use super::typing;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum WsSessionState {
|
||||
Connecting,
|
||||
Authenticated,
|
||||
Replaced,
|
||||
Closing,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl WsSessionState {
|
||||
pub fn is_deliverable(self) -> bool {
|
||||
matches!(self, Self::Authenticated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WsSession {
|
||||
pub user_id: Uuid,
|
||||
pub device_id: String,
|
||||
pub connection_id: Uuid,
|
||||
pub workspace_name: String,
|
||||
pub connected_at: i64,
|
||||
pub authenticated_at: Option<i64>,
|
||||
pub state: WsSessionState,
|
||||
pub superseded_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsSessionManager {
|
||||
redis: AppRedis,
|
||||
#[allow(dead_code)]
|
||||
nats: Arc<NatsQueue>,
|
||||
user_devices: Arc<DashMap<Uuid, HashMap<String, Uuid>>>,
|
||||
sessions: Arc<DashMap<Uuid, WsSession>>,
|
||||
channel_routes: Arc<DashMap<Uuid, HashSet<Uuid>>>,
|
||||
session_channels: Arc<DashMap<Uuid, HashSet<Uuid>>>,
|
||||
}
|
||||
|
||||
impl WsSessionManager {
|
||||
pub fn new(redis: AppRedis, nats: Arc<NatsQueue>) -> Self {
|
||||
Self {
|
||||
redis,
|
||||
nats,
|
||||
user_devices: Arc::new(DashMap::new()),
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
channel_routes: Arc::new(DashMap::new()),
|
||||
session_channels: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn issue_token(&self, user_id: Uuid, workspace_name: &str) -> AppResult<String> {
|
||||
self.issue_token_for_device(user_id, workspace_name, "default")
|
||||
}
|
||||
|
||||
pub fn issue_token_for_device(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
workspace_name: &str,
|
||||
device_id: &str,
|
||||
) -> AppResult<String> {
|
||||
let token = format!("ws_{}", Uuid::now_v7());
|
||||
let session = WsSession {
|
||||
user_id,
|
||||
device_id: device_id.to_string(),
|
||||
connection_id: Uuid::nil(),
|
||||
workspace_name: workspace_name.to_string(),
|
||||
connected_at: 0,
|
||||
authenticated_at: None,
|
||||
state: WsSessionState::Connecting,
|
||||
superseded_by: None,
|
||||
};
|
||||
let json = serde_json::to_string(&session)?;
|
||||
let key = format!("{WS_TOKEN_PREFIX}{token}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&key)
|
||||
.arg(WS_TOKEN_TTL_SECS)
|
||||
.arg(&json)
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub fn redeem_token(&self, token: &str) -> AppResult<WsSession> {
|
||||
let key = format!("{WS_TOKEN_PREFIX}{token}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let json: Option<String> = Cmd::new()
|
||||
.arg("GETDEL")
|
||||
.arg(&key)
|
||||
.query::<Option<String>>(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
let json = json.ok_or(AppError::Unauthorized)?;
|
||||
let mut session: WsSession = serde_json::from_str(&json)
|
||||
.map_err(|e| AppError::Config(format!("invalid ws session: {e}")))?;
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
session.connection_id = Uuid::now_v7();
|
||||
session.connected_at = now;
|
||||
session.authenticated_at = Some(now);
|
||||
session.state = WsSessionState::Authenticated;
|
||||
session.superseded_by = None;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn register_connection(&self, session: &WsSession) -> AppResult<()> {
|
||||
let _ = self.register_connection_with_replacement(session)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn register_connection_with_replacement(
|
||||
&self,
|
||||
session: &WsSession,
|
||||
) -> AppResult<Option<Uuid>> {
|
||||
let mut current = session.clone();
|
||||
current.state = WsSessionState::Authenticated;
|
||||
current.superseded_by = None;
|
||||
self.sessions.insert(current.connection_id, current.clone());
|
||||
|
||||
let replaced = {
|
||||
let mut entry = self.user_devices.entry(current.user_id).or_default();
|
||||
entry.insert(current.device_id.clone(), current.connection_id)
|
||||
};
|
||||
|
||||
if let Some(old_id) = replaced
|
||||
&& old_id != current.connection_id
|
||||
{
|
||||
if let Some(mut old) = self.sessions.get_mut(&old_id) {
|
||||
old.state = WsSessionState::Replaced;
|
||||
old.superseded_by = Some(current.connection_id);
|
||||
}
|
||||
self.unsubscribe_all(old_id);
|
||||
}
|
||||
|
||||
register_redis_online(&self.redis, ¤t)?;
|
||||
Ok(replaced.filter(|old| *old != current.connection_id))
|
||||
}
|
||||
|
||||
pub fn unregister_connection(&self, session: &WsSession) -> AppResult<()> {
|
||||
let removed = self.sessions.remove(&session.connection_id).map(|(_, s)| s);
|
||||
let current = removed.as_ref().unwrap_or(session);
|
||||
self.unsubscribe_all(current.connection_id);
|
||||
|
||||
if let Some(mut devices) = self.user_devices.get_mut(¤t.user_id)
|
||||
&& devices.get(¤t.device_id).copied() == Some(current.connection_id)
|
||||
{
|
||||
devices.remove(¤t.device_id);
|
||||
}
|
||||
self.user_devices
|
||||
.remove_if(¤t.user_id, |_, devices| devices.is_empty());
|
||||
unregister_redis_online(&self.redis, current)
|
||||
}
|
||||
|
||||
pub fn heartbeat(&self, session: &WsSession) -> AppResult<()> {
|
||||
if !self.is_deliverable(session.connection_id) {
|
||||
return Err(AppError::Unauthorized);
|
||||
}
|
||||
heartbeat_redis(&self.redis, session)
|
||||
}
|
||||
|
||||
pub fn subscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
|
||||
self.channel_routes
|
||||
.entry(channel_id)
|
||||
.or_default()
|
||||
.insert(connection_id);
|
||||
self.session_channels
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(channel_id);
|
||||
}
|
||||
|
||||
pub fn unsubscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
|
||||
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
|
||||
sessions.remove(&connection_id);
|
||||
}
|
||||
self.channel_routes
|
||||
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
|
||||
if let Some(mut channels) = self.session_channels.get_mut(&connection_id) {
|
||||
channels.remove(&channel_id);
|
||||
}
|
||||
self.session_channels
|
||||
.remove_if(&connection_id, |_, channels| channels.is_empty());
|
||||
}
|
||||
|
||||
pub fn unsubscribe_all(&self, connection_id: Uuid) {
|
||||
let channels = self
|
||||
.session_channels
|
||||
.remove(&connection_id)
|
||||
.map(|(_, channels)| channels)
|
||||
.unwrap_or_default();
|
||||
for channel_id in channels {
|
||||
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
|
||||
sessions.remove(&connection_id);
|
||||
}
|
||||
self.channel_routes
|
||||
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribers(&self, channel_id: Uuid) -> Vec<Uuid> {
|
||||
self.channel_routes
|
||||
.get(&channel_id)
|
||||
.map(|sessions| sessions.iter().copied().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn user_connections(&self, user_id: Uuid) -> Vec<Uuid> {
|
||||
self.user_devices
|
||||
.get(&user_id)
|
||||
.map(|devices| devices.values().copied().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn workspace_connections(&self, workspace_name: &str) -> Vec<Uuid> {
|
||||
self.sessions
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let session = entry.value();
|
||||
(session.workspace_name == workspace_name && session.state.is_deliverable())
|
||||
.then_some(session.connection_id)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_session(&self, connection_id: Uuid) -> Option<WsSession> {
|
||||
self.sessions
|
||||
.get(&connection_id)
|
||||
.map(|session| session.clone())
|
||||
}
|
||||
|
||||
pub fn is_deliverable(&self, connection_id: Uuid) -> bool {
|
||||
self.sessions
|
||||
.get(&connection_id)
|
||||
.map(|session| session.state.is_deliverable() && session.superseded_by.is_none())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn is_user_online(&self, user_id: Uuid) -> AppResult<bool> {
|
||||
Ok(self
|
||||
.user_devices
|
||||
.get(&user_id)
|
||||
.map(|devices| !devices.is_empty())
|
||||
.unwrap_or(false))
|
||||
}
|
||||
|
||||
pub fn get_connection_count(&self, user_id: Uuid) -> AppResult<u32> {
|
||||
Ok(self
|
||||
.user_devices
|
||||
.get(&user_id)
|
||||
.map(|devices| devices.len() as u32)
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
pub fn set_typing(
|
||||
&self,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
user_id: Uuid,
|
||||
) -> AppResult<()> {
|
||||
typing::set_typing(&self.redis, channel_id, thread_id, user_id)
|
||||
}
|
||||
|
||||
pub fn clear_typing(
|
||||
&self,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
user_id: Uuid,
|
||||
) -> AppResult<()> {
|
||||
typing::clear_typing(&self.redis, channel_id, thread_id, user_id)
|
||||
}
|
||||
|
||||
pub fn get_typing_users(
|
||||
&self,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
) -> AppResult<Vec<Uuid>> {
|
||||
typing::get_typing_users(&self.redis, channel_id, thread_id)
|
||||
}
|
||||
|
||||
pub fn heartbeat_interval(&self) -> Duration {
|
||||
Duration::from_secs(WS_HEARTBEAT_INTERVAL_SECS)
|
||||
}
|
||||
pub fn heartbeat_interval_secs(&self) -> u64 {
|
||||
WS_HEARTBEAT_INTERVAL_SECS
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user