use std::sync::Arc; use futures_util::StreamExt; use uuid::Uuid; use crate::queue::NatsQueue; use super::{ ArticleEvent, CategoryEvent, DraftEvent, FollowEvent, MemberEvent, MessageEvent, PollEvent, PresenceEvent, ReactionEvent, ThreadEvent, TypingEvent, WsOutbound, WsSessionManager, WsSinkManager, }; #[derive(Clone)] pub struct NatsWsBridge { queue: Arc, sessions: Arc, sinks: Arc, } impl NatsWsBridge { pub fn new( queue: Arc, sessions: Arc, sinks: Arc, ) -> Self { Self { queue, sessions, sinks, } } pub async fn run_ephemeral(self, subject: &str) { let Ok(mut sub) = self.queue.subscribe_ephemeral(subject.to_string()).await else { tracing::warn!(subject, "nats ws bridge subscribe failed"); return; }; while let Some(msg) = sub.next().await { self.dispatch(msg.subject.as_str(), msg.payload.as_ref(), request_id(&msg)) .await; } } async fn dispatch(&self, subject: &str, payload: &[u8], request_id: Uuid) { if subject.starts_with("im.message.") { self.channel_event(payload, |data| WsOutbound::Message { request_id, data }); } else if subject.starts_with("im.thread.") { self.channel_event(payload, |data| WsOutbound::Thread { request_id, data }); } else if subject.starts_with("im.member.") { self.channel_event(payload, |data| WsOutbound::Member { request_id, data }); } else if subject.starts_with("im.reaction.") { self.channel_event(payload, |data| WsOutbound::Reaction { request_id, data }); } else if subject.starts_with("im.poll.") { self.channel_event(payload, |data| WsOutbound::Poll { request_id, data }); } else if subject.starts_with("im.article.") { self.channel_event(payload, |data| WsOutbound::Article { request_id, data }); } else if subject.starts_with("im.typing.") { self.channel_event(payload, |data| WsOutbound::Typing { request_id, data }); } else if subject.starts_with("im.presence.") { self.presence_event(payload, request_id); } else if subject.starts_with("im.channel.") { self.channel_meta_event(subject, payload, request_id); } else if subject.starts_with("im.category.") { self.category_event(payload, request_id); } else if subject.starts_with("im.draft.") { self.draft_event(payload, request_id); } else if subject.starts_with("im.follow.") { self.channel_event(payload, |data| WsOutbound::Follow { request_id, data }); } } fn channel_event(&self, payload: &[u8], build: F) where T: serde::de::DeserializeOwned + ChannelScoped, F: Fn(T) -> WsOutbound, { let Ok(data) = serde_json::from_slice::(payload) else { tracing::warn!("nats ws bridge decode channel event failed"); return; }; let channel_id = data.channel_id(); let subscribers = self.sessions.subscribers(channel_id); let delivered = self.sinks.send_many(subscribers, build(data)); tracing::debug!(%channel_id, delivered, "nats event forwarded to ws subscribers"); } fn presence_event(&self, payload: &[u8], request_id: Uuid) { let Ok(data) = serde_json::from_slice::(payload) else { tracing::warn!("nats ws bridge decode presence event failed"); return; }; let ids = self.sessions.user_connections(data.user_id); let delivered = self .sinks .send_many(ids, WsOutbound::Presence { request_id, data }); tracing::debug!(delivered, "nats presence forwarded to ws subscribers"); } fn category_event(&self, payload: &[u8], request_id: Uuid) { let Ok(data) = serde_json::from_slice::(payload) else { tracing::warn!("nats ws bridge decode category event failed"); return; }; let targets = self.sessions.workspace_connections(&data.workspace_name); let delivered = self .sinks .send_many(targets, WsOutbound::Category { request_id, data }); tracing::debug!(delivered, "nats category event forwarded to ws subscribers"); } fn draft_event(&self, payload: &[u8], request_id: Uuid) { let Ok(data) = serde_json::from_slice::(payload) else { tracing::warn!("nats ws bridge decode draft event failed"); return; }; let targets = self.sessions.user_connections(data.user_id); let delivered = self .sinks .send_many(targets, WsOutbound::Draft { request_id, data }); tracing::debug!(delivered, "nats draft event forwarded to ws subscribers"); } fn channel_meta_event(&self, subject: &str, payload: &[u8], request_id: Uuid) { let Ok(data) = serde_json::from_slice::(payload) else { tracing::warn!("nats ws bridge decode channel event failed"); return; }; let mut targets = data .workspace_name .as_deref() .map(|workspace| self.sessions.workspace_connections(workspace)) .unwrap_or_else(|| self.sessions.subscribers(data.channel_id)); if targets.is_empty() && let Some(id) = subject .rsplit('.') .next() .and_then(|v| v.parse::().ok()) { targets = self.sessions.subscribers(id); } let delivered = self .sinks .send_many(targets, WsOutbound::Channel { request_id, data }); tracing::debug!(delivered, "nats channel event forwarded to ws subscribers"); } } pub trait ChannelScoped { fn channel_id(&self) -> Uuid; } impl ChannelScoped for MessageEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for ThreadEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for MemberEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for ReactionEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for PollEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for ArticleEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for TypingEvent { fn channel_id(&self) -> Uuid { self.channel_id } } impl ChannelScoped for FollowEvent { fn channel_id(&self) -> Uuid { self.channel_id } } fn request_id(msg: &async_nats::Message) -> Uuid { msg.headers .as_ref() .and_then(|h| h.get("X-Request-Id")) .and_then(|v| v.as_str().parse().ok()) .unwrap_or_else(Uuid::nil) }