use std::sync::Arc; use uuid::Uuid; use crate::immediate::dedup::DedupManager; use crate::immediate::limiter::HandlerLimiter; use crate::immediate::nats::ImNats; use crate::immediate::outbound::*; use crate::immediate::rate_limit::{LocalRateLimiter, RateLimiter}; use crate::immediate::reconnect::ReconnectManager; use crate::immediate::session::{WsSession, WsSessionManager}; use crate::service::ImService; use crate::service::im::messages::EditMessageParams; use crate::service::im::messages::SendMessageParams; use crate::service::im::presence::UpdatePresenceParams; use crate::service::im::session::ImSession; use super::inbound::WsInbound; use super::redis_keys::*; use super::sink::WsSinkManager; #[allow(dead_code)] #[derive(Clone)] pub struct WsHandler { nats: Arc, manager: Arc, sinks: Arc, service: ImService, dedup: Arc, rate_limiter: Arc, local_limiter: Arc, handler_limiter: Arc, reconnect: Arc, session: Option, } #[allow(dead_code)] impl WsHandler { pub fn new( manager: Arc, sinks: Arc, service: ImService, nats: Arc, dedup: Arc, rate_limiter: Arc, reconnect: Arc, ) -> Self { Self { nats, manager, sinks, service, dedup, rate_limiter, local_limiter: Arc::new(LocalRateLimiter::new(WS_MAX_MESSAGES_PER_SEC)), handler_limiter: Arc::new(HandlerLimiter::new(1024)), reconnect, session: None, } } pub fn session(&self) -> Option<&WsSession> { self.session.as_ref() } pub fn is_authenticated(&self) -> bool { self.session.is_some() } pub fn handle_disconnect(&self) { if let Some(s) = &self.session && let Err(e) = self.manager.unregister_connection(s) { tracing::warn!(conn = %s.connection_id, error = %e, "unregister failed"); } } pub async fn handle(&mut self, msg: WsInbound) -> Vec { match msg { WsInbound::Auth { request_id, token } => self.handle_auth(request_id, token).await, m => { let Some(s) = &self.session else { return vec![WsOutbound::Error { request_id: request_id_of(&m), code: "not_authenticated".into(), message: "authenticate first".into(), }]; }; if !self.manager.is_deliverable(s.connection_id) { return vec![WsOutbound::Error { request_id: request_id_of(&m), code: "session_not_active".into(), message: "session is not active".into(), }]; } let Ok(_permit) = self.handler_limiter.try_acquire() else { return vec![WsOutbound::Error { request_id: request_id_of(&m), code: "overloaded".into(), message: "too many inflight messages".into(), }]; }; if !self.local_limiter.check() { return vec![WsOutbound::Error { request_id: request_id_of(&m), code: "rate_limit_exceeded".into(), message: "too many messages".into(), }]; } match self.rate_limiter.check(s.connection_id) { Ok(true) => {} Ok(false) => { return vec![WsOutbound::Error { request_id: request_id_of(&m), code: "rate_limit_exceeded".into(), message: "too many messages".into(), }]; } Err(e) => tracing::warn!(error = %e, "rate limit check failed"), } self.dispatch(s, m).await } } } async fn dispatch(&self, session: &WsSession, msg: WsInbound) -> Vec { match msg { WsInbound::Heartbeat { request_id } => { if let Err(e) = self.manager.heartbeat(session) { tracing::warn!(user = %session.user_id, error = %e, "heartbeat failed"); } vec![WsOutbound::HeartbeatAck { request_id, timestamp_ms: chrono::Utc::now().timestamp_millis(), }] } WsInbound::JoinChannel { request_id, channel_id, } => match self.service.resolve_channel(channel_id).await { Ok(channel) => match self .service .ensure_channel_readable(session.user_id, &channel) .await { Ok(()) => { self.manager .subscribe_channel(session.connection_id, channel_id); vec![] } Err(e) => vec![WsOutbound::Error { request_id, code: "join_channel_failed".into(), message: e.to_string(), }], }, Err(e) => vec![WsOutbound::Error { request_id, code: "join_channel_failed".into(), message: e.to_string(), }], }, WsInbound::LeaveChannel { request_id: _, channel_id, } => { self.manager .unsubscribe_channel(session.connection_id, channel_id); vec![] } WsInbound::TypingStart { request_id, channel_id, thread_id, } => { let _ = self .manager .set_typing(channel_id, thread_id, session.user_id); self.nats .emit( &ImNats::typing_subject(channel_id), request_id, &TypingEvent { channel_id, thread_id, user_id: session.user_id, }, ) .await; vec![] } WsInbound::TypingStop { request_id: _, channel_id, thread_id, } => { let _ = self .manager .clear_typing(channel_id, thread_id, session.user_id); vec![] } WsInbound::MessageSend { request_id, channel_id, body, thread_id, reply_to, message_type, } => { if body.len() > WS_MAX_MESSAGE_BYTES { return vec![WsOutbound::Error { request_id, code: "message_too_large".into(), message: "message body too large".into(), }]; } match self.dedup.check_and_mark(request_id, channel_id) { Ok(true) => {} Ok(false) => { return vec![WsOutbound::Error { request_id, code: "duplicate".into(), message: "duplicate message".into(), }]; } Err(e) => tracing::warn!(error = %e, "dedup check failed"), } let ctx = ImSession::new(session.user_id); let params = SendMessageParams { body, message_type, thread_id, reply_to_message_id: reply_to, pinned: None, attachments: None, embeds: None, }; match self .service .message_send( &ctx, &session.workspace_name, channel_id, params, request_id, ) .await { Ok(msg) => vec![WsOutbound::SeqAck { request_id, channel_id, seq: msg.seq, }], Err(e) => { if let Err(clear_err) = self.dedup.clear(request_id, channel_id) { tracing::warn!(error = %clear_err, "dedup clear failed after message send error"); } vec![WsOutbound::Error { request_id, code: "message_send_failed".into(), message: e.to_string(), }] } } } WsInbound::MessageEdit { request_id, channel_id, message_id, body, } => { if body.len() > WS_MAX_MESSAGE_BYTES { return vec![WsOutbound::Error { request_id, code: "message_too_large".into(), message: "message body too large".into(), }]; } let ctx = ImSession::new(session.user_id); let params = EditMessageParams { body }; match self .service .message_edit( &ctx, &session.workspace_name, channel_id, message_id, params, request_id, ) .await { Ok(_) => vec![], Err(e) => vec![WsOutbound::Error { request_id, code: "message_edit_failed".into(), message: e.to_string(), }], } } WsInbound::MessageDelete { request_id, channel_id, message_id, } => { let ctx = ImSession::new(session.user_id); match self .service .message_delete( &ctx, &session.workspace_name, channel_id, message_id, request_id, ) .await { Ok(()) => vec![], Err(e) => vec![WsOutbound::Error { request_id, code: "message_delete_failed".into(), message: e.to_string(), }], } } WsInbound::PresenceUpdate { request_id, status, custom_status_text, custom_status_emoji, } => { let ctx = ImSession::new(session.user_id); let params = UpdatePresenceParams { status, custom_status_text: custom_status_text.clone(), custom_status_emoji: custom_status_emoji.clone(), }; match self .service .presence_update(&ctx, &session.workspace_name, params) .await { Ok(p) => { self.nats .emit( &ImNats::presence_subject(session.user_id), request_id, &PresenceEvent { user_id: session.user_id, status: p.status.to_string(), custom_status_text, custom_status_emoji, }, ) .await; vec![] } Err(e) => vec![WsOutbound::Error { request_id, code: "presence_update_failed".into(), message: e.to_string(), }], } } WsInbound::ReadReceipt { request_id, channel_id, last_read_message_id, last_seq, } => { if let Some(seq) = last_seq && let Err(e) = self.reconnect .save_read_position(session.user_id, channel_id, seq) { tracing::warn!(error = %e, "save read position failed"); } vec![WsOutbound::ReadReceiptAck { request_id, channel_id, last_read_message_id, last_seq, }] } WsInbound::Auth { .. } => unreachable!(), } } fn close_replaced_connection(&self, old_id: Uuid, new_id: Uuid) { let _ = self.sinks.send( old_id, WsOutbound::Error { request_id: Uuid::nil(), code: "session_replaced".into(), message: format!("session replaced by {new_id}"), }, ); self.sinks.detach(old_id); if let Some(old) = self.manager.get_session(old_id) && let Err(e) = self.manager.unregister_connection(&old) { tracing::warn!(conn = %old_id, error = %e, "unregister replaced connection failed"); } } async fn handle_auth(&mut self, request_id: Uuid, token: String) -> Vec { match self.manager.redeem_token(&token) { Ok(session) => { match self.manager.register_connection_with_replacement(&session) { Ok(Some(old_id)) => { self.close_replaced_connection(old_id, session.connection_id) } Ok(None) => {} Err(e) => tracing::warn!(error = %e, "register connection failed"), } let cid = session.connection_id; let interval = self.manager.heartbeat_interval_secs(); self.session = Some(session); vec![WsOutbound::AuthOk { request_id, connection_id: cid, heartbeat_interval_secs: interval, }] } Err(e) => vec![WsOutbound::AuthError { request_id, message: e.to_string(), }], } } } #[allow(dead_code)] fn request_id_of(msg: &WsInbound) -> Uuid { match msg { WsInbound::Auth { request_id, .. } => *request_id, WsInbound::Heartbeat { request_id } => *request_id, WsInbound::JoinChannel { request_id, .. } => *request_id, WsInbound::LeaveChannel { request_id, .. } => *request_id, WsInbound::TypingStart { request_id, .. } => *request_id, WsInbound::TypingStop { request_id, .. } => *request_id, WsInbound::MessageSend { request_id, .. } => *request_id, WsInbound::MessageEdit { request_id, .. } => *request_id, WsInbound::MessageDelete { request_id, .. } => *request_id, WsInbound::PresenceUpdate { request_id, .. } => *request_id, WsInbound::ReadReceipt { request_id, .. } => *request_id, } }