feat: init
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::queue::NatsQueue;
|
||||
|
||||
use super::{
|
||||
ArticleEvent, CategoryEvent, DraftEvent, FollowEvent, MemberEvent, MessageEvent, PollEvent,
|
||||
PresenceEvent, ReactionEvent, ThreadEvent, TypingEvent, WsOutbound, WsSessionManager,
|
||||
WsSinkManager,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NatsWsBridge {
|
||||
queue: Arc<NatsQueue>,
|
||||
sessions: Arc<WsSessionManager>,
|
||||
sinks: Arc<WsSinkManager>,
|
||||
}
|
||||
|
||||
impl NatsWsBridge {
|
||||
pub fn new(
|
||||
queue: Arc<NatsQueue>,
|
||||
sessions: Arc<WsSessionManager>,
|
||||
sinks: Arc<WsSinkManager>,
|
||||
) -> Self {
|
||||
Self {
|
||||
queue,
|
||||
sessions,
|
||||
sinks,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_ephemeral(self, subject: &str) {
|
||||
let Ok(mut sub) = self.queue.subscribe_ephemeral(subject.to_string()).await else {
|
||||
tracing::warn!(subject, "nats ws bridge subscribe failed");
|
||||
return;
|
||||
};
|
||||
while let Some(msg) = sub.next().await {
|
||||
self.dispatch(msg.subject.as_str(), msg.payload.as_ref(), request_id(&msg))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn dispatch(&self, subject: &str, payload: &[u8], request_id: Uuid) {
|
||||
if subject.starts_with("im.message.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Message { request_id, data });
|
||||
} else if subject.starts_with("im.thread.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Thread { request_id, data });
|
||||
} else if subject.starts_with("im.member.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Member { request_id, data });
|
||||
} else if subject.starts_with("im.reaction.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Reaction { request_id, data });
|
||||
} else if subject.starts_with("im.poll.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Poll { request_id, data });
|
||||
} else if subject.starts_with("im.article.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Article { request_id, data });
|
||||
} else if subject.starts_with("im.typing.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Typing { request_id, data });
|
||||
} else if subject.starts_with("im.presence.") {
|
||||
self.presence_event(payload, request_id);
|
||||
} else if subject.starts_with("im.channel.") {
|
||||
self.channel_meta_event(subject, payload, request_id);
|
||||
} else if subject.starts_with("im.category.") {
|
||||
self.category_event(payload, request_id);
|
||||
} else if subject.starts_with("im.draft.") {
|
||||
self.draft_event(payload, request_id);
|
||||
} else if subject.starts_with("im.follow.") {
|
||||
self.channel_event(payload, |data| WsOutbound::Follow { request_id, data });
|
||||
}
|
||||
}
|
||||
|
||||
fn channel_event<T, F>(&self, payload: &[u8], build: F)
|
||||
where
|
||||
T: serde::de::DeserializeOwned + ChannelScoped,
|
||||
F: Fn(T) -> WsOutbound,
|
||||
{
|
||||
let Ok(data) = serde_json::from_slice::<T>(payload) else {
|
||||
tracing::warn!("nats ws bridge decode channel event failed");
|
||||
return;
|
||||
};
|
||||
let channel_id = data.channel_id();
|
||||
let subscribers = self.sessions.subscribers(channel_id);
|
||||
let delivered = self.sinks.send_many(subscribers, build(data));
|
||||
tracing::debug!(%channel_id, delivered, "nats event forwarded to ws subscribers");
|
||||
}
|
||||
|
||||
fn presence_event(&self, payload: &[u8], request_id: Uuid) {
|
||||
let Ok(data) = serde_json::from_slice::<PresenceEvent>(payload) else {
|
||||
tracing::warn!("nats ws bridge decode presence event failed");
|
||||
return;
|
||||
};
|
||||
let ids = self.sessions.user_connections(data.user_id);
|
||||
let delivered = self
|
||||
.sinks
|
||||
.send_many(ids, WsOutbound::Presence { request_id, data });
|
||||
tracing::debug!(delivered, "nats presence forwarded to ws subscribers");
|
||||
}
|
||||
|
||||
fn category_event(&self, payload: &[u8], request_id: Uuid) {
|
||||
let Ok(data) = serde_json::from_slice::<CategoryEvent>(payload) else {
|
||||
tracing::warn!("nats ws bridge decode category event failed");
|
||||
return;
|
||||
};
|
||||
let targets = self.sessions.workspace_connections(&data.workspace_name);
|
||||
let delivered = self
|
||||
.sinks
|
||||
.send_many(targets, WsOutbound::Category { request_id, data });
|
||||
tracing::debug!(delivered, "nats category event forwarded to ws subscribers");
|
||||
}
|
||||
|
||||
fn draft_event(&self, payload: &[u8], request_id: Uuid) {
|
||||
let Ok(data) = serde_json::from_slice::<DraftEvent>(payload) else {
|
||||
tracing::warn!("nats ws bridge decode draft event failed");
|
||||
return;
|
||||
};
|
||||
let targets = self.sessions.user_connections(data.user_id);
|
||||
let delivered = self
|
||||
.sinks
|
||||
.send_many(targets, WsOutbound::Draft { request_id, data });
|
||||
tracing::debug!(delivered, "nats draft event forwarded to ws subscribers");
|
||||
}
|
||||
|
||||
fn channel_meta_event(&self, subject: &str, payload: &[u8], request_id: Uuid) {
|
||||
let Ok(data) = serde_json::from_slice::<super::ChannelEvent>(payload) else {
|
||||
tracing::warn!("nats ws bridge decode channel event failed");
|
||||
return;
|
||||
};
|
||||
let mut targets = data
|
||||
.workspace_name
|
||||
.as_deref()
|
||||
.map(|workspace| self.sessions.workspace_connections(workspace))
|
||||
.unwrap_or_else(|| self.sessions.subscribers(data.channel_id));
|
||||
if targets.is_empty()
|
||||
&& let Some(id) = subject
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.and_then(|v| v.parse::<Uuid>().ok())
|
||||
{
|
||||
targets = self.sessions.subscribers(id);
|
||||
}
|
||||
let delivered = self
|
||||
.sinks
|
||||
.send_many(targets, WsOutbound::Channel { request_id, data });
|
||||
tracing::debug!(delivered, "nats channel event forwarded to ws subscribers");
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ChannelScoped {
|
||||
fn channel_id(&self) -> Uuid;
|
||||
}
|
||||
|
||||
impl ChannelScoped for MessageEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for ThreadEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for MemberEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for ReactionEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for PollEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for ArticleEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for TypingEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
impl ChannelScoped for FollowEvent {
|
||||
fn channel_id(&self) -> Uuid {
|
||||
self.channel_id
|
||||
}
|
||||
}
|
||||
|
||||
fn request_id(msg: &async_nats::Message) -> Uuid {
|
||||
msg.headers
|
||||
.as_ref()
|
||||
.and_then(|h| h.get("X-Request-Id"))
|
||||
.and_then(|v| v.as_str().parse().ok())
|
||||
.unwrap_or_else(Uuid::nil)
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::AppResult;
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
|
||||
pub struct DedupManager {
|
||||
redis: AppRedis,
|
||||
window_secs: u64,
|
||||
}
|
||||
|
||||
impl DedupManager {
|
||||
pub fn new(redis: AppRedis) -> Self {
|
||||
Self {
|
||||
redis,
|
||||
window_secs: WS_DEDUP_WINDOW_SECS,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_and_mark(&self, message_id: Uuid, channel_id: Uuid) -> AppResult<bool> {
|
||||
let key = format!("{WS_DEDUP_PREFIX}{channel_id}:{message_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let result: Option<String> = Cmd::new()
|
||||
.arg("SET")
|
||||
.arg(&key)
|
||||
.arg("1")
|
||||
.arg("NX")
|
||||
.arg("EX")
|
||||
.arg(self.window_secs)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(crate::error::AppError::Redis)?;
|
||||
Ok(result.is_some())
|
||||
}
|
||||
|
||||
pub fn is_duplicate(&self, message_id: Uuid, channel_id: Uuid) -> AppResult<bool> {
|
||||
let key = format!("{WS_DEDUP_PREFIX}{channel_id}:{message_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let exists: bool = Cmd::new()
|
||||
.arg("EXISTS")
|
||||
.arg(&key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(crate::error::AppError::Redis)?;
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
pub fn clear(&self, message_id: Uuid, channel_id: Uuid) -> AppResult<()> {
|
||||
let key = format!("{WS_DEDUP_PREFIX}{channel_id}:{message_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&key)
|
||||
.query::<()>(&mut *conn.inner_mut())
|
||||
.map_err(crate::error::AppError::Redis)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TransportEnvelope<T> {
|
||||
#[serde(default = "Uuid::now_v7")]
|
||||
pub message_id: Uuid,
|
||||
pub request_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub payload: T,
|
||||
#[serde(default = "default_timestamp")]
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
#[serde(default)]
|
||||
pub attempt: u8,
|
||||
}
|
||||
|
||||
fn default_timestamp() -> chrono::DateTime<chrono::Utc> {
|
||||
chrono::Utc::now()
|
||||
}
|
||||
|
||||
impl<T> TransportEnvelope<T> {
|
||||
pub fn new(request_id: Uuid, user_id: Uuid, payload: T) -> Self {
|
||||
Self {
|
||||
message_id: Uuid::now_v7(),
|
||||
request_id,
|
||||
user_id,
|
||||
payload,
|
||||
created_at: chrono::Utc::now(),
|
||||
attempt: 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retry(self) -> Self {
|
||||
Self {
|
||||
attempt: self.attempt + 1,
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,447 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::immediate::dedup::DedupManager;
|
||||
use crate::immediate::limiter::HandlerLimiter;
|
||||
use crate::immediate::nats::ImNats;
|
||||
use crate::immediate::outbound::*;
|
||||
use crate::immediate::rate_limit::{LocalRateLimiter, RateLimiter};
|
||||
use crate::immediate::reconnect::ReconnectManager;
|
||||
use crate::immediate::session::{WsSession, WsSessionManager};
|
||||
use crate::service::ImService;
|
||||
use crate::service::im::messages::EditMessageParams;
|
||||
use crate::service::im::messages::SendMessageParams;
|
||||
use crate::service::im::presence::UpdatePresenceParams;
|
||||
use crate::service::im::session::ImSession;
|
||||
|
||||
use super::inbound::WsInbound;
|
||||
use super::redis_keys::*;
|
||||
use super::sink::WsSinkManager;
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub struct WsHandler {
|
||||
nats: Arc<ImNats>,
|
||||
manager: Arc<WsSessionManager>,
|
||||
sinks: Arc<WsSinkManager>,
|
||||
service: ImService,
|
||||
dedup: Arc<DedupManager>,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
local_limiter: Arc<LocalRateLimiter>,
|
||||
handler_limiter: Arc<HandlerLimiter>,
|
||||
reconnect: Arc<ReconnectManager>,
|
||||
session: Option<WsSession>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl WsHandler {
|
||||
pub fn new(
|
||||
manager: Arc<WsSessionManager>,
|
||||
sinks: Arc<WsSinkManager>,
|
||||
service: ImService,
|
||||
nats: Arc<ImNats>,
|
||||
dedup: Arc<DedupManager>,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
reconnect: Arc<ReconnectManager>,
|
||||
) -> Self {
|
||||
Self {
|
||||
nats,
|
||||
manager,
|
||||
sinks,
|
||||
service,
|
||||
dedup,
|
||||
rate_limiter,
|
||||
local_limiter: Arc::new(LocalRateLimiter::new(WS_MAX_MESSAGES_PER_SEC)),
|
||||
handler_limiter: Arc::new(HandlerLimiter::new(1024)),
|
||||
reconnect,
|
||||
session: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn session(&self) -> Option<&WsSession> {
|
||||
self.session.as_ref()
|
||||
}
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.session.is_some()
|
||||
}
|
||||
|
||||
pub fn handle_disconnect(&self) {
|
||||
if let Some(s) = &self.session
|
||||
&& let Err(e) = self.manager.unregister_connection(s)
|
||||
{
|
||||
tracing::warn!(conn = %s.connection_id, error = %e, "unregister failed");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle(&mut self, msg: WsInbound) -> Vec<WsOutbound> {
|
||||
match msg {
|
||||
WsInbound::Auth { request_id, token } => self.handle_auth(request_id, token).await,
|
||||
m => {
|
||||
let Some(s) = &self.session else {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id: request_id_of(&m),
|
||||
code: "not_authenticated".into(),
|
||||
message: "authenticate first".into(),
|
||||
}];
|
||||
};
|
||||
if !self.manager.is_deliverable(s.connection_id) {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id: request_id_of(&m),
|
||||
code: "session_not_active".into(),
|
||||
message: "session is not active".into(),
|
||||
}];
|
||||
}
|
||||
let Ok(_permit) = self.handler_limiter.try_acquire() else {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id: request_id_of(&m),
|
||||
code: "overloaded".into(),
|
||||
message: "too many inflight messages".into(),
|
||||
}];
|
||||
};
|
||||
if !self.local_limiter.check() {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id: request_id_of(&m),
|
||||
code: "rate_limit_exceeded".into(),
|
||||
message: "too many messages".into(),
|
||||
}];
|
||||
}
|
||||
match self.rate_limiter.check(s.connection_id) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id: request_id_of(&m),
|
||||
code: "rate_limit_exceeded".into(),
|
||||
message: "too many messages".into(),
|
||||
}];
|
||||
}
|
||||
Err(e) => tracing::warn!(error = %e, "rate limit check failed"),
|
||||
}
|
||||
self.dispatch(s, m).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn dispatch(&self, session: &WsSession, msg: WsInbound) -> Vec<WsOutbound> {
|
||||
match msg {
|
||||
WsInbound::Heartbeat { request_id } => {
|
||||
if let Err(e) = self.manager.heartbeat(session) {
|
||||
tracing::warn!(user = %session.user_id, error = %e, "heartbeat failed");
|
||||
}
|
||||
vec![WsOutbound::HeartbeatAck {
|
||||
request_id,
|
||||
timestamp_ms: chrono::Utc::now().timestamp_millis(),
|
||||
}]
|
||||
}
|
||||
WsInbound::JoinChannel {
|
||||
request_id,
|
||||
channel_id,
|
||||
} => match self.service.resolve_channel(channel_id).await {
|
||||
Ok(channel) => match self
|
||||
.service
|
||||
.ensure_channel_readable(session.user_id, &channel)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
self.manager
|
||||
.subscribe_channel(session.connection_id, channel_id);
|
||||
vec![]
|
||||
}
|
||||
Err(e) => vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "join_channel_failed".into(),
|
||||
message: e.to_string(),
|
||||
}],
|
||||
},
|
||||
Err(e) => vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "join_channel_failed".into(),
|
||||
message: e.to_string(),
|
||||
}],
|
||||
},
|
||||
WsInbound::LeaveChannel {
|
||||
request_id: _,
|
||||
channel_id,
|
||||
} => {
|
||||
self.manager
|
||||
.unsubscribe_channel(session.connection_id, channel_id);
|
||||
vec![]
|
||||
}
|
||||
WsInbound::TypingStart {
|
||||
request_id,
|
||||
channel_id,
|
||||
thread_id,
|
||||
} => {
|
||||
let _ = self
|
||||
.manager
|
||||
.set_typing(channel_id, thread_id, session.user_id);
|
||||
self.nats
|
||||
.emit(
|
||||
&ImNats::typing_subject(channel_id),
|
||||
request_id,
|
||||
&TypingEvent {
|
||||
channel_id,
|
||||
thread_id,
|
||||
user_id: session.user_id,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
vec![]
|
||||
}
|
||||
WsInbound::TypingStop {
|
||||
request_id: _,
|
||||
channel_id,
|
||||
thread_id,
|
||||
} => {
|
||||
let _ = self
|
||||
.manager
|
||||
.clear_typing(channel_id, thread_id, session.user_id);
|
||||
vec![]
|
||||
}
|
||||
WsInbound::MessageSend {
|
||||
request_id,
|
||||
channel_id,
|
||||
body,
|
||||
thread_id,
|
||||
reply_to,
|
||||
message_type,
|
||||
} => {
|
||||
if body.len() > WS_MAX_MESSAGE_BYTES {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "message_too_large".into(),
|
||||
message: "message body too large".into(),
|
||||
}];
|
||||
}
|
||||
match self.dedup.check_and_mark(request_id, channel_id) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "duplicate".into(),
|
||||
message: "duplicate message".into(),
|
||||
}];
|
||||
}
|
||||
Err(e) => tracing::warn!(error = %e, "dedup check failed"),
|
||||
}
|
||||
let ctx = ImSession::new(session.user_id);
|
||||
let params = SendMessageParams {
|
||||
body,
|
||||
message_type,
|
||||
thread_id,
|
||||
reply_to_message_id: reply_to,
|
||||
pinned: None,
|
||||
attachments: None,
|
||||
embeds: None,
|
||||
};
|
||||
match self
|
||||
.service
|
||||
.message_send(
|
||||
&ctx,
|
||||
&session.workspace_name,
|
||||
channel_id,
|
||||
params,
|
||||
request_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(msg) => vec![WsOutbound::SeqAck {
|
||||
request_id,
|
||||
channel_id,
|
||||
seq: msg.seq,
|
||||
}],
|
||||
Err(e) => {
|
||||
if let Err(clear_err) = self.dedup.clear(request_id, channel_id) {
|
||||
tracing::warn!(error = %clear_err, "dedup clear failed after message send error");
|
||||
}
|
||||
vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "message_send_failed".into(),
|
||||
message: e.to_string(),
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
WsInbound::MessageEdit {
|
||||
request_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
body,
|
||||
} => {
|
||||
if body.len() > WS_MAX_MESSAGE_BYTES {
|
||||
return vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "message_too_large".into(),
|
||||
message: "message body too large".into(),
|
||||
}];
|
||||
}
|
||||
let ctx = ImSession::new(session.user_id);
|
||||
let params = EditMessageParams { body };
|
||||
match self
|
||||
.service
|
||||
.message_edit(
|
||||
&ctx,
|
||||
&session.workspace_name,
|
||||
channel_id,
|
||||
message_id,
|
||||
params,
|
||||
request_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => vec![],
|
||||
Err(e) => vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "message_edit_failed".into(),
|
||||
message: e.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
WsInbound::MessageDelete {
|
||||
request_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
} => {
|
||||
let ctx = ImSession::new(session.user_id);
|
||||
match self
|
||||
.service
|
||||
.message_delete(
|
||||
&ctx,
|
||||
&session.workspace_name,
|
||||
channel_id,
|
||||
message_id,
|
||||
request_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => vec![],
|
||||
Err(e) => vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "message_delete_failed".into(),
|
||||
message: e.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
WsInbound::PresenceUpdate {
|
||||
request_id,
|
||||
status,
|
||||
custom_status_text,
|
||||
custom_status_emoji,
|
||||
} => {
|
||||
let ctx = ImSession::new(session.user_id);
|
||||
let params = UpdatePresenceParams {
|
||||
status,
|
||||
custom_status_text: custom_status_text.clone(),
|
||||
custom_status_emoji: custom_status_emoji.clone(),
|
||||
};
|
||||
match self
|
||||
.service
|
||||
.presence_update(&ctx, &session.workspace_name, params)
|
||||
.await
|
||||
{
|
||||
Ok(p) => {
|
||||
self.nats
|
||||
.emit(
|
||||
&ImNats::presence_subject(session.user_id),
|
||||
request_id,
|
||||
&PresenceEvent {
|
||||
user_id: session.user_id,
|
||||
status: p.status.to_string(),
|
||||
custom_status_text,
|
||||
custom_status_emoji,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
vec![]
|
||||
}
|
||||
Err(e) => vec![WsOutbound::Error {
|
||||
request_id,
|
||||
code: "presence_update_failed".into(),
|
||||
message: e.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
WsInbound::ReadReceipt {
|
||||
request_id,
|
||||
channel_id,
|
||||
last_read_message_id,
|
||||
last_seq,
|
||||
} => {
|
||||
if let Some(seq) = last_seq
|
||||
&& let Err(e) =
|
||||
self.reconnect
|
||||
.save_read_position(session.user_id, channel_id, seq)
|
||||
{
|
||||
tracing::warn!(error = %e, "save read position failed");
|
||||
}
|
||||
vec![WsOutbound::ReadReceiptAck {
|
||||
request_id,
|
||||
channel_id,
|
||||
last_read_message_id,
|
||||
last_seq,
|
||||
}]
|
||||
}
|
||||
WsInbound::Auth { .. } => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn close_replaced_connection(&self, old_id: Uuid, new_id: Uuid) {
|
||||
let _ = self.sinks.send(
|
||||
old_id,
|
||||
WsOutbound::Error {
|
||||
request_id: Uuid::nil(),
|
||||
code: "session_replaced".into(),
|
||||
message: format!("session replaced by {new_id}"),
|
||||
},
|
||||
);
|
||||
self.sinks.detach(old_id);
|
||||
if let Some(old) = self.manager.get_session(old_id)
|
||||
&& let Err(e) = self.manager.unregister_connection(&old)
|
||||
{
|
||||
tracing::warn!(conn = %old_id, error = %e, "unregister replaced connection failed");
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_auth(&mut self, request_id: Uuid, token: String) -> Vec<WsOutbound> {
|
||||
match self.manager.redeem_token(&token) {
|
||||
Ok(session) => {
|
||||
match self.manager.register_connection_with_replacement(&session) {
|
||||
Ok(Some(old_id)) => {
|
||||
self.close_replaced_connection(old_id, session.connection_id)
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!(error = %e, "register connection failed"),
|
||||
}
|
||||
let cid = session.connection_id;
|
||||
let interval = self.manager.heartbeat_interval_secs();
|
||||
self.session = Some(session);
|
||||
vec![WsOutbound::AuthOk {
|
||||
request_id,
|
||||
connection_id: cid,
|
||||
heartbeat_interval_secs: interval,
|
||||
}]
|
||||
}
|
||||
Err(e) => vec![WsOutbound::AuthError {
|
||||
request_id,
|
||||
message: e.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn request_id_of(msg: &WsInbound) -> Uuid {
|
||||
match msg {
|
||||
WsInbound::Auth { request_id, .. } => *request_id,
|
||||
WsInbound::Heartbeat { request_id } => *request_id,
|
||||
WsInbound::JoinChannel { request_id, .. } => *request_id,
|
||||
WsInbound::LeaveChannel { request_id, .. } => *request_id,
|
||||
WsInbound::TypingStart { request_id, .. } => *request_id,
|
||||
WsInbound::TypingStop { request_id, .. } => *request_id,
|
||||
WsInbound::MessageSend { request_id, .. } => *request_id,
|
||||
WsInbound::MessageEdit { request_id, .. } => *request_id,
|
||||
WsInbound::MessageDelete { request_id, .. } => *request_id,
|
||||
WsInbound::PresenceUpdate { request_id, .. } => *request_id,
|
||||
WsInbound::ReadReceipt { request_id, .. } => *request_id,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum WsInbound {
|
||||
Auth {
|
||||
request_id: Uuid,
|
||||
token: String,
|
||||
},
|
||||
Heartbeat {
|
||||
request_id: Uuid,
|
||||
},
|
||||
JoinChannel {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
},
|
||||
LeaveChannel {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
},
|
||||
TypingStart {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
},
|
||||
TypingStop {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
},
|
||||
MessageSend {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
body: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
thread_id: Option<Uuid>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reply_to: Option<Uuid>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
message_type: Option<String>,
|
||||
},
|
||||
MessageEdit {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
message_id: Uuid,
|
||||
body: String,
|
||||
},
|
||||
MessageDelete {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
message_id: Uuid,
|
||||
},
|
||||
PresenceUpdate {
|
||||
request_id: Uuid,
|
||||
status: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
custom_status_text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
custom_status_emoji: Option<String>,
|
||||
},
|
||||
ReadReceipt {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
last_read_message_id: Uuid,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
last_seq: Option<i64>,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct HandlerLimitError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HandlerLimiter {
|
||||
sem: Arc<Semaphore>,
|
||||
max_inflight: usize,
|
||||
rejected: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl HandlerLimiter {
|
||||
pub fn new(max_inflight: usize) -> Self {
|
||||
Self {
|
||||
sem: Arc::new(Semaphore::new(max_inflight)),
|
||||
max_inflight,
|
||||
rejected: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_acquire(&self) -> Result<OwnedSemaphorePermit, HandlerLimitError> {
|
||||
match self.sem.clone().try_acquire_owned() {
|
||||
Ok(permit) => Ok(permit),
|
||||
Err(_) => {
|
||||
self.rejected.fetch_add(1, Ordering::Relaxed);
|
||||
Err(HandlerLimitError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inflight(&self) -> usize {
|
||||
self.max_inflight - self.sem.available_permits()
|
||||
}
|
||||
|
||||
pub fn available(&self) -> usize {
|
||||
self.sem.available_permits()
|
||||
}
|
||||
|
||||
pub fn rejected_total(&self) -> u64 {
|
||||
self.rejected.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
mod bridge;
|
||||
mod dedup;
|
||||
mod envelope;
|
||||
mod handler;
|
||||
mod inbound;
|
||||
mod limiter;
|
||||
mod nats;
|
||||
mod outbound;
|
||||
mod rate_limit;
|
||||
mod reconnect;
|
||||
mod redis_keys;
|
||||
mod runtime;
|
||||
mod seq;
|
||||
mod session;
|
||||
mod session_redis;
|
||||
mod sink;
|
||||
mod typing;
|
||||
|
||||
pub use bridge::NatsWsBridge;
|
||||
pub use dedup::DedupManager;
|
||||
pub use envelope::TransportEnvelope;
|
||||
pub use inbound::WsInbound;
|
||||
pub use limiter::HandlerLimiter;
|
||||
pub use nats::ImNats;
|
||||
pub use outbound::{
|
||||
ArticleAction, ArticleEvent, CategoryAction, CategoryEvent, ChannelAction, ChannelEvent,
|
||||
DraftAction, DraftEvent, FollowAction, FollowEvent, MemberAction, MemberEvent, MessageAction,
|
||||
MessageEvent, PollAction, PollEvent, PresenceEvent, ReactionAction, ReactionEvent,
|
||||
ThreadAction, ThreadEvent, TypingEvent, WsOutbound,
|
||||
};
|
||||
pub use rate_limit::{LocalRateLimiter, RateLimiter};
|
||||
pub use reconnect::ReconnectManager;
|
||||
pub use runtime::WsRuntime;
|
||||
pub use seq::SeqAllocator;
|
||||
pub use session::{WsSession, WsSessionManager, WsSessionState};
|
||||
pub use sink::{WsReceiver, WsSender, WsSinkManager};
|
||||
@@ -0,0 +1,81 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::queue::NatsQueue;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ImNats {
|
||||
inner: Arc<NatsQueue>,
|
||||
}
|
||||
|
||||
impl ImNats {
|
||||
pub fn new(nats: Arc<NatsQueue>) -> Self {
|
||||
Self { inner: nats }
|
||||
}
|
||||
|
||||
pub async fn emit<T: Serialize>(&self, subject: &str, request_id: Uuid, event: &T) {
|
||||
if let Err(e) = self
|
||||
.inner
|
||||
.publish_with_headers(
|
||||
subject,
|
||||
&serde_json::to_vec(event).unwrap_or_default(),
|
||||
vec![("X-Request-Id".into(), request_id.to_string())],
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(subject, error = %e, "nats emit failed");
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn channel_subject(channel_id: Uuid) -> String {
|
||||
format!("im.channel.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn message_subject(channel_id: Uuid) -> String {
|
||||
format!("im.message.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn thread_subject(channel_id: Uuid, thread_id: Uuid) -> String {
|
||||
format!("im.thread.{channel_id}.{thread_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn member_subject(channel_id: Uuid) -> String {
|
||||
format!("im.member.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reaction_subject(channel_id: Uuid) -> String {
|
||||
format!("im.reaction.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn typing_subject(channel_id: Uuid) -> String {
|
||||
format!("im.typing.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn presence_subject(user_id: Uuid) -> String {
|
||||
format!("im.presence.{user_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn poll_subject(channel_id: Uuid) -> String {
|
||||
format!("im.poll.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn article_subject(channel_id: Uuid) -> String {
|
||||
format!("im.article.{channel_id}")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn workspace_channels_subject(workspace_name: &str) -> String {
|
||||
format!("im.ws_channels.{workspace_name}")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum WsOutbound {
|
||||
AuthOk {
|
||||
request_id: Uuid,
|
||||
connection_id: Uuid,
|
||||
heartbeat_interval_secs: u64,
|
||||
},
|
||||
AuthError {
|
||||
request_id: Uuid,
|
||||
message: String,
|
||||
},
|
||||
HeartbeatAck {
|
||||
request_id: Uuid,
|
||||
timestamp_ms: i64,
|
||||
},
|
||||
Error {
|
||||
request_id: Uuid,
|
||||
code: String,
|
||||
message: String,
|
||||
},
|
||||
Typing {
|
||||
request_id: Uuid,
|
||||
data: TypingEvent,
|
||||
},
|
||||
Presence {
|
||||
request_id: Uuid,
|
||||
data: PresenceEvent,
|
||||
},
|
||||
Message {
|
||||
request_id: Uuid,
|
||||
data: MessageEvent,
|
||||
},
|
||||
Channel {
|
||||
request_id: Uuid,
|
||||
data: ChannelEvent,
|
||||
},
|
||||
Thread {
|
||||
request_id: Uuid,
|
||||
data: ThreadEvent,
|
||||
},
|
||||
Member {
|
||||
request_id: Uuid,
|
||||
data: MemberEvent,
|
||||
},
|
||||
Reaction {
|
||||
request_id: Uuid,
|
||||
data: ReactionEvent,
|
||||
},
|
||||
Poll {
|
||||
request_id: Uuid,
|
||||
data: PollEvent,
|
||||
},
|
||||
Article {
|
||||
request_id: Uuid,
|
||||
data: ArticleEvent,
|
||||
},
|
||||
Category {
|
||||
request_id: Uuid,
|
||||
data: CategoryEvent,
|
||||
},
|
||||
Draft {
|
||||
request_id: Uuid,
|
||||
data: DraftEvent,
|
||||
},
|
||||
Follow {
|
||||
request_id: Uuid,
|
||||
data: FollowEvent,
|
||||
},
|
||||
ReadReceiptAck {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
last_read_message_id: Uuid,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
last_seq: Option<i64>,
|
||||
},
|
||||
SeqAck {
|
||||
request_id: Uuid,
|
||||
channel_id: Uuid,
|
||||
seq: i64,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TypingEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub thread_id: Option<Uuid>,
|
||||
pub user_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PresenceEvent {
|
||||
pub user_id: Uuid,
|
||||
pub status: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub custom_status_text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub custom_status_emoji: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageEvent {
|
||||
pub channel_id: Uuid,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thread_id: Option<Uuid>,
|
||||
pub message_id: Uuid,
|
||||
pub author_id: Uuid,
|
||||
pub action: MessageAction,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub body: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seq: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MessageAction {
|
||||
Created,
|
||||
Edited,
|
||||
Deleted,
|
||||
Pinned,
|
||||
Unpinned,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChannelEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub action: ChannelAction,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub workspace_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ChannelAction {
|
||||
Created,
|
||||
Updated,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ThreadEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
pub action: ThreadAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ThreadAction {
|
||||
Created,
|
||||
Updated,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemberEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub action: MemberAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum MemberAction {
|
||||
Joined,
|
||||
Left,
|
||||
Kicked,
|
||||
Updated,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReactionEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub message_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub action: ReactionAction,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ReactionAction {
|
||||
Added,
|
||||
Removed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PollEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub poll_id: Uuid,
|
||||
pub action: PollAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PollAction {
|
||||
Created,
|
||||
Voted,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ArticleEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub article_id: Uuid,
|
||||
pub action: ArticleAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ArticleAction {
|
||||
Created,
|
||||
Updated,
|
||||
Published,
|
||||
Unpublished,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CategoryEvent {
|
||||
pub workspace_name: String,
|
||||
pub category_id: Uuid,
|
||||
pub action: CategoryAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum CategoryAction {
|
||||
Created,
|
||||
Updated,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DraftEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub thread_id: Option<Uuid>,
|
||||
pub action: DraftAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum DraftAction {
|
||||
Saved,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FollowEvent {
|
||||
pub channel_id: Uuid,
|
||||
pub follow_id: Uuid,
|
||||
pub action: FollowAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FollowAction {
|
||||
Created,
|
||||
Deleted,
|
||||
Retried,
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::AppResult;
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
|
||||
pub struct RateLimiter {
|
||||
redis: AppRedis,
|
||||
max_per_sec: u32,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(redis: AppRedis) -> Self {
|
||||
Self {
|
||||
redis,
|
||||
max_per_sec: WS_MAX_MESSAGES_PER_SEC,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_limit(redis: AppRedis, max_per_sec: u32) -> Self {
|
||||
Self { redis, max_per_sec }
|
||||
}
|
||||
|
||||
pub fn check(&self, connection_id: Uuid) -> AppResult<bool> {
|
||||
let key = format!("{WS_RATE_PREFIX}{connection_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let count: i64 = Cmd::new()
|
||||
.arg("INCR")
|
||||
.arg(&key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(crate::error::AppError::Redis)?;
|
||||
if count == 1 {
|
||||
let _ = Cmd::new()
|
||||
.arg("EXPIRE")
|
||||
.arg(&key)
|
||||
.arg(1_u64)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
}
|
||||
Ok(count <= self.max_per_sec as i64)
|
||||
}
|
||||
|
||||
pub fn check_sliding(&self, connection_id: Uuid) -> AppResult<bool> {
|
||||
let key = format!("{WS_RATE_PREFIX}{connection_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let count: i64 = Cmd::new()
|
||||
.arg("INCR")
|
||||
.arg(&key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(crate::error::AppError::Redis)?;
|
||||
if count == 1 {
|
||||
let _ = Cmd::new()
|
||||
.arg("EXPIRE")
|
||||
.arg(&key)
|
||||
.arg(2_u64)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
}
|
||||
Ok(count <= self.max_per_sec as i64)
|
||||
}
|
||||
|
||||
pub fn remaining(&self, connection_id: Uuid) -> AppResult<u32> {
|
||||
let key = format!("{WS_RATE_PREFIX}{connection_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let count: Option<i64> = Cmd::new()
|
||||
.arg("GET")
|
||||
.arg(&key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(crate::error::AppError::Redis)?;
|
||||
Ok(self.max_per_sec.saturating_sub(count.unwrap_or(0) as u32))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalRateLimiter {
|
||||
count: std::sync::atomic::AtomicU32,
|
||||
start: std::sync::Mutex<Instant>,
|
||||
max_per_sec: u32,
|
||||
}
|
||||
|
||||
impl LocalRateLimiter {
|
||||
pub fn new(max_per_sec: u32) -> Self {
|
||||
Self {
|
||||
count: std::sync::atomic::AtomicU32::new(0),
|
||||
start: std::sync::Mutex::new(Instant::now()),
|
||||
max_per_sec,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check(&self) -> bool {
|
||||
let mut start = self.start.lock().unwrap();
|
||||
if start.elapsed().as_secs() >= 1 {
|
||||
self.count.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
*start = Instant::now();
|
||||
}
|
||||
drop(start);
|
||||
self.count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
< self.max_per_sec
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
|
||||
pub struct ReconnectManager {
|
||||
redis: AppRedis,
|
||||
}
|
||||
|
||||
impl ReconnectManager {
|
||||
pub fn new(redis: AppRedis) -> Self {
|
||||
Self { redis }
|
||||
}
|
||||
|
||||
pub fn save_read_position(&self, user_id: Uuid, channel_id: Uuid, seq: i64) -> AppResult<()> {
|
||||
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&key)
|
||||
.arg(WS_RECONNECT_STATE_TTL_SECS)
|
||||
.arg(seq.to_string())
|
||||
.query::<()>(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save_read_positions(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
positions: &HashMap<Uuid, i64>,
|
||||
) -> AppResult<()> {
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
for (channel_id, seq) in positions {
|
||||
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
|
||||
Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&key)
|
||||
.arg(WS_RECONNECT_STATE_TTL_SECS)
|
||||
.arg(seq.to_string())
|
||||
.query::<()>(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_last_seq(&self, user_id: Uuid, channel_id: Uuid) -> AppResult<Option<i64>> {
|
||||
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let val: Option<String> = Cmd::new()
|
||||
.arg("GET")
|
||||
.arg(&key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
Ok(val.and_then(|v| v.parse().ok()))
|
||||
}
|
||||
|
||||
pub fn get_all_positions(&self, user_id: Uuid) -> AppResult<HashMap<Uuid, i64>> {
|
||||
let pattern = format!("{WS_RECONNECT_PREFIX}{user_id}:*");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let keys: Vec<String> = Cmd::new()
|
||||
.arg("KEYS")
|
||||
.arg(&pattern)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
let mut result = HashMap::new();
|
||||
let prefix_len = format!("{WS_RECONNECT_PREFIX}{user_id}:").len();
|
||||
for key in &keys {
|
||||
if let Some(channel_str) = key.get(prefix_len..)
|
||||
&& let Ok(channel_id) = channel_str.parse::<Uuid>()
|
||||
{
|
||||
let val: Option<String> = Cmd::new()
|
||||
.arg("GET")
|
||||
.arg(key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
if let Some(v) = val
|
||||
&& let Ok(seq) = v.parse::<i64>()
|
||||
{
|
||||
result.insert(channel_id, seq);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn cleanup_channel(&self, user_id: Uuid, channel_id: Uuid) -> AppResult<()> {
|
||||
let key = format!("{WS_RECONNECT_PREFIX}{user_id}:{channel_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let _ = Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&key)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
#![allow(dead_code)]
|
||||
pub const WS_TOKEN_PREFIX: &str = "im:ws:token:";
|
||||
pub const WS_ONLINE_PREFIX: &str = "im:ws:online:";
|
||||
pub const WS_CONNS_PREFIX: &str = "im:ws:conns:";
|
||||
pub const WS_SEQ_PREFIX: &str = "im:seq:";
|
||||
pub const WS_DEDUP_PREFIX: &str = "im:dedup:";
|
||||
pub const WS_RATE_PREFIX: &str = "im:rate:";
|
||||
pub const WS_RECONNECT_PREFIX: &str = "im:reconnect:";
|
||||
|
||||
pub const WS_TOKEN_TTL_SECS: u64 = 30;
|
||||
pub const WS_ONLINE_TTL_SECS: u64 = 60;
|
||||
pub const WS_HEARTBEAT_INTERVAL_SECS: u64 = 30;
|
||||
pub const WS_HEARTBEAT_TIMEOUT_SECS: u64 = 60;
|
||||
pub const WS_MAX_IDLE_SECS: u64 = 300;
|
||||
pub const WS_MAX_MESSAGE_BYTES: usize = 64 * 1024;
|
||||
pub const WS_MAX_MESSAGES_PER_SEC: u32 = 100;
|
||||
pub const WS_SEQ_SEGMENT_SIZE: u64 = 1024;
|
||||
pub const WS_DEDUP_WINDOW_SECS: u64 = 300;
|
||||
pub const WS_RECONNECT_STATE_TTL_SECS: u64 = 86400;
|
||||
@@ -0,0 +1,52 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::queue::NatsQueue;
|
||||
|
||||
use super::{NatsWsBridge, WsReceiver, WsSender, WsSessionManager, WsSinkManager};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsRuntime {
|
||||
sessions: Arc<WsSessionManager>,
|
||||
sinks: Arc<WsSinkManager>,
|
||||
bridge: NatsWsBridge,
|
||||
}
|
||||
|
||||
impl WsRuntime {
|
||||
pub fn new(queue: Arc<NatsQueue>, sessions: Arc<WsSessionManager>) -> Self {
|
||||
let sinks = Arc::new(WsSinkManager::new());
|
||||
let bridge = NatsWsBridge::new(queue, sessions.clone(), sinks.clone());
|
||||
Self {
|
||||
sessions,
|
||||
sinks,
|
||||
bridge,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sinks(&self) -> Arc<WsSinkManager> {
|
||||
self.sinks.clone()
|
||||
}
|
||||
|
||||
pub fn sessions(&self) -> Arc<WsSessionManager> {
|
||||
self.sessions.clone()
|
||||
}
|
||||
|
||||
pub fn attach(&self, connection_id: Uuid) -> WsReceiver {
|
||||
let (tx, rx): (WsSender, WsReceiver) = WsSinkManager::channel();
|
||||
self.sinks.attach(connection_id, tx);
|
||||
rx
|
||||
}
|
||||
|
||||
pub fn detach(&self, connection_id: Uuid) {
|
||||
self.sinks.detach(connection_id);
|
||||
self.sessions.unsubscribe_all(connection_id);
|
||||
}
|
||||
|
||||
pub fn start_nats_bridge(&self) {
|
||||
let bridge = self.bridge.clone();
|
||||
tokio::spawn(async move {
|
||||
bridge.run_ephemeral("im.>").await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
|
||||
struct Segment {
|
||||
end: i64,
|
||||
next: AtomicI64,
|
||||
}
|
||||
|
||||
pub struct SeqAllocator {
|
||||
redis: AppRedis,
|
||||
segments: DashMap<Uuid, Arc<Segment>>,
|
||||
locks: DashMap<Uuid, Arc<Mutex<()>>>,
|
||||
segment_size: u64,
|
||||
}
|
||||
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
|
||||
impl SeqAllocator {
|
||||
pub fn new(redis: AppRedis) -> Self {
|
||||
Self {
|
||||
redis,
|
||||
segments: DashMap::new(),
|
||||
locks: DashMap::new(),
|
||||
segment_size: WS_SEQ_SEGMENT_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next(&self, channel_id: Uuid) -> AppResult<i64> {
|
||||
for _ in 0..MAX_RETRIES {
|
||||
if let Some(seq) = self.try_allocate(&channel_id) {
|
||||
return Ok(seq);
|
||||
}
|
||||
let lock = self
|
||||
.locks
|
||||
.entry(channel_id)
|
||||
.or_insert_with(|| Arc::new(Mutex::new(())))
|
||||
.clone();
|
||||
let _guard = lock.lock().await;
|
||||
if let Some(seq) = self.try_allocate(&channel_id) {
|
||||
return Ok(seq);
|
||||
}
|
||||
self.refresh(channel_id).await?;
|
||||
}
|
||||
Err(AppError::InternalServerError(
|
||||
"seq allocation exhausted retries".into(),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn bootstrap(&self, channel_id: Uuid, db_max: i64) -> AppResult<i64> {
|
||||
let key = format!("{WS_SEQ_PREFIX}{channel_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let current: i64 = Cmd::new()
|
||||
.arg("SET")
|
||||
.arg(&key)
|
||||
.arg(db_max)
|
||||
.arg("NX")
|
||||
.arg("EX")
|
||||
.arg(86400)
|
||||
.query::<Option<String>>(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or_else(|| {
|
||||
let existing: i64 = Cmd::new()
|
||||
.arg("GET")
|
||||
.arg(&key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)
|
||||
.unwrap_or(db_max);
|
||||
if existing < db_max { db_max } else { existing }
|
||||
});
|
||||
self.segments.remove(&channel_id);
|
||||
Ok(current)
|
||||
}
|
||||
|
||||
fn try_allocate(&self, channel_id: &Uuid) -> Option<i64> {
|
||||
let state = self.segments.get(channel_id)?;
|
||||
let next = state.next.fetch_add(1, Ordering::Relaxed);
|
||||
if next < state.end { Some(next) } else { None }
|
||||
}
|
||||
|
||||
async fn refresh(&self, channel_id: Uuid) -> AppResult<()> {
|
||||
let key = format!("{WS_SEQ_PREFIX}{channel_id}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let counter: i64 = Cmd::new()
|
||||
.arg("INCRBY")
|
||||
.arg(&key)
|
||||
.arg(self.segment_size as i64)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
let start = counter - self.segment_size as i64 + 1;
|
||||
let end = counter + 1;
|
||||
self.segments.insert(
|
||||
channel_id,
|
||||
Arc::new(Segment {
|
||||
end,
|
||||
next: AtomicI64::new(start),
|
||||
}),
|
||||
);
|
||||
let _ = Cmd::new()
|
||||
.arg("EXPIRE")
|
||||
.arg(&key)
|
||||
.arg(86400_u64)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
use crate::queue::NatsQueue;
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
use super::session_redis::{heartbeat_redis, register_redis_online, unregister_redis_online};
|
||||
use super::typing;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum WsSessionState {
|
||||
Connecting,
|
||||
Authenticated,
|
||||
Replaced,
|
||||
Closing,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl WsSessionState {
|
||||
pub fn is_deliverable(self) -> bool {
|
||||
matches!(self, Self::Authenticated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WsSession {
|
||||
pub user_id: Uuid,
|
||||
pub device_id: String,
|
||||
pub connection_id: Uuid,
|
||||
pub workspace_name: String,
|
||||
pub connected_at: i64,
|
||||
pub authenticated_at: Option<i64>,
|
||||
pub state: WsSessionState,
|
||||
pub superseded_by: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WsSessionManager {
|
||||
redis: AppRedis,
|
||||
#[allow(dead_code)]
|
||||
nats: Arc<NatsQueue>,
|
||||
user_devices: Arc<DashMap<Uuid, HashMap<String, Uuid>>>,
|
||||
sessions: Arc<DashMap<Uuid, WsSession>>,
|
||||
channel_routes: Arc<DashMap<Uuid, HashSet<Uuid>>>,
|
||||
session_channels: Arc<DashMap<Uuid, HashSet<Uuid>>>,
|
||||
}
|
||||
|
||||
impl WsSessionManager {
|
||||
pub fn new(redis: AppRedis, nats: Arc<NatsQueue>) -> Self {
|
||||
Self {
|
||||
redis,
|
||||
nats,
|
||||
user_devices: Arc::new(DashMap::new()),
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
channel_routes: Arc::new(DashMap::new()),
|
||||
session_channels: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn issue_token(&self, user_id: Uuid, workspace_name: &str) -> AppResult<String> {
|
||||
self.issue_token_for_device(user_id, workspace_name, "default")
|
||||
}
|
||||
|
||||
pub fn issue_token_for_device(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
workspace_name: &str,
|
||||
device_id: &str,
|
||||
) -> AppResult<String> {
|
||||
let token = format!("ws_{}", Uuid::now_v7());
|
||||
let session = WsSession {
|
||||
user_id,
|
||||
device_id: device_id.to_string(),
|
||||
connection_id: Uuid::nil(),
|
||||
workspace_name: workspace_name.to_string(),
|
||||
connected_at: 0,
|
||||
authenticated_at: None,
|
||||
state: WsSessionState::Connecting,
|
||||
superseded_by: None,
|
||||
};
|
||||
let json = serde_json::to_string(&session)?;
|
||||
let key = format!("{WS_TOKEN_PREFIX}{token}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&key)
|
||||
.arg(WS_TOKEN_TTL_SECS)
|
||||
.arg(&json)
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub fn redeem_token(&self, token: &str) -> AppResult<WsSession> {
|
||||
let key = format!("{WS_TOKEN_PREFIX}{token}");
|
||||
let mut conn = self.redis.get_connection()?;
|
||||
let json: Option<String> = Cmd::new()
|
||||
.arg("GETDEL")
|
||||
.arg(&key)
|
||||
.query::<Option<String>>(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
let json = json.ok_or(AppError::Unauthorized)?;
|
||||
let mut session: WsSession = serde_json::from_str(&json)
|
||||
.map_err(|e| AppError::Config(format!("invalid ws session: {e}")))?;
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
session.connection_id = Uuid::now_v7();
|
||||
session.connected_at = now;
|
||||
session.authenticated_at = Some(now);
|
||||
session.state = WsSessionState::Authenticated;
|
||||
session.superseded_by = None;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn register_connection(&self, session: &WsSession) -> AppResult<()> {
|
||||
let _ = self.register_connection_with_replacement(session)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn register_connection_with_replacement(
|
||||
&self,
|
||||
session: &WsSession,
|
||||
) -> AppResult<Option<Uuid>> {
|
||||
let mut current = session.clone();
|
||||
current.state = WsSessionState::Authenticated;
|
||||
current.superseded_by = None;
|
||||
self.sessions.insert(current.connection_id, current.clone());
|
||||
|
||||
let replaced = {
|
||||
let mut entry = self.user_devices.entry(current.user_id).or_default();
|
||||
entry.insert(current.device_id.clone(), current.connection_id)
|
||||
};
|
||||
|
||||
if let Some(old_id) = replaced
|
||||
&& old_id != current.connection_id
|
||||
{
|
||||
if let Some(mut old) = self.sessions.get_mut(&old_id) {
|
||||
old.state = WsSessionState::Replaced;
|
||||
old.superseded_by = Some(current.connection_id);
|
||||
}
|
||||
self.unsubscribe_all(old_id);
|
||||
}
|
||||
|
||||
register_redis_online(&self.redis, ¤t)?;
|
||||
Ok(replaced.filter(|old| *old != current.connection_id))
|
||||
}
|
||||
|
||||
pub fn unregister_connection(&self, session: &WsSession) -> AppResult<()> {
|
||||
let removed = self.sessions.remove(&session.connection_id).map(|(_, s)| s);
|
||||
let current = removed.as_ref().unwrap_or(session);
|
||||
self.unsubscribe_all(current.connection_id);
|
||||
|
||||
if let Some(mut devices) = self.user_devices.get_mut(¤t.user_id)
|
||||
&& devices.get(¤t.device_id).copied() == Some(current.connection_id)
|
||||
{
|
||||
devices.remove(¤t.device_id);
|
||||
}
|
||||
self.user_devices
|
||||
.remove_if(¤t.user_id, |_, devices| devices.is_empty());
|
||||
unregister_redis_online(&self.redis, current)
|
||||
}
|
||||
|
||||
pub fn heartbeat(&self, session: &WsSession) -> AppResult<()> {
|
||||
if !self.is_deliverable(session.connection_id) {
|
||||
return Err(AppError::Unauthorized);
|
||||
}
|
||||
heartbeat_redis(&self.redis, session)
|
||||
}
|
||||
|
||||
pub fn subscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
|
||||
self.channel_routes
|
||||
.entry(channel_id)
|
||||
.or_default()
|
||||
.insert(connection_id);
|
||||
self.session_channels
|
||||
.entry(connection_id)
|
||||
.or_default()
|
||||
.insert(channel_id);
|
||||
}
|
||||
|
||||
pub fn unsubscribe_channel(&self, connection_id: Uuid, channel_id: Uuid) {
|
||||
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
|
||||
sessions.remove(&connection_id);
|
||||
}
|
||||
self.channel_routes
|
||||
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
|
||||
if let Some(mut channels) = self.session_channels.get_mut(&connection_id) {
|
||||
channels.remove(&channel_id);
|
||||
}
|
||||
self.session_channels
|
||||
.remove_if(&connection_id, |_, channels| channels.is_empty());
|
||||
}
|
||||
|
||||
pub fn unsubscribe_all(&self, connection_id: Uuid) {
|
||||
let channels = self
|
||||
.session_channels
|
||||
.remove(&connection_id)
|
||||
.map(|(_, channels)| channels)
|
||||
.unwrap_or_default();
|
||||
for channel_id in channels {
|
||||
if let Some(mut sessions) = self.channel_routes.get_mut(&channel_id) {
|
||||
sessions.remove(&connection_id);
|
||||
}
|
||||
self.channel_routes
|
||||
.remove_if(&channel_id, |_, sessions| sessions.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribers(&self, channel_id: Uuid) -> Vec<Uuid> {
|
||||
self.channel_routes
|
||||
.get(&channel_id)
|
||||
.map(|sessions| sessions.iter().copied().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn user_connections(&self, user_id: Uuid) -> Vec<Uuid> {
|
||||
self.user_devices
|
||||
.get(&user_id)
|
||||
.map(|devices| devices.values().copied().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn workspace_connections(&self, workspace_name: &str) -> Vec<Uuid> {
|
||||
self.sessions
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let session = entry.value();
|
||||
(session.workspace_name == workspace_name && session.state.is_deliverable())
|
||||
.then_some(session.connection_id)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_session(&self, connection_id: Uuid) -> Option<WsSession> {
|
||||
self.sessions
|
||||
.get(&connection_id)
|
||||
.map(|session| session.clone())
|
||||
}
|
||||
|
||||
pub fn is_deliverable(&self, connection_id: Uuid) -> bool {
|
||||
self.sessions
|
||||
.get(&connection_id)
|
||||
.map(|session| session.state.is_deliverable() && session.superseded_by.is_none())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn is_user_online(&self, user_id: Uuid) -> AppResult<bool> {
|
||||
Ok(self
|
||||
.user_devices
|
||||
.get(&user_id)
|
||||
.map(|devices| !devices.is_empty())
|
||||
.unwrap_or(false))
|
||||
}
|
||||
|
||||
pub fn get_connection_count(&self, user_id: Uuid) -> AppResult<u32> {
|
||||
Ok(self
|
||||
.user_devices
|
||||
.get(&user_id)
|
||||
.map(|devices| devices.len() as u32)
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
pub fn set_typing(
|
||||
&self,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
user_id: Uuid,
|
||||
) -> AppResult<()> {
|
||||
typing::set_typing(&self.redis, channel_id, thread_id, user_id)
|
||||
}
|
||||
|
||||
pub fn clear_typing(
|
||||
&self,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
user_id: Uuid,
|
||||
) -> AppResult<()> {
|
||||
typing::clear_typing(&self.redis, channel_id, thread_id, user_id)
|
||||
}
|
||||
|
||||
pub fn get_typing_users(
|
||||
&self,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
) -> AppResult<Vec<Uuid>> {
|
||||
typing::get_typing_users(&self.redis, channel_id, thread_id)
|
||||
}
|
||||
|
||||
pub fn heartbeat_interval(&self) -> Duration {
|
||||
Duration::from_secs(WS_HEARTBEAT_INTERVAL_SECS)
|
||||
}
|
||||
pub fn heartbeat_interval_secs(&self) -> u64 {
|
||||
WS_HEARTBEAT_INTERVAL_SECS
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
use crate::service::im::util::PRESENCE_PREFIX;
|
||||
use ::redis::Cmd;
|
||||
|
||||
use super::redis_keys::*;
|
||||
use super::session::WsSession;
|
||||
|
||||
pub fn register_redis_online(redis: &AppRedis, session: &WsSession) -> AppResult<()> {
|
||||
let set_key = format!("{WS_ONLINE_PREFIX}{}", session.user_id);
|
||||
let conn_id = session.connection_id.to_string();
|
||||
let meta_key = format!("{WS_CONNS_PREFIX}{}", session.connection_id);
|
||||
let mut conn = redis.get_connection()?;
|
||||
|
||||
Cmd::new()
|
||||
.arg("SADD")
|
||||
.arg(&set_key)
|
||||
.arg(&conn_id)
|
||||
.query::<i32>(&mut *conn.inner_mut())?;
|
||||
Cmd::new()
|
||||
.arg("EXPIRE")
|
||||
.arg(&set_key)
|
||||
.arg(WS_ONLINE_TTL_SECS)
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&meta_key)
|
||||
.arg(WS_ONLINE_TTL_SECS)
|
||||
.arg(session.workspace_name.as_str())
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn unregister_redis_online(redis: &AppRedis, session: &WsSession) -> AppResult<()> {
|
||||
let set_key = format!("{WS_ONLINE_PREFIX}{}", session.user_id);
|
||||
let conn_id = session.connection_id.to_string();
|
||||
let meta_key = format!("{WS_CONNS_PREFIX}{}", session.connection_id);
|
||||
let mut conn = redis.get_connection()?;
|
||||
|
||||
Cmd::new()
|
||||
.arg("SREM")
|
||||
.arg(&set_key)
|
||||
.arg(&conn_id)
|
||||
.query::<i32>(&mut *conn.inner_mut())?;
|
||||
|
||||
let remaining: i32 = Cmd::new()
|
||||
.arg("SCARD")
|
||||
.arg(&set_key)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
if remaining == 0 {
|
||||
Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&set_key)
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
let pk = format!("{PRESENCE_PREFIX}{}", session.user_id);
|
||||
let _ = Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&pk)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
}
|
||||
let _ = Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&meta_key)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn heartbeat_redis(redis: &AppRedis, session: &WsSession) -> AppResult<()> {
|
||||
let set_key = format!("{WS_ONLINE_PREFIX}{}", session.user_id);
|
||||
let meta_key = format!("{WS_CONNS_PREFIX}{}", session.connection_id);
|
||||
let pk = format!("{PRESENCE_PREFIX}{}", session.user_id);
|
||||
let mut conn = redis.get_connection()?;
|
||||
|
||||
let _ = Cmd::new()
|
||||
.arg("EXPIRE")
|
||||
.arg(&set_key)
|
||||
.arg(WS_ONLINE_TTL_SECS)
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
let _ = Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&meta_key)
|
||||
.arg(WS_ONLINE_TTL_SECS)
|
||||
.arg(session.workspace_name.as_str())
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
let _ = Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&pk)
|
||||
.arg(WS_ONLINE_TTL_SECS)
|
||||
.arg("online")
|
||||
.query::<()>(&mut *conn.inner_mut());
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::WsOutbound;
|
||||
|
||||
pub type WsSender = mpsc::UnboundedSender<WsOutbound>;
|
||||
pub type WsReceiver = mpsc::UnboundedReceiver<WsOutbound>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct WsSinkManager {
|
||||
sinks: Arc<DashMap<Uuid, WsSender>>,
|
||||
}
|
||||
|
||||
impl WsSinkManager {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn channel() -> (WsSender, WsReceiver) {
|
||||
mpsc::unbounded_channel()
|
||||
}
|
||||
|
||||
pub fn attach(&self, connection_id: Uuid, sender: WsSender) {
|
||||
self.sinks.insert(connection_id, sender);
|
||||
}
|
||||
|
||||
pub fn detach(&self, connection_id: Uuid) {
|
||||
self.sinks.remove(&connection_id);
|
||||
}
|
||||
|
||||
pub fn send(&self, connection_id: Uuid, message: WsOutbound) -> bool {
|
||||
self.sinks
|
||||
.get(&connection_id)
|
||||
.map(|sink| sink.send(message).is_ok())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn send_many<I>(&self, ids: I, message: WsOutbound) -> usize
|
||||
where
|
||||
I: IntoIterator<Item = Uuid>,
|
||||
{
|
||||
ids.into_iter()
|
||||
.filter(|id| self.send(*id, message.clone()))
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn contains(&self, connection_id: Uuid) -> bool {
|
||||
self.sinks.contains_key(&connection_id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
use crate::service::im::util::{TYPING_PREFIX, TYPING_TTL_SECS};
|
||||
use ::redis::Cmd;
|
||||
|
||||
pub fn set_typing(
|
||||
redis: &AppRedis,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
user_id: Uuid,
|
||||
) -> AppResult<()> {
|
||||
let key = typing_key(channel_id, thread_id, user_id);
|
||||
let mut conn = redis.get_connection()?;
|
||||
Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&key)
|
||||
.arg(TYPING_TTL_SECS as u64)
|
||||
.arg("1")
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clear_typing(
|
||||
redis: &AppRedis,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
user_id: Uuid,
|
||||
) -> AppResult<()> {
|
||||
let key = typing_key(channel_id, thread_id, user_id);
|
||||
let mut conn = redis.get_connection()?;
|
||||
Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&key)
|
||||
.query::<()>(&mut *conn.inner_mut())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_typing_users(
|
||||
redis: &AppRedis,
|
||||
channel_id: Uuid,
|
||||
thread_id: Option<Uuid>,
|
||||
) -> AppResult<Vec<Uuid>> {
|
||||
let pattern = match thread_id {
|
||||
Some(tid) => format!("{TYPING_PREFIX}{channel_id}:{tid}:*"),
|
||||
None => format!("{TYPING_PREFIX}{channel_id}:*"),
|
||||
};
|
||||
let mut conn = redis.get_connection()?;
|
||||
let keys: Vec<String> = Cmd::new()
|
||||
.arg("KEYS")
|
||||
.arg(&pattern)
|
||||
.query(&mut *conn.inner_mut())
|
||||
.map_err(AppError::Redis)?;
|
||||
let mut ids = Vec::with_capacity(keys.len());
|
||||
for key in &keys {
|
||||
if let Some(part) = key.rsplit(':').next()
|
||||
&& let Ok(uid) = part.parse::<Uuid>()
|
||||
{
|
||||
ids.push(uid);
|
||||
}
|
||||
}
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
fn typing_key(channel_id: Uuid, thread_id: Option<Uuid>, user_id: Uuid) -> String {
|
||||
match thread_id {
|
||||
Some(tid) => format!("{TYPING_PREFIX}{channel_id}:{tid}:{user_id}"),
|
||||
None => format!("{TYPING_PREFIX}{channel_id}:{user_id}"),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user