448 lines
16 KiB
Rust
448 lines
16 KiB
Rust
use std::sync::Arc;
|
|
|
|
use uuid::Uuid;
|
|
|
|
use crate::immediate::dedup::DedupManager;
|
|
use crate::immediate::limiter::HandlerLimiter;
|
|
use crate::immediate::nats::ImNats;
|
|
use crate::immediate::outbound::*;
|
|
use crate::immediate::rate_limit::{LocalRateLimiter, RateLimiter};
|
|
use crate::immediate::reconnect::ReconnectManager;
|
|
use crate::immediate::session::{WsSession, WsSessionManager};
|
|
use crate::service::ImService;
|
|
use crate::service::im::messages::EditMessageParams;
|
|
use crate::service::im::messages::SendMessageParams;
|
|
use crate::service::im::presence::UpdatePresenceParams;
|
|
use crate::service::im::session::ImSession;
|
|
|
|
use super::inbound::WsInbound;
|
|
use super::redis_keys::*;
|
|
use super::sink::WsSinkManager;
|
|
|
|
#[allow(dead_code)]
|
|
#[derive(Clone)]
|
|
pub struct WsHandler {
|
|
nats: Arc<ImNats>,
|
|
manager: Arc<WsSessionManager>,
|
|
sinks: Arc<WsSinkManager>,
|
|
service: ImService,
|
|
dedup: Arc<DedupManager>,
|
|
rate_limiter: Arc<RateLimiter>,
|
|
local_limiter: Arc<LocalRateLimiter>,
|
|
handler_limiter: Arc<HandlerLimiter>,
|
|
reconnect: Arc<ReconnectManager>,
|
|
session: Option<WsSession>,
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
impl WsHandler {
|
|
pub fn new(
|
|
manager: Arc<WsSessionManager>,
|
|
sinks: Arc<WsSinkManager>,
|
|
service: ImService,
|
|
nats: Arc<ImNats>,
|
|
dedup: Arc<DedupManager>,
|
|
rate_limiter: Arc<RateLimiter>,
|
|
reconnect: Arc<ReconnectManager>,
|
|
) -> Self {
|
|
Self {
|
|
nats,
|
|
manager,
|
|
sinks,
|
|
service,
|
|
dedup,
|
|
rate_limiter,
|
|
local_limiter: Arc::new(LocalRateLimiter::new(WS_MAX_MESSAGES_PER_SEC)),
|
|
handler_limiter: Arc::new(HandlerLimiter::new(1024)),
|
|
reconnect,
|
|
session: None,
|
|
}
|
|
}
|
|
|
|
pub fn session(&self) -> Option<&WsSession> {
|
|
self.session.as_ref()
|
|
}
|
|
pub fn is_authenticated(&self) -> bool {
|
|
self.session.is_some()
|
|
}
|
|
|
|
pub fn handle_disconnect(&self) {
|
|
if let Some(s) = &self.session
|
|
&& let Err(e) = self.manager.unregister_connection(s)
|
|
{
|
|
tracing::warn!(conn = %s.connection_id, error = %e, "unregister failed");
|
|
}
|
|
}
|
|
|
|
pub async fn handle(&mut self, msg: WsInbound) -> Vec<WsOutbound> {
|
|
match msg {
|
|
WsInbound::Auth { request_id, token } => self.handle_auth(request_id, token).await,
|
|
m => {
|
|
let Some(s) = &self.session else {
|
|
return vec![WsOutbound::Error {
|
|
request_id: request_id_of(&m),
|
|
code: "not_authenticated".into(),
|
|
message: "authenticate first".into(),
|
|
}];
|
|
};
|
|
if !self.manager.is_deliverable(s.connection_id) {
|
|
return vec![WsOutbound::Error {
|
|
request_id: request_id_of(&m),
|
|
code: "session_not_active".into(),
|
|
message: "session is not active".into(),
|
|
}];
|
|
}
|
|
let Ok(_permit) = self.handler_limiter.try_acquire() else {
|
|
return vec![WsOutbound::Error {
|
|
request_id: request_id_of(&m),
|
|
code: "overloaded".into(),
|
|
message: "too many inflight messages".into(),
|
|
}];
|
|
};
|
|
if !self.local_limiter.check() {
|
|
return vec![WsOutbound::Error {
|
|
request_id: request_id_of(&m),
|
|
code: "rate_limit_exceeded".into(),
|
|
message: "too many messages".into(),
|
|
}];
|
|
}
|
|
match self.rate_limiter.check(s.connection_id) {
|
|
Ok(true) => {}
|
|
Ok(false) => {
|
|
return vec![WsOutbound::Error {
|
|
request_id: request_id_of(&m),
|
|
code: "rate_limit_exceeded".into(),
|
|
message: "too many messages".into(),
|
|
}];
|
|
}
|
|
Err(e) => tracing::warn!(error = %e, "rate limit check failed"),
|
|
}
|
|
self.dispatch(s, m).await
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn dispatch(&self, session: &WsSession, msg: WsInbound) -> Vec<WsOutbound> {
|
|
match msg {
|
|
WsInbound::Heartbeat { request_id } => {
|
|
if let Err(e) = self.manager.heartbeat(session) {
|
|
tracing::warn!(user = %session.user_id, error = %e, "heartbeat failed");
|
|
}
|
|
vec![WsOutbound::HeartbeatAck {
|
|
request_id,
|
|
timestamp_ms: chrono::Utc::now().timestamp_millis(),
|
|
}]
|
|
}
|
|
WsInbound::JoinChannel {
|
|
request_id,
|
|
channel_id,
|
|
} => match self.service.resolve_channel(channel_id).await {
|
|
Ok(channel) => match self
|
|
.service
|
|
.ensure_channel_readable(session.user_id, &channel)
|
|
.await
|
|
{
|
|
Ok(()) => {
|
|
self.manager
|
|
.subscribe_channel(session.connection_id, channel_id);
|
|
vec![]
|
|
}
|
|
Err(e) => vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "join_channel_failed".into(),
|
|
message: e.to_string(),
|
|
}],
|
|
},
|
|
Err(e) => vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "join_channel_failed".into(),
|
|
message: e.to_string(),
|
|
}],
|
|
},
|
|
WsInbound::LeaveChannel {
|
|
request_id: _,
|
|
channel_id,
|
|
} => {
|
|
self.manager
|
|
.unsubscribe_channel(session.connection_id, channel_id);
|
|
vec![]
|
|
}
|
|
WsInbound::TypingStart {
|
|
request_id,
|
|
channel_id,
|
|
thread_id,
|
|
} => {
|
|
let _ = self
|
|
.manager
|
|
.set_typing(channel_id, thread_id, session.user_id);
|
|
self.nats
|
|
.emit(
|
|
&ImNats::typing_subject(channel_id),
|
|
request_id,
|
|
&TypingEvent {
|
|
channel_id,
|
|
thread_id,
|
|
user_id: session.user_id,
|
|
},
|
|
)
|
|
.await;
|
|
vec![]
|
|
}
|
|
WsInbound::TypingStop {
|
|
request_id: _,
|
|
channel_id,
|
|
thread_id,
|
|
} => {
|
|
let _ = self
|
|
.manager
|
|
.clear_typing(channel_id, thread_id, session.user_id);
|
|
vec![]
|
|
}
|
|
WsInbound::MessageSend {
|
|
request_id,
|
|
channel_id,
|
|
body,
|
|
thread_id,
|
|
reply_to,
|
|
message_type,
|
|
} => {
|
|
if body.len() > WS_MAX_MESSAGE_BYTES {
|
|
return vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "message_too_large".into(),
|
|
message: "message body too large".into(),
|
|
}];
|
|
}
|
|
match self.dedup.check_and_mark(request_id, channel_id) {
|
|
Ok(true) => {}
|
|
Ok(false) => {
|
|
return vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "duplicate".into(),
|
|
message: "duplicate message".into(),
|
|
}];
|
|
}
|
|
Err(e) => tracing::warn!(error = %e, "dedup check failed"),
|
|
}
|
|
let ctx = ImSession::new(session.user_id);
|
|
let params = SendMessageParams {
|
|
body,
|
|
message_type,
|
|
thread_id,
|
|
reply_to_message_id: reply_to,
|
|
pinned: None,
|
|
attachments: None,
|
|
embeds: None,
|
|
};
|
|
match self
|
|
.service
|
|
.message_send(
|
|
&ctx,
|
|
&session.workspace_name,
|
|
channel_id,
|
|
params,
|
|
request_id,
|
|
)
|
|
.await
|
|
{
|
|
Ok(msg) => vec![WsOutbound::SeqAck {
|
|
request_id,
|
|
channel_id,
|
|
seq: msg.seq,
|
|
}],
|
|
Err(e) => {
|
|
if let Err(clear_err) = self.dedup.clear(request_id, channel_id) {
|
|
tracing::warn!(error = %clear_err, "dedup clear failed after message send error");
|
|
}
|
|
vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "message_send_failed".into(),
|
|
message: e.to_string(),
|
|
}]
|
|
}
|
|
}
|
|
}
|
|
WsInbound::MessageEdit {
|
|
request_id,
|
|
channel_id,
|
|
message_id,
|
|
body,
|
|
} => {
|
|
if body.len() > WS_MAX_MESSAGE_BYTES {
|
|
return vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "message_too_large".into(),
|
|
message: "message body too large".into(),
|
|
}];
|
|
}
|
|
let ctx = ImSession::new(session.user_id);
|
|
let params = EditMessageParams { body };
|
|
match self
|
|
.service
|
|
.message_edit(
|
|
&ctx,
|
|
&session.workspace_name,
|
|
channel_id,
|
|
message_id,
|
|
params,
|
|
request_id,
|
|
)
|
|
.await
|
|
{
|
|
Ok(_) => vec![],
|
|
Err(e) => vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "message_edit_failed".into(),
|
|
message: e.to_string(),
|
|
}],
|
|
}
|
|
}
|
|
WsInbound::MessageDelete {
|
|
request_id,
|
|
channel_id,
|
|
message_id,
|
|
} => {
|
|
let ctx = ImSession::new(session.user_id);
|
|
match self
|
|
.service
|
|
.message_delete(
|
|
&ctx,
|
|
&session.workspace_name,
|
|
channel_id,
|
|
message_id,
|
|
request_id,
|
|
)
|
|
.await
|
|
{
|
|
Ok(()) => vec![],
|
|
Err(e) => vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "message_delete_failed".into(),
|
|
message: e.to_string(),
|
|
}],
|
|
}
|
|
}
|
|
WsInbound::PresenceUpdate {
|
|
request_id,
|
|
status,
|
|
custom_status_text,
|
|
custom_status_emoji,
|
|
} => {
|
|
let ctx = ImSession::new(session.user_id);
|
|
let params = UpdatePresenceParams {
|
|
status,
|
|
custom_status_text: custom_status_text.clone(),
|
|
custom_status_emoji: custom_status_emoji.clone(),
|
|
};
|
|
match self
|
|
.service
|
|
.presence_update(&ctx, &session.workspace_name, params)
|
|
.await
|
|
{
|
|
Ok(p) => {
|
|
self.nats
|
|
.emit(
|
|
&ImNats::presence_subject(session.user_id),
|
|
request_id,
|
|
&PresenceEvent {
|
|
user_id: session.user_id,
|
|
status: p.status.to_string(),
|
|
custom_status_text,
|
|
custom_status_emoji,
|
|
},
|
|
)
|
|
.await;
|
|
vec![]
|
|
}
|
|
Err(e) => vec![WsOutbound::Error {
|
|
request_id,
|
|
code: "presence_update_failed".into(),
|
|
message: e.to_string(),
|
|
}],
|
|
}
|
|
}
|
|
WsInbound::ReadReceipt {
|
|
request_id,
|
|
channel_id,
|
|
last_read_message_id,
|
|
last_seq,
|
|
} => {
|
|
if let Some(seq) = last_seq
|
|
&& let Err(e) =
|
|
self.reconnect
|
|
.save_read_position(session.user_id, channel_id, seq)
|
|
{
|
|
tracing::warn!(error = %e, "save read position failed");
|
|
}
|
|
vec![WsOutbound::ReadReceiptAck {
|
|
request_id,
|
|
channel_id,
|
|
last_read_message_id,
|
|
last_seq,
|
|
}]
|
|
}
|
|
WsInbound::Auth { .. } => unreachable!(),
|
|
}
|
|
}
|
|
|
|
fn close_replaced_connection(&self, old_id: Uuid, new_id: Uuid) {
|
|
let _ = self.sinks.send(
|
|
old_id,
|
|
WsOutbound::Error {
|
|
request_id: Uuid::nil(),
|
|
code: "session_replaced".into(),
|
|
message: format!("session replaced by {new_id}"),
|
|
},
|
|
);
|
|
self.sinks.detach(old_id);
|
|
if let Some(old) = self.manager.get_session(old_id)
|
|
&& let Err(e) = self.manager.unregister_connection(&old)
|
|
{
|
|
tracing::warn!(conn = %old_id, error = %e, "unregister replaced connection failed");
|
|
}
|
|
}
|
|
|
|
async fn handle_auth(&mut self, request_id: Uuid, token: String) -> Vec<WsOutbound> {
|
|
match self.manager.redeem_token(&token) {
|
|
Ok(session) => {
|
|
match self.manager.register_connection_with_replacement(&session) {
|
|
Ok(Some(old_id)) => {
|
|
self.close_replaced_connection(old_id, session.connection_id)
|
|
}
|
|
Ok(None) => {}
|
|
Err(e) => tracing::warn!(error = %e, "register connection failed"),
|
|
}
|
|
let cid = session.connection_id;
|
|
let interval = self.manager.heartbeat_interval_secs();
|
|
self.session = Some(session);
|
|
vec![WsOutbound::AuthOk {
|
|
request_id,
|
|
connection_id: cid,
|
|
heartbeat_interval_secs: interval,
|
|
}]
|
|
}
|
|
Err(e) => vec![WsOutbound::AuthError {
|
|
request_id,
|
|
message: e.to_string(),
|
|
}],
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
fn request_id_of(msg: &WsInbound) -> Uuid {
|
|
match msg {
|
|
WsInbound::Auth { request_id, .. } => *request_id,
|
|
WsInbound::Heartbeat { request_id } => *request_id,
|
|
WsInbound::JoinChannel { request_id, .. } => *request_id,
|
|
WsInbound::LeaveChannel { request_id, .. } => *request_id,
|
|
WsInbound::TypingStart { request_id, .. } => *request_id,
|
|
WsInbound::TypingStop { request_id, .. } => *request_id,
|
|
WsInbound::MessageSend { request_id, .. } => *request_id,
|
|
WsInbound::MessageEdit { request_id, .. } => *request_id,
|
|
WsInbound::MessageDelete { request_id, .. } => *request_id,
|
|
WsInbound::PresenceUpdate { request_id, .. } => *request_id,
|
|
WsInbound::ReadReceipt { request_id, .. } => *request_id,
|
|
}
|
|
}
|