feat: init

This commit is contained in:
zhenyi
2026-06-07 11:30:56 +08:00
commit 563381c1ca
361 changed files with 41327 additions and 0 deletions
+200
View File
@@ -0,0 +1,200 @@
use std::sync::Arc;
use futures_util::StreamExt;
use uuid::Uuid;
use crate::queue::NatsQueue;
use super::{
ArticleEvent, CategoryEvent, DraftEvent, FollowEvent, MemberEvent, MessageEvent, PollEvent,
PresenceEvent, ReactionEvent, ThreadEvent, TypingEvent, WsOutbound, WsSessionManager,
WsSinkManager,
};
#[derive(Clone)]
pub struct NatsWsBridge {
queue: Arc<NatsQueue>,
sessions: Arc<WsSessionManager>,
sinks: Arc<WsSinkManager>,
}
impl NatsWsBridge {
pub fn new(
queue: Arc<NatsQueue>,
sessions: Arc<WsSessionManager>,
sinks: Arc<WsSinkManager>,
) -> Self {
Self {
queue,
sessions,
sinks,
}
}
pub async fn run_ephemeral(self, subject: &str) {
let Ok(mut sub) = self.queue.subscribe_ephemeral(subject.to_string()).await else {
tracing::warn!(subject, "nats ws bridge subscribe failed");
return;
};
while let Some(msg) = sub.next().await {
self.dispatch(msg.subject.as_str(), msg.payload.as_ref(), request_id(&msg))
.await;
}
}
async fn dispatch(&self, subject: &str, payload: &[u8], request_id: Uuid) {
if subject.starts_with("im.message.") {
self.channel_event(payload, |data| WsOutbound::Message { request_id, data });
} else if subject.starts_with("im.thread.") {
self.channel_event(payload, |data| WsOutbound::Thread { request_id, data });
} else if subject.starts_with("im.member.") {
self.channel_event(payload, |data| WsOutbound::Member { request_id, data });
} else if subject.starts_with("im.reaction.") {
self.channel_event(payload, |data| WsOutbound::Reaction { request_id, data });
} else if subject.starts_with("im.poll.") {
self.channel_event(payload, |data| WsOutbound::Poll { request_id, data });
} else if subject.starts_with("im.article.") {
self.channel_event(payload, |data| WsOutbound::Article { request_id, data });
} else if subject.starts_with("im.typing.") {
self.channel_event(payload, |data| WsOutbound::Typing { request_id, data });
} else if subject.starts_with("im.presence.") {
self.presence_event(payload, request_id);
} else if subject.starts_with("im.channel.") {
self.channel_meta_event(subject, payload, request_id);
} else if subject.starts_with("im.category.") {
self.category_event(payload, request_id);
} else if subject.starts_with("im.draft.") {
self.draft_event(payload, request_id);
} else if subject.starts_with("im.follow.") {
self.channel_event(payload, |data| WsOutbound::Follow { request_id, data });
}
}
fn channel_event<T, F>(&self, payload: &[u8], build: F)
where
T: serde::de::DeserializeOwned + ChannelScoped,
F: Fn(T) -> WsOutbound,
{
let Ok(data) = serde_json::from_slice::<T>(payload) else {
tracing::warn!("nats ws bridge decode channel event failed");
return;
};
let channel_id = data.channel_id();
let subscribers = self.sessions.subscribers(channel_id);
let delivered = self.sinks.send_many(subscribers, build(data));
tracing::debug!(%channel_id, delivered, "nats event forwarded to ws subscribers");
}
fn presence_event(&self, payload: &[u8], request_id: Uuid) {
let Ok(data) = serde_json::from_slice::<PresenceEvent>(payload) else {
tracing::warn!("nats ws bridge decode presence event failed");
return;
};
let ids = self.sessions.user_connections(data.user_id);
let delivered = self
.sinks
.send_many(ids, WsOutbound::Presence { request_id, data });
tracing::debug!(delivered, "nats presence forwarded to ws subscribers");
}
fn category_event(&self, payload: &[u8], request_id: Uuid) {
let Ok(data) = serde_json::from_slice::<CategoryEvent>(payload) else {
tracing::warn!("nats ws bridge decode category event failed");
return;
};
let targets = self.sessions.workspace_connections(&data.workspace_name);
let delivered = self
.sinks
.send_many(targets, WsOutbound::Category { request_id, data });
tracing::debug!(delivered, "nats category event forwarded to ws subscribers");
}
fn draft_event(&self, payload: &[u8], request_id: Uuid) {
let Ok(data) = serde_json::from_slice::<DraftEvent>(payload) else {
tracing::warn!("nats ws bridge decode draft event failed");
return;
};
let targets = self.sessions.user_connections(data.user_id);
let delivered = self
.sinks
.send_many(targets, WsOutbound::Draft { request_id, data });
tracing::debug!(delivered, "nats draft event forwarded to ws subscribers");
}
fn channel_meta_event(&self, subject: &str, payload: &[u8], request_id: Uuid) {
let Ok(data) = serde_json::from_slice::<super::ChannelEvent>(payload) else {
tracing::warn!("nats ws bridge decode channel event failed");
return;
};
let mut targets = data
.workspace_name
.as_deref()
.map(|workspace| self.sessions.workspace_connections(workspace))
.unwrap_or_else(|| self.sessions.subscribers(data.channel_id));
if targets.is_empty()
&& let Some(id) = subject
.rsplit('.')
.next()
.and_then(|v| v.parse::<Uuid>().ok())
{
targets = self.sessions.subscribers(id);
}
let delivered = self
.sinks
.send_many(targets, WsOutbound::Channel { request_id, data });
tracing::debug!(delivered, "nats channel event forwarded to ws subscribers");
}
}
pub trait ChannelScoped {
fn channel_id(&self) -> Uuid;
}
impl ChannelScoped for MessageEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for ThreadEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for MemberEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for ReactionEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for PollEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for ArticleEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for TypingEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
impl ChannelScoped for FollowEvent {
fn channel_id(&self) -> Uuid {
self.channel_id
}
}
fn request_id(msg: &async_nats::Message) -> Uuid {
msg.headers
.as_ref()
.and_then(|h| h.get("X-Request-Id"))
.and_then(|v| v.as_str().parse().ok())
.unwrap_or_else(Uuid::nil)
}
+58
View File
@@ -0,0 +1,58 @@
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::AppResult;
use ::redis::Cmd;
use super::redis_keys::*;
pub struct DedupManager {
redis: AppRedis,
window_secs: u64,
}
impl DedupManager {
pub fn new(redis: AppRedis) -> Self {
Self {
redis,
window_secs: WS_DEDUP_WINDOW_SECS,
}
}
pub fn check_and_mark(&self, message_id: Uuid, channel_id: Uuid) -> AppResult<bool> {
let key = format!("{WS_DEDUP_PREFIX}{channel_id}:{message_id}");
let mut conn = self.redis.get_connection()?;
let result: Option<String> = Cmd::new()
.arg("SET")
.arg(&key)
.arg("1")
.arg("NX")
.arg("EX")
.arg(self.window_secs)
.query(&mut *conn.inner_mut())
.map_err(crate::error::AppError::Redis)?;
Ok(result.is_some())
}
pub fn is_duplicate(&self, message_id: Uuid, channel_id: Uuid) -> AppResult<bool> {
let key = format!("{WS_DEDUP_PREFIX}{channel_id}:{message_id}");
let mut conn = self.redis.get_connection()?;
let exists: bool = Cmd::new()
.arg("EXISTS")
.arg(&key)
.query(&mut *conn.inner_mut())
.map_err(crate::error::AppError::Redis)?;
Ok(exists)
}
pub fn clear(&self, message_id: Uuid, channel_id: Uuid) -> AppResult<()> {
let key = format!("{WS_DEDUP_PREFIX}{channel_id}:{message_id}");
let mut conn = self.redis.get_connection()?;
Cmd::new()
.arg("DEL")
.arg(&key)
.query::<()>(&mut *conn.inner_mut())
.map_err(crate::error::AppError::Redis)?;
Ok(())
}
}
+39
View File
@@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportEnvelope<T> {
#[serde(default = "Uuid::now_v7")]
pub message_id: Uuid,
pub request_id: Uuid,
pub user_id: Uuid,
pub payload: T,
#[serde(default = "default_timestamp")]
pub created_at: chrono::DateTime<chrono::Utc>,
#[serde(default)]
pub attempt: u8,
}
fn default_timestamp() -> chrono::DateTime<chrono::Utc> {
chrono::Utc::now()
}
impl<T> TransportEnvelope<T> {
pub fn new(request_id: Uuid, user_id: Uuid, payload: T) -> Self {
Self {
message_id: Uuid::now_v7(),
request_id,
user_id,
payload,
created_at: chrono::Utc::now(),
attempt: 1,
}
}
pub fn retry(self) -> Self {
Self {
attempt: self.attempt + 1,
..self
}
}
}
+447
View File
@@ -0,0 +1,447 @@
use std::sync::Arc;
use uuid::Uuid;
use crate::immediate::dedup::DedupManager;
use crate::immediate::limiter::HandlerLimiter;
use crate::immediate::nats::ImNats;
use crate::immediate::outbound::*;
use crate::immediate::rate_limit::{LocalRateLimiter, RateLimiter};
use crate::immediate::reconnect::ReconnectManager;
use crate::immediate::session::{WsSession, WsSessionManager};
use crate::service::ImService;
use crate::service::im::messages::EditMessageParams;
use crate::service::im::messages::SendMessageParams;
use crate::service::im::presence::UpdatePresenceParams;
use crate::service::im::session::ImSession;
use super::inbound::WsInbound;
use super::redis_keys::*;
use super::sink::WsSinkManager;
#[allow(dead_code)]
#[derive(Clone)]
pub struct WsHandler {
nats: Arc<ImNats>,
manager: Arc<WsSessionManager>,
sinks: Arc<WsSinkManager>,
service: ImService,
dedup: Arc<DedupManager>,
rate_limiter: Arc<RateLimiter>,
local_limiter: Arc<LocalRateLimiter>,
handler_limiter: Arc<HandlerLimiter>,
reconnect: Arc<ReconnectManager>,
session: Option<WsSession>,
}
#[allow(dead_code)]
impl WsHandler {
pub fn new(
manager: Arc<WsSessionManager>,
sinks: Arc<WsSinkManager>,
service: ImService,
nats: Arc<ImNats>,
dedup: Arc<DedupManager>,
rate_limiter: Arc<RateLimiter>,
reconnect: Arc<ReconnectManager>,
) -> Self {
Self {
nats,
manager,
sinks,
service,
dedup,
rate_limiter,
local_limiter: Arc::new(LocalRateLimiter::new(WS_MAX_MESSAGES_PER_SEC)),
handler_limiter: Arc::new(HandlerLimiter::new(1024)),
reconnect,
session: None,
}
}
pub fn session(&self) -> Option<&WsSession> {
self.session.as_ref()
}
pub fn is_authenticated(&self) -> bool {
self.session.is_some()
}
pub fn handle_disconnect(&self) {
if let Some(s) = &self.session
&& let Err(e) = self.manager.unregister_connection(s)
{
tracing::warn!(conn = %s.connection_id, error = %e, "unregister failed");
}
}
pub async fn handle(&mut self, msg: WsInbound) -> Vec<WsOutbound> {
match msg {
WsInbound::Auth { request_id, token } => self.handle_auth(request_id, token).await,
m => {
let Some(s) = &self.session else {
return vec![WsOutbound::Error {
request_id: request_id_of(&m),
code: "not_authenticated".into(),
message: "authenticate first".into(),
}];
};
if !self.manager.is_deliverable(s.connection_id) {
return vec![WsOutbound::Error {
request_id: request_id_of(&m),
code: "session_not_active".into(),
message: "session is not active".into(),
}];
}
let Ok(_permit) = self.handler_limiter.try_acquire() else {
return vec![WsOutbound::Error {
request_id: request_id_of(&m),
code: "overloaded".into(),
message: "too many inflight messages".into(),
}];
};
if !self.local_limiter.check() {
return vec![WsOutbound::Error {
request_id: request_id_of(&m),
code: "rate_limit_exceeded".into(),
message: "too many messages".into(),
}];
}
match self.rate_limiter.check(s.connection_id) {
Ok(true) => {}
Ok(false) => {
return vec![WsOutbound::Error {
request_id: request_id_of(&m),
code: "rate_limit_exceeded".into(),
message: "too many messages".into(),
}];
}
Err(e) => tracing::warn!(error = %e, "rate limit check failed"),
}
self.dispatch(s, m).await
}
}
}
async fn dispatch(&self, session: &WsSession, msg: WsInbound) -> Vec<WsOutbound> {
match msg {
WsInbound::Heartbeat { request_id } => {
if let Err(e) = self.manager.heartbeat(session) {
tracing::warn!(user = %session.user_id, error = %e, "heartbeat failed");
}
vec![WsOutbound::HeartbeatAck {
request_id,
timestamp_ms: chrono::Utc::now().timestamp_millis(),
}]
}
WsInbound::JoinChannel {
request_id,
channel_id,
} => match self.service.resolve_channel(channel_id).await {
Ok(channel) => match self
.service
.ensure_channel_readable(session.user_id, &channel)
.await
{
Ok(()) => {
self.manager
.subscribe_channel(session.connection_id, channel_id);
vec![]
}
Err(e) => vec![WsOutbound::Error {
request_id,
code: "join_channel_failed".into(),
message: e.to_string(),
}],
},
Err(e) => vec![WsOutbound::Error {
request_id,
code: "join_channel_failed".into(),
message: e.to_string(),
}],
},
WsInbound::LeaveChannel {
request_id: _,
channel_id,
} => {
self.manager
.unsubscribe_channel(session.connection_id, channel_id);
vec![]
}
WsInbound::TypingStart {
request_id,
channel_id,
thread_id,
} => {
let _ = self
.manager
.set_typing(channel_id, thread_id, session.user_id);
self.nats
.emit(
&ImNats::typing_subject(channel_id),
request_id,
&TypingEvent {
channel_id,
thread_id,
user_id: session.user_id,
},
)
.await;
vec![]
}
WsInbound::TypingStop {
request_id: _,
channel_id,
thread_id,
} => {
let _ = self
.manager
.clear_typing(channel_id, thread_id, session.user_id);
vec![]
}
WsInbound::MessageSend {
request_id,
channel_id,
body,
thread_id,
reply_to,
message_type,
} => {
if body.len() > WS_MAX_MESSAGE_BYTES {
return vec![WsOutbound::Error {
request_id,
code: "message_too_large".into(),
message: "message body too large".into(),
}];
}
match self.dedup.check_and_mark(request_id, channel_id) {
Ok(true) => {}
Ok(false) => {
return vec![WsOutbound::Error {
request_id,
code: "duplicate".into(),
message: "duplicate message".into(),
}];
}
Err(e) => tracing::warn!(error = %e, "dedup check failed"),
}
let ctx = ImSession::new(session.user_id);
let params = SendMessageParams {
body,
message_type,
thread_id,
reply_to_message_id: reply_to,
pinned: None,
attachments: None,
embeds: None,
};
match self
.service
.message_send(
&ctx,
&session.workspace_name,
channel_id,
params,
request_id,
)
.await
{
Ok(msg) => vec![WsOutbound::SeqAck {
request_id,
channel_id,
seq: msg.seq,
}],
Err(e) => {
if let Err(clear_err) = self.dedup.clear(request_id, channel_id) {
tracing::warn!(error = %clear_err, "dedup clear failed after message send error");
}
vec![WsOutbound::Error {
request_id,
code: "message_send_failed".into(),
message: e.to_string(),
}]
}
}
}
WsInbound::MessageEdit {
request_id,
channel_id,
message_id,
body,
} => {
if body.len() > WS_MAX_MESSAGE_BYTES {
return vec![WsOutbound::Error {
request_id,
code: "message_too_large".into(),
message: "message body too large".into(),
}];
}
let ctx = ImSession::new(session.user_id);
let params = EditMessageParams { body };
match self
.service
.message_edit(
&ctx,
&session.workspace_name,
channel_id,
message_id,
params,
request_id,
)
.await
{
Ok(_) => vec![],
Err(e) => vec![WsOutbound::Error {
request_id,
code: "message_edit_failed".into(),
message: e.to_string(),
}],
}
}
WsInbound::MessageDelete {
request_id,
channel_id,
message_id,
} => {
let ctx = ImSession::new(session.user_id);
match self
.service
.message_delete(
&ctx,
&session.workspace_name,
channel_id,
message_id,
request_id,
)
.await
{
Ok(()) => vec![],
Err(e) => vec![WsOutbound::Error {
request_id,
code: "message_delete_failed".into(),
message: e.to_string(),
}],
}
}
WsInbound::PresenceUpdate {
request_id,
status,
custom_status_text,
custom_status_emoji,
} => {
let ctx = ImSession::new(session.user_id);
let params = UpdatePresenceParams {
status,
custom_status_text: custom_status_text.clone(),
custom_status_emoji: custom_status_emoji.clone(),
};
match self
.service
.presence_update(&ctx, &session.workspace_name, params)
.await
{
Ok(p) => {
self.nats
.emit(
&ImNats::presence_subject(session.user_id),
request_id,
&PresenceEvent {
user_id: session.user_id,
status: p.status.to_string(),
custom_status_text,
custom_status_emoji,
},
)
.await;
vec![]
}
Err(e) => vec![WsOutbound::Error {
request_id,
code: "presence_update_failed".into(),
message: e.to_string(),
}],
}
}
WsInbound::ReadReceipt {
request_id,
channel_id,
last_read_message_id,
last_seq,
} => {
if let Some(seq) = last_seq
&& let Err(e) =
self.reconnect
.save_read_position(session.user_id, channel_id, seq)
{
tracing::warn!(error = %e, "save read position failed");
}
vec![WsOutbound::ReadReceiptAck {
request_id,
channel_id,
last_read_message_id,
last_seq,
}]
}
WsInbound::Auth { .. } => unreachable!(),
}
}
fn close_replaced_connection(&self, old_id: Uuid, new_id: Uuid) {
let _ = self.sinks.send(
old_id,
WsOutbound::Error {
request_id: Uuid::nil(),
code: "session_replaced".into(),
message: format!("session replaced by {new_id}"),
},
);
self.sinks.detach(old_id);
if let Some(old) = self.manager.get_session(old_id)
&& let Err(e) = self.manager.unregister_connection(&old)
{
tracing::warn!(conn = %old_id, error = %e, "unregister replaced connection failed");
}
}
async fn handle_auth(&mut self, request_id: Uuid, token: String) -> Vec<WsOutbound> {
match self.manager.redeem_token(&token) {
Ok(session) => {
match self.manager.register_connection_with_replacement(&session) {
Ok(Some(old_id)) => {
self.close_replaced_connection(old_id, session.connection_id)
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "register connection failed"),
}
let cid = session.connection_id;
let interval = self.manager.heartbeat_interval_secs();
self.session = Some(session);
vec![WsOutbound::AuthOk {
request_id,
connection_id: cid,
heartbeat_interval_secs: interval,
}]
}
Err(e) => vec![WsOutbound::AuthError {
request_id,
message: e.to_string(),
}],
}
}
}
#[allow(dead_code)]
fn request_id_of(msg: &WsInbound) -> Uuid {
match msg {
WsInbound::Auth { request_id, .. } => *request_id,
WsInbound::Heartbeat { request_id } => *request_id,
WsInbound::JoinChannel { request_id, .. } => *request_id,
WsInbound::LeaveChannel { request_id, .. } => *request_id,
WsInbound::TypingStart { request_id, .. } => *request_id,
WsInbound::TypingStop { request_id, .. } => *request_id,
WsInbound::MessageSend { request_id, .. } => *request_id,
WsInbound::MessageEdit { request_id, .. } => *request_id,
WsInbound::MessageDelete { request_id, .. } => *request_id,
WsInbound::PresenceUpdate { request_id, .. } => *request_id,
WsInbound::ReadReceipt { request_id, .. } => *request_id,
}
}
+68
View File
@@ -0,0 +1,68 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WsInbound {
Auth {
request_id: Uuid,
token: String,
},
Heartbeat {
request_id: Uuid,
},
JoinChannel {
request_id: Uuid,
channel_id: Uuid,
},
LeaveChannel {
request_id: Uuid,
channel_id: Uuid,
},
TypingStart {
request_id: Uuid,
channel_id: Uuid,
thread_id: Option<Uuid>,
},
TypingStop {
request_id: Uuid,
channel_id: Uuid,
thread_id: Option<Uuid>,
},
MessageSend {
request_id: Uuid,
channel_id: Uuid,
body: String,
#[serde(skip_serializing_if = "Option::is_none")]
thread_id: Option<Uuid>,
#[serde(skip_serializing_if = "Option::is_none")]
reply_to: Option<Uuid>,
#[serde(skip_serializing_if = "Option::is_none")]
message_type: Option<String>,
},
MessageEdit {
request_id: Uuid,
channel_id: Uuid,
message_id: Uuid,
body: String,
},
MessageDelete {
request_id: Uuid,
channel_id: Uuid,
message_id: Uuid,
},
PresenceUpdate {
request_id: Uuid,
status: String,
#[serde(skip_serializing_if = "Option::is_none")]
custom_status_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
custom_status_emoji: Option<String>,
},
ReadReceipt {
request_id: Uuid,
channel_id: Uuid,
last_read_message_id: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
last_seq: Option<i64>,
},
}
+46
View File
@@ -0,0 +1,46 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HandlerLimitError;
#[derive(Clone)]
pub struct HandlerLimiter {
sem: Arc<Semaphore>,
max_inflight: usize,
rejected: Arc<AtomicU64>,
}
impl HandlerLimiter {
pub fn new(max_inflight: usize) -> Self {
Self {
sem: Arc::new(Semaphore::new(max_inflight)),
max_inflight,
rejected: Arc::new(AtomicU64::new(0)),
}
}
pub fn try_acquire(&self) -> Result<OwnedSemaphorePermit, HandlerLimitError> {
match self.sem.clone().try_acquire_owned() {
Ok(permit) => Ok(permit),
Err(_) => {
self.rejected.fetch_add(1, Ordering::Relaxed);
Err(HandlerLimitError)
}
}
}
pub fn inflight(&self) -> usize {
self.max_inflight - self.sem.available_permits()
}
pub fn available(&self) -> usize {
self.sem.available_permits()
}
pub fn rejected_total(&self) -> u64 {
self.rejected.load(Ordering::Relaxed)
}
}
+36
View File
@@ -0,0 +1,36 @@
mod bridge;
mod dedup;
mod envelope;
mod handler;
mod inbound;
mod limiter;
mod nats;
mod outbound;
mod rate_limit;
mod reconnect;
mod redis_keys;
mod runtime;
mod seq;
mod session;
mod session_redis;
mod sink;
mod typing;
pub use bridge::NatsWsBridge;
pub use dedup::DedupManager;
pub use envelope::TransportEnvelope;
pub use inbound::WsInbound;
pub use limiter::HandlerLimiter;
pub use nats::ImNats;
pub use outbound::{
ArticleAction, ArticleEvent, CategoryAction, CategoryEvent, ChannelAction, ChannelEvent,
DraftAction, DraftEvent, FollowAction, FollowEvent, MemberAction, MemberEvent, MessageAction,
MessageEvent, PollAction, PollEvent, PresenceEvent, ReactionAction, ReactionEvent,
ThreadAction, ThreadEvent, TypingEvent, WsOutbound,
};
pub use rate_limit::{LocalRateLimiter, RateLimiter};
pub use reconnect::ReconnectManager;
pub use runtime::WsRuntime;
pub use seq::SeqAllocator;
pub use session::{WsSession, WsSessionManager, WsSessionState};
pub use sink::{WsReceiver, WsSender, WsSinkManager};
+81
View File
@@ -0,0 +1,81 @@
use std::sync::Arc;
use serde::Serialize;
use uuid::Uuid;
use crate::queue::NatsQueue;
#[derive(Clone)]
pub struct ImNats {
inner: Arc<NatsQueue>,
}
impl ImNats {
pub fn new(nats: Arc<NatsQueue>) -> Self {
Self { inner: nats }
}
pub async fn emit<T: Serialize>(&self, subject: &str, request_id: Uuid, event: &T) {
if let Err(e) = self
.inner
.publish_with_headers(
subject,
&serde_json::to_vec(event).unwrap_or_default(),
vec![("X-Request-Id".into(), request_id.to_string())],
)
.await
{
tracing::warn!(subject, error = %e, "nats emit failed");
}
}
#[inline]
pub fn channel_subject(channel_id: Uuid) -> String {
format!("im.channel.{channel_id}")
}
#[inline]
pub fn message_subject(channel_id: Uuid) -> String {
format!("im.message.{channel_id}")
}
#[inline]
pub fn thread_subject(channel_id: Uuid, thread_id: Uuid) -> String {
format!("im.thread.{channel_id}.{thread_id}")
}
#[inline]
pub fn member_subject(channel_id: Uuid) -> String {
format!("im.member.{channel_id}")
}
#[inline]
pub fn reaction_subject(channel_id: Uuid) -> String {
format!("im.reaction.{channel_id}")
}
#[inline]
pub fn typing_subject(channel_id: Uuid) -> String {
format!("im.typing.{channel_id}")
}
#[inline]
pub fn presence_subject(user_id: Uuid) -> String {
format!("im.presence.{user_id}")
}
#[inline]
pub fn poll_subject(channel_id: Uuid) -> String {
format!("im.poll.{channel_id}")
}
#[inline]
pub fn article_subject(channel_id: Uuid) -> String {
format!("im.article.{channel_id}")
}
#[inline]
pub fn workspace_channels_subject(workspace_name: &str) -> String {
format!("im.ws_channels.{workspace_name}")
}
}
+256
View File
@@ -0,0 +1,256 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WsOutbound {
AuthOk {
request_id: Uuid,
connection_id: Uuid,
heartbeat_interval_secs: u64,
},
AuthError {
request_id: Uuid,
message: String,
},
HeartbeatAck {
request_id: Uuid,
timestamp_ms: i64,
},
Error {
request_id: Uuid,
code: String,
message: String,
},
Typing {
request_id: Uuid,
data: TypingEvent,
},
Presence {
request_id: Uuid,
data: PresenceEvent,
},
Message {
request_id: Uuid,
data: MessageEvent,
},
Channel {
request_id: Uuid,
data: ChannelEvent,
},
Thread {
request_id: Uuid,
data: ThreadEvent,
},
Member {
request_id: Uuid,
data: MemberEvent,
},
Reaction {
request_id: Uuid,
data: ReactionEvent,
},
Poll {
request_id: Uuid,
data: PollEvent,
},
Article {
request_id: Uuid,
data: ArticleEvent,
},
Category {
request_id: Uuid,
data: CategoryEvent,
},
Draft {
request_id: Uuid,
data: DraftEvent,
},
Follow {
request_id: Uuid,
data: FollowEvent,
},
ReadReceiptAck {
request_id: Uuid,
channel_id: Uuid,
last_read_message_id: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
last_seq: Option<i64>,
},
SeqAck {
request_id: Uuid,
channel_id: Uuid,
seq: i64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypingEvent {
pub channel_id: Uuid,
pub thread_id: Option<Uuid>,
pub user_id: Uuid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PresenceEvent {
pub user_id: Uuid,
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_status_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_status_emoji: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageEvent {
pub channel_id: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
pub thread_id: Option<Uuid>,
pub message_id: Uuid,
pub author_id: Uuid,
pub action: MessageAction,
#[serde(skip_serializing_if = "Option::is_none")]
pub body: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seq: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageAction {
Created,
Edited,
Deleted,
Pinned,
Unpinned,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChannelEvent {
pub channel_id: Uuid,
pub action: ChannelAction,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub workspace_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChannelAction {
Created,
Updated,
Deleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadEvent {
pub channel_id: Uuid,
pub thread_id: Uuid,
pub action: ThreadAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ThreadAction {
Created,
Updated,
Deleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemberEvent {
pub channel_id: Uuid,
pub user_id: Uuid,
pub action: MemberAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MemberAction {
Joined,
Left,
Kicked,
Updated,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReactionEvent {
pub channel_id: Uuid,
pub message_id: Uuid,
pub user_id: Uuid,
pub action: ReactionAction,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ReactionAction {
Added,
Removed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PollEvent {
pub channel_id: Uuid,
pub poll_id: Uuid,
pub action: PollAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PollAction {
Created,
Voted,
Deleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArticleEvent {
pub channel_id: Uuid,
pub article_id: Uuid,
pub action: ArticleAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ArticleAction {
Created,
Updated,
Published,
Unpublished,
Deleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoryEvent {
pub workspace_name: String,
pub category_id: Uuid,
pub action: CategoryAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CategoryAction {
Created,
Updated,
Deleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DraftEvent {
pub channel_id: Uuid,
pub user_id: Uuid,
pub thread_id: Option<Uuid>,
pub action: DraftAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DraftAction {
Saved,
Deleted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FollowEvent {
pub channel_id: Uuid,
pub follow_id: Uuid,
pub action: FollowAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FollowAction {
Created,
Deleted,
Retried,
}
+102
View File
@@ -0,0 +1,102 @@
use std::time::Instant;
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::AppResult;
use ::redis::Cmd;
use super::redis_keys::*;
pub struct RateLimiter {
redis: AppRedis,
max_per_sec: u32,
}
impl RateLimiter {
pub fn new(redis: AppRedis) -> Self {
Self {
redis,
max_per_sec: WS_MAX_MESSAGES_PER_SEC,
}
}
pub fn with_limit(redis: AppRedis, max_per_sec: u32) -> Self {
Self { redis, max_per_sec }
}
pub fn check(&self, connection_id: Uuid) -> AppResult<bool> {
let key = format!("{WS_RATE_PREFIX}{connection_id}");
let mut conn = self.redis.get_connection()?;
let count: i64 = Cmd::new()
.arg("INCR")
.arg(&key)
.query(&mut *conn.inner_mut())
.map_err(crate::error::AppError::Redis)?;
if count == 1 {
let _ = Cmd::new()
.arg("EXPIRE")
.arg(&key)
.arg(1_u64)
.query::<()>(&mut *conn.inner_mut());
}
Ok(count <= self.max_per_sec as i64)
}
pub fn check_sliding(&self, connection_id: Uuid) -> AppResult<bool> {
let key = format!("{WS_RATE_PREFIX}{connection_id}");
let mut conn = self.redis.get_connection()?;
let count: i64 = Cmd::new()
.arg("INCR")
.arg(&key)
.query(&mut *conn.inner_mut())
.map_err(crate::error::AppError::Redis)?;
if count == 1 {
let _ = Cmd::new()
.arg("EXPIRE")
.arg(&key)
.arg(2_u64)
.query::<()>(&mut *conn.inner_mut());
}
Ok(count <= self.max_per_sec as i64)
}
pub fn remaining(&self, connection_id: Uuid) -> AppResult<u32> {
let key = format!("{WS_RATE_PREFIX}{connection_id}");
let mut conn = self.redis.get_connection()?;
let count: Option<i64> = Cmd::new()
.arg("GET")
.arg(&key)
.query(&mut *conn.inner_mut())
.map_err(crate::error::AppError::Redis)?;
Ok(self.max_per_sec.saturating_sub(count.unwrap_or(0) as u32))
}
}
pub struct LocalRateLimiter {
count: std::sync::atomic::AtomicU32,
start: std::sync::Mutex<Instant>,
max_per_sec: u32,
}
impl LocalRateLimiter {
pub fn new(max_per_sec: u32) -> Self {
Self {
count: std::sync::atomic::AtomicU32::new(0),
start: std::sync::Mutex::new(Instant::now()),
max_per_sec,
}
}
pub fn check(&self) -> bool {
let mut start = self.start.lock().unwrap();
if start.elapsed().as_secs() >= 1 {
self.count.store(0, std::sync::atomic::Ordering::Relaxed);
*start = Instant::now();
}
drop(start);
self.count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
< self.max_per_sec
}
}
+101
View File
@@ -0,0 +1,101 @@
use std::collections::HashMap;
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
use ::redis::Cmd;
use super::redis_keys::*;
pub struct ReconnectManager {
redis: AppRedis,
}
impl ReconnectManager {
pub fn new(redis: AppRedis) -> Self {
Self { redis }
}
pub fn save_read_position(&self, user_id: Uuid, channel_id: Uuid, seq: i64) -> AppResult<()> {
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
let mut conn = self.redis.get_connection()?;
Cmd::new()
.arg("SETEX")
.arg(&key)
.arg(WS_RECONNECT_STATE_TTL_SECS)
.arg(seq.to_string())
.query::<()>(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
Ok(())
}
pub fn save_read_positions(
&self,
user_id: Uuid,
positions: &HashMap<Uuid, i64>,
) -> AppResult<()> {
let mut conn = self.redis.get_connection()?;
for (channel_id, seq) in positions {
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
Cmd::new()
.arg("SETEX")
.arg(&key)
.arg(WS_RECONNECT_STATE_TTL_SECS)
.arg(seq.to_string())
.query::<()>(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
}
Ok(())
}
pub fn get_last_seq(&self, user_id: Uuid, channel_id: Uuid) -> AppResult<Option<i64>> {
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
let mut conn = self.redis.get_connection()?;
let val: Option<String> = Cmd::new()
.arg("GET")
.arg(&key)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
Ok(val.and_then(|v| v.parse().ok()))
}
pub fn get_all_positions(&self, user_id: Uuid) -> AppResult<HashMap<Uuid, i64>> {
let pattern = format!("{WS_RECONNECT_PREFIX}{user_id}:*");
let mut conn = self.redis.get_connection()?;
let keys: Vec<String> = Cmd::new()
.arg("KEYS")
.arg(&pattern)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
let mut result = HashMap::new();
let prefix_len = format!("{WS_RECONNECT_PREFIX}{user_id}:").len();
for key in &keys {
if let Some(channel_str) = key.get(prefix_len..)
&& let Ok(channel_id) = channel_str.parse::<Uuid>()
{
let val: Option<String> = Cmd::new()
.arg("GET")
.arg(key)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
if let Some(v) = val
&& let Ok(seq) = v.parse::<i64>()
{
result.insert(channel_id, seq);
}
}
}
Ok(result)
}
pub fn cleanup_channel(&self, user_id: Uuid, channel_id: Uuid) -> AppResult<()> {
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
let mut conn = self.redis.get_connection()?;
let _ = Cmd::new()
.arg("DEL")
.arg(&key)
.query::<()>(&mut *conn.inner_mut());
Ok(())
}
}
+19
View File
@@ -0,0 +1,19 @@
#![allow(dead_code)]
pub const WS_TOKEN_PREFIX: &str = "im:ws:token:";
pub const WS_ONLINE_PREFIX: &str = "im:ws:online:";
pub const WS_CONNS_PREFIX: &str = "im:ws:conns:";
pub const WS_SEQ_PREFIX: &str = "im:seq:";
pub const WS_DEDUP_PREFIX: &str = "im:dedup:";
pub const WS_RATE_PREFIX: &str = "im:rate:";
pub const WS_RECONNECT_PREFIX: &str = "im:reconnect:";
pub const WS_TOKEN_TTL_SECS: u64 = 30;
pub const WS_ONLINE_TTL_SECS: u64 = 60;
pub const WS_HEARTBEAT_INTERVAL_SECS: u64 = 30;
pub const WS_HEARTBEAT_TIMEOUT_SECS: u64 = 60;
pub const WS_MAX_IDLE_SECS: u64 = 300;
pub const WS_MAX_MESSAGE_BYTES: usize = 64 * 1024;
pub const WS_MAX_MESSAGES_PER_SEC: u32 = 100;
pub const WS_SEQ_SEGMENT_SIZE: u64 = 1024;
pub const WS_DEDUP_WINDOW_SECS: u64 = 300;
pub const WS_RECONNECT_STATE_TTL_SECS: u64 = 86400;
+52
View File
@@ -0,0 +1,52 @@
use std::sync::Arc;
use uuid::Uuid;
use crate::queue::NatsQueue;
use super::{NatsWsBridge, WsReceiver, WsSender, WsSessionManager, WsSinkManager};
#[derive(Clone)]
pub struct WsRuntime {
sessions: Arc<WsSessionManager>,
sinks: Arc<WsSinkManager>,
bridge: NatsWsBridge,
}
impl WsRuntime {
pub fn new(queue: Arc<NatsQueue>, sessions: Arc<WsSessionManager>) -> Self {
let sinks = Arc::new(WsSinkManager::new());
let bridge = NatsWsBridge::new(queue, sessions.clone(), sinks.clone());
Self {
sessions,
sinks,
bridge,
}
}
pub fn sinks(&self) -> Arc<WsSinkManager> {
self.sinks.clone()
}
pub fn sessions(&self) -> Arc<WsSessionManager> {
self.sessions.clone()
}
pub fn attach(&self, connection_id: Uuid) -> WsReceiver {
let (tx, rx): (WsSender, WsReceiver) = WsSinkManager::channel();
self.sinks.attach(connection_id, tx);
rx
}
pub fn detach(&self, connection_id: Uuid) {
self.sinks.detach(connection_id);
self.sessions.unsubscribe_all(connection_id);
}
pub fn start_nats_bridge(&self) {
let bridge = self.bridge.clone();
tokio::spawn(async move {
bridge.run_ephemeral("im.>").await;
});
}
}
+117
View File
@@ -0,0 +1,117 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use dashmap::DashMap;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
use ::redis::Cmd;
use super::redis_keys::*;
struct Segment {
end: i64,
next: AtomicI64,
}
pub struct SeqAllocator {
redis: AppRedis,
segments: DashMap<Uuid, Arc<Segment>>,
locks: DashMap<Uuid, Arc<Mutex<()>>>,
segment_size: u64,
}
const MAX_RETRIES: u32 = 3;
impl SeqAllocator {
pub fn new(redis: AppRedis) -> Self {
Self {
redis,
segments: DashMap::new(),
locks: DashMap::new(),
segment_size: WS_SEQ_SEGMENT_SIZE,
}
}
pub async fn next(&self, channel_id: Uuid) -> AppResult<i64> {
for _ in 0..MAX_RETRIES {
if let Some(seq) = self.try_allocate(&channel_id) {
return Ok(seq);
}
let lock = self
.locks
.entry(channel_id)
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
let _guard = lock.lock().await;
if let Some(seq) = self.try_allocate(&channel_id) {
return Ok(seq);
}
self.refresh(channel_id).await?;
}
Err(AppError::InternalServerError(
"seq allocation exhausted retries".into(),
))
}
pub async fn bootstrap(&self, channel_id: Uuid, db_max: i64) -> AppResult<i64> {
let key = format!("{WS_SEQ_PREFIX}{channel_id}");
let mut conn = self.redis.get_connection()?;
let current: i64 = Cmd::new()
.arg("SET")
.arg(&key)
.arg(db_max)
.arg("NX")
.arg("EX")
.arg(86400)
.query::<Option<String>>(&mut *conn.inner_mut())
.map_err(AppError::Redis)?
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| {
let existing: i64 = Cmd::new()
.arg("GET")
.arg(&key)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)
.unwrap_or(db_max);
if existing < db_max { db_max } else { existing }
});
self.segments.remove(&channel_id);
Ok(current)
}
fn try_allocate(&self, channel_id: &Uuid) -> Option<i64> {
let state = self.segments.get(channel_id)?;
let next = state.next.fetch_add(1, Ordering::Relaxed);
if next < state.end { Some(next) } else { None }
}
async fn refresh(&self, channel_id: Uuid) -> AppResult<()> {
let key = format!("{WS_SEQ_PREFIX}{channel_id}");
let mut conn = self.redis.get_connection()?;
let counter: i64 = Cmd::new()
.arg("INCRBY")
.arg(&key)
.arg(self.segment_size as i64)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
let start = counter - self.segment_size as i64 + 1;
let end = counter + 1;
self.segments.insert(
channel_id,
Arc::new(Segment {
end,
next: AtomicI64::new(start),
}),
);
let _ = Cmd::new()
.arg("EXPIRE")
.arg(&key)
.arg(86400_u64)
.query::<()>(&mut *conn.inner_mut());
Ok(())
}
}
+301
View File
@@ -0,0 +1,301 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
use crate::queue::NatsQueue;
use ::redis::Cmd;
use super::redis_keys::*;
use super::session_redis::{heartbeat_redis, register_redis_online, unregister_redis_online};
use super::typing;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WsSessionState {
Connecting,
Authenticated,
Replaced,
Closing,
Closed,
}
impl WsSessionState {
pub fn is_deliverable(self) -> bool {
matches!(self, Self::Authenticated)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsSession {
pub user_id: Uuid,
pub device_id: String,
pub connection_id: Uuid,
pub workspace_name: String,
pub connected_at: i64,
pub authenticated_at: Option<i64>,
pub state: WsSessionState,
pub superseded_by: Option<Uuid>,
}
#[derive(Clone)]
pub struct WsSessionManager {
redis: AppRedis,
#[allow(dead_code)]
nats: Arc<NatsQueue>,
user_devices: Arc<DashMap<Uuid, HashMap<String, Uuid>>>,
sessions: Arc<DashMap<Uuid, WsSession>>,
channel_routes: Arc<DashMap<Uuid, HashSet<Uuid>>>,
session_channels: Arc<DashMap<Uuid, HashSet<Uuid>>>,
}
impl WsSessionManager {
pub fn new(redis: AppRedis, nats: Arc<NatsQueue>) -> Self {
Self {
redis,
nats,
user_devices: Arc::new(DashMap::new()),
sessions: Arc::new(DashMap::new()),
channel_routes: Arc::new(DashMap::new()),
session_channels: Arc::new(DashMap::new()),
}
}
pub fn issue_token(&self, user_id: Uuid, workspace_name: &str) -> AppResult<String> {
self.issue_token_for_device(user_id, workspace_name, "default")
}
pub fn issue_token_for_device(
&self,
user_id: Uuid,
workspace_name: &str,
device_id: &str,
) -> AppResult<String> {
let token = format!("ws_{}", Uuid::now_v7());
let session = WsSession {
user_id,
device_id: device_id.to_string(),
connection_id: Uuid::nil(),
workspace_name: workspace_name.to_string(),
connected_at: 0,
authenticated_at: None,
state: WsSessionState::Connecting,
superseded_by: None,
};
let json = serde_json::to_string(&session)?;
let key = format!("{WS_TOKEN_PREFIX}{token}");
let mut conn = self.redis.get_connection()?;
Cmd::new()
.arg("SETEX")
.arg(&key)
.arg(WS_TOKEN_TTL_SECS)
.arg(&json)
.query::<()>(&mut *conn.inner_mut())?;
Ok(token)
}
pub fn redeem_token(&self, token: &str) -> AppResult<WsSession> {
let key = format!("{WS_TOKEN_PREFIX}{token}");
let mut conn = self.redis.get_connection()?;
let json: Option<String> = Cmd::new()
.arg("GETDEL")
.arg(&key)
.query::<Option<String>>(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
let json = json.ok_or(AppError::Unauthorized)?;
let mut session: WsSession = serde_json::from_str(&json)
.map_err(|e| AppError::Config(format!("invalid ws session: {e}")))?;
let now = chrono::Utc::now().timestamp_millis();
session.connection_id = Uuid::now_v7();
session.connected_at = now;
session.authenticated_at = Some(now);
session.state = WsSessionState::Authenticated;
session.superseded_by = None;
Ok(session)
}
pub fn register_connection(&self, session: &WsSession) -> AppResult<()> {
let _ = self.register_connection_with_replacement(session)?;
Ok(())
}
pub fn register_connection_with_replacement(
&self,
session: &WsSession,
) -> AppResult<Option<Uuid>> {
let mut current = session.clone();
current.state = WsSessionState::Authenticated;
current.superseded_by = None;
self.sessions.insert(current.connection_id, current.clone());
let replaced = {
let mut entry = self.user_devices.entry(current.user_id).or_default();
entry.insert(current.device_id.clone(), current.connection_id)
};
if let Some(old_id) = replaced
&& old_id != current.connection_id
{
if let Some(mut old) = self.sessions.get_mut(&old_id) {
old.state = WsSessionState::Replaced;
old.superseded_by = Some(current.connection_id);
}
self.unsubscribe_all(old_id);
}
register_redis_online(&self.redis, &current)?;
Ok(replaced.filter(|old| *old != current.connection_id))
}
pub fn unregister_connection(&self, session: &WsSession) -> AppResult<()> {
let removed = self.sessions.remove(&session.connection_id).map(|(_, s)| s);
let current = removed.as_ref().unwrap_or(session);
self.unsubscribe_all(current.connection_id);
if let Some(mut devices) = self.user_devices.get_mut(&current.user_id)
&& devices.get(&current.device_id).copied() == Some(current.connection_id)
{
devices.remove(&current.device_id);
}
self.user_devices
.remove_if(&current.user_id, |_, devices| devices.is_empty());
unregister_redis_online(&self.redis, current)
}
pub fn heartbeat(&self, session: &WsSession) -> AppResult<()> {
if !self.is_deliverable(session.connection_id) {
return Err(AppError::Unauthorized);
}
heartbeat_redis(&self.redis, session)
}
pub fn subscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
self.channel_routes
.entry(channel_id)
.or_default()
.insert(connection_id);
self.session_channels
.entry(connection_id)
.or_default()
.insert(channel_id);
}
pub fn unsubscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
sessions.remove(&connection_id);
}
self.channel_routes
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
if let Some(mut channels) = self.session_channels.get_mut(&connection_id) {
channels.remove(&channel_id);
}
self.session_channels
.remove_if(&connection_id, |_, channels| channels.is_empty());
}
pub fn unsubscribe_all(&self, connection_id: Uuid) {
let channels = self
.session_channels
.remove(&connection_id)
.map(|(_, channels)| channels)
.unwrap_or_default();
for channel_id in channels {
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
sessions.remove(&connection_id);
}
self.channel_routes
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
}
}
pub fn subscribers(&self, channel_id: Uuid) -> Vec<Uuid> {
self.channel_routes
.get(&channel_id)
.map(|sessions| sessions.iter().copied().collect())
.unwrap_or_default()
}
pub fn user_connections(&self, user_id: Uuid) -> Vec<Uuid> {
self.user_devices
.get(&user_id)
.map(|devices| devices.values().copied().collect())
.unwrap_or_default()
}
pub fn workspace_connections(&self, workspace_name: &str) -> Vec<Uuid> {
self.sessions
.iter()
.filter_map(|entry| {
let session = entry.value();
(session.workspace_name == workspace_name && session.state.is_deliverable())
.then_some(session.connection_id)
})
.collect()
}
pub fn get_session(&self, connection_id: Uuid) -> Option<WsSession> {
self.sessions
.get(&connection_id)
.map(|session| session.clone())
}
pub fn is_deliverable(&self, connection_id: Uuid) -> bool {
self.sessions
.get(&connection_id)
.map(|session| session.state.is_deliverable() && session.superseded_by.is_none())
.unwrap_or(false)
}
pub fn is_user_online(&self, user_id: Uuid) -> AppResult<bool> {
Ok(self
.user_devices
.get(&user_id)
.map(|devices| !devices.is_empty())
.unwrap_or(false))
}
pub fn get_connection_count(&self, user_id: Uuid) -> AppResult<u32> {
Ok(self
.user_devices
.get(&user_id)
.map(|devices| devices.len() as u32)
.unwrap_or(0))
}
pub fn set_typing(
&self,
channel_id: Uuid,
thread_id: Option<Uuid>,
user_id: Uuid,
) -> AppResult<()> {
typing::set_typing(&self.redis, channel_id, thread_id, user_id)
}
pub fn clear_typing(
&self,
channel_id: Uuid,
thread_id: Option<Uuid>,
user_id: Uuid,
) -> AppResult<()> {
typing::clear_typing(&self.redis, channel_id, thread_id, user_id)
}
pub fn get_typing_users(
&self,
channel_id: Uuid,
thread_id: Option<Uuid>,
) -> AppResult<Vec<Uuid>> {
typing::get_typing_users(&self.redis, channel_id, thread_id)
}
pub fn heartbeat_interval(&self) -> Duration {
Duration::from_secs(WS_HEARTBEAT_INTERVAL_SECS)
}
pub fn heartbeat_interval_secs(&self) -> u64 {
WS_HEARTBEAT_INTERVAL_SECS
}
}
+93
View File
@@ -0,0 +1,93 @@
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
use crate::service::im::util::PRESENCE_PREFIX;
use ::redis::Cmd;
use super::redis_keys::*;
use super::session::WsSession;
pub fn register_redis_online(redis: &AppRedis, session: &WsSession) -> AppResult<()> {
let set_key = format!("{WS_ONLINE_PREFIX}{}", session.user_id);
let conn_id = session.connection_id.to_string();
let meta_key = format!("{WS_CONNS_PREFIX}{}", session.connection_id);
let mut conn = redis.get_connection()?;
Cmd::new()
.arg("SADD")
.arg(&set_key)
.arg(&conn_id)
.query::<i32>(&mut *conn.inner_mut())?;
Cmd::new()
.arg("EXPIRE")
.arg(&set_key)
.arg(WS_ONLINE_TTL_SECS)
.query::<()>(&mut *conn.inner_mut())?;
Cmd::new()
.arg("SETEX")
.arg(&meta_key)
.arg(WS_ONLINE_TTL_SECS)
.arg(session.workspace_name.as_str())
.query::<()>(&mut *conn.inner_mut())?;
Ok(())
}
pub fn unregister_redis_online(redis: &AppRedis, session: &WsSession) -> AppResult<()> {
let set_key = format!("{WS_ONLINE_PREFIX}{}", session.user_id);
let conn_id = session.connection_id.to_string();
let meta_key = format!("{WS_CONNS_PREFIX}{}", session.connection_id);
let mut conn = redis.get_connection()?;
Cmd::new()
.arg("SREM")
.arg(&set_key)
.arg(&conn_id)
.query::<i32>(&mut *conn.inner_mut())?;
let remaining: i32 = Cmd::new()
.arg("SCARD")
.arg(&set_key)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
if remaining == 0 {
Cmd::new()
.arg("DEL")
.arg(&set_key)
.query::<()>(&mut *conn.inner_mut())?;
let pk = format!("{PRESENCE_PREFIX}{}", session.user_id);
let _ = Cmd::new()
.arg("DEL")
.arg(&pk)
.query::<()>(&mut *conn.inner_mut());
}
let _ = Cmd::new()
.arg("DEL")
.arg(&meta_key)
.query::<()>(&mut *conn.inner_mut());
Ok(())
}
pub fn heartbeat_redis(redis: &AppRedis, session: &WsSession) -> AppResult<()> {
let set_key = format!("{WS_ONLINE_PREFIX}{}", session.user_id);
let meta_key = format!("{WS_CONNS_PREFIX}{}", session.connection_id);
let pk = format!("{PRESENCE_PREFIX}{}", session.user_id);
let mut conn = redis.get_connection()?;
let _ = Cmd::new()
.arg("EXPIRE")
.arg(&set_key)
.arg(WS_ONLINE_TTL_SECS)
.query::<()>(&mut *conn.inner_mut());
let _ = Cmd::new()
.arg("SETEX")
.arg(&meta_key)
.arg(WS_ONLINE_TTL_SECS)
.arg(session.workspace_name.as_str())
.query::<()>(&mut *conn.inner_mut());
let _ = Cmd::new()
.arg("SETEX")
.arg(&pk)
.arg(WS_ONLINE_TTL_SECS)
.arg("online")
.query::<()>(&mut *conn.inner_mut());
Ok(())
}
+53
View File
@@ -0,0 +1,53 @@
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::mpsc;
use uuid::Uuid;
use super::WsOutbound;
pub type WsSender = mpsc::UnboundedSender<WsOutbound>;
pub type WsReceiver = mpsc::UnboundedReceiver<WsOutbound>;
#[derive(Clone, Default)]
pub struct WsSinkManager {
sinks: Arc<DashMap<Uuid, WsSender>>,
}
impl WsSinkManager {
pub fn new() -> Self {
Self::default()
}
pub fn channel() -> (WsSender, WsReceiver) {
mpsc::unbounded_channel()
}
pub fn attach(&self, connection_id: Uuid, sender: WsSender) {
self.sinks.insert(connection_id, sender);
}
pub fn detach(&self, connection_id: Uuid) {
self.sinks.remove(&connection_id);
}
pub fn send(&self, connection_id: Uuid, message: WsOutbound) -> bool {
self.sinks
.get(&connection_id)
.map(|sink| sink.send(message).is_ok())
.unwrap_or(false)
}
pub fn send_many<I>(&self, ids: I, message: WsOutbound) -> usize
where
I: IntoIterator<Item = Uuid>,
{
ids.into_iter()
.filter(|id| self.send(*id, message.clone()))
.count()
}
pub fn contains(&self, connection_id: Uuid) -> bool {
self.sinks.contains_key(&connection_id)
}
}
+71
View File
@@ -0,0 +1,71 @@
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
use crate::service::im::util::{TYPING_PREFIX, TYPING_TTL_SECS};
use ::redis::Cmd;
pub fn set_typing(
redis: &AppRedis,
channel_id: Uuid,
thread_id: Option<Uuid>,
user_id: Uuid,
) -> AppResult<()> {
let key = typing_key(channel_id, thread_id, user_id);
let mut conn = redis.get_connection()?;
Cmd::new()
.arg("SETEX")
.arg(&key)
.arg(TYPING_TTL_SECS as u64)
.arg("1")
.query::<()>(&mut *conn.inner_mut())?;
Ok(())
}
pub fn clear_typing(
redis: &AppRedis,
channel_id: Uuid,
thread_id: Option<Uuid>,
user_id: Uuid,
) -> AppResult<()> {
let key = typing_key(channel_id, thread_id, user_id);
let mut conn = redis.get_connection()?;
Cmd::new()
.arg("DEL")
.arg(&key)
.query::<()>(&mut *conn.inner_mut())?;
Ok(())
}
pub fn get_typing_users(
redis: &AppRedis,
channel_id: Uuid,
thread_id: Option<Uuid>,
) -> AppResult<Vec<Uuid>> {
let pattern = match thread_id {
Some(tid) => format!("{TYPING_PREFIX}{channel_id}:{tid}:*"),
None => format!("{TYPING_PREFIX}{channel_id}:*"),
};
let mut conn = redis.get_connection()?;
let keys: Vec<String> = Cmd::new()
.arg("KEYS")
.arg(&pattern)
.query(&mut *conn.inner_mut())
.map_err(AppError::Redis)?;
let mut ids = Vec::with_capacity(keys.len());
for key in &keys {
if let Some(part) = key.rsplit(':').next()
&& let Ok(uid) = part.parse::<Uuid>()
{
ids.push(uid);
}
}
Ok(ids)
}
fn typing_key(channel_id: Uuid, thread_id: Option<Uuid>, user_id: Uuid) -> String {
match thread_id {
Some(tid) => format!("{TYPING_PREFIX}{channel_id}:{tid}:{user_id}"),
None => format!("{TYPING_PREFIX}{channel_id}:{user_id}"),
}
}