201 lines
6.9 KiB
Rust
201 lines
6.9 KiB
Rust
use std::sync::Arc;
|
|
|
|
use futures_util::StreamExt;
|
|
use uuid::Uuid;
|
|
|
|
use crate::queue::NatsQueue;
|
|
|
|
use super::{
|
|
ArticleEvent, CategoryEvent, DraftEvent, FollowEvent, MemberEvent, MessageEvent, PollEvent,
|
|
PresenceEvent, ReactionEvent, ThreadEvent, TypingEvent, WsOutbound, WsSessionManager,
|
|
WsSinkManager,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct NatsWsBridge {
|
|
queue: Arc<NatsQueue>,
|
|
sessions: Arc<WsSessionManager>,
|
|
sinks: Arc<WsSinkManager>,
|
|
}
|
|
|
|
impl NatsWsBridge {
|
|
pub fn new(
|
|
queue: Arc<NatsQueue>,
|
|
sessions: Arc<WsSessionManager>,
|
|
sinks: Arc<WsSinkManager>,
|
|
) -> Self {
|
|
Self {
|
|
queue,
|
|
sessions,
|
|
sinks,
|
|
}
|
|
}
|
|
|
|
pub async fn run_ephemeral(self, subject: &str) {
|
|
let Ok(mut sub) = self.queue.subscribe_ephemeral(subject.to_string()).await else {
|
|
tracing::warn!(subject, "nats ws bridge subscribe failed");
|
|
return;
|
|
};
|
|
while let Some(msg) = sub.next().await {
|
|
self.dispatch(msg.subject.as_str(), msg.payload.as_ref(), request_id(&msg))
|
|
.await;
|
|
}
|
|
}
|
|
|
|
async fn dispatch(&self, subject: &str, payload: &[u8], request_id: Uuid) {
|
|
if subject.starts_with("im.message.") {
|
|
self.channel_event(payload, |data| WsOutbound::Message { request_id, data });
|
|
} else if subject.starts_with("im.thread.") {
|
|
self.channel_event(payload, |data| WsOutbound::Thread { request_id, data });
|
|
} else if subject.starts_with("im.member.") {
|
|
self.channel_event(payload, |data| WsOutbound::Member { request_id, data });
|
|
} else if subject.starts_with("im.reaction.") {
|
|
self.channel_event(payload, |data| WsOutbound::Reaction { request_id, data });
|
|
} else if subject.starts_with("im.poll.") {
|
|
self.channel_event(payload, |data| WsOutbound::Poll { request_id, data });
|
|
} else if subject.starts_with("im.article.") {
|
|
self.channel_event(payload, |data| WsOutbound::Article { request_id, data });
|
|
} else if subject.starts_with("im.typing.") {
|
|
self.channel_event(payload, |data| WsOutbound::Typing { request_id, data });
|
|
} else if subject.starts_with("im.presence.") {
|
|
self.presence_event(payload, request_id);
|
|
} else if subject.starts_with("im.channel.") {
|
|
self.channel_meta_event(subject, payload, request_id);
|
|
} else if subject.starts_with("im.category.") {
|
|
self.category_event(payload, request_id);
|
|
} else if subject.starts_with("im.draft.") {
|
|
self.draft_event(payload, request_id);
|
|
} else if subject.starts_with("im.follow.") {
|
|
self.channel_event(payload, |data| WsOutbound::Follow { request_id, data });
|
|
}
|
|
}
|
|
|
|
fn channel_event<T, F>(&self, payload: &[u8], build: F)
|
|
where
|
|
T: serde::de::DeserializeOwned + ChannelScoped,
|
|
F: Fn(T) -> WsOutbound,
|
|
{
|
|
let Ok(data) = serde_json::from_slice::<T>(payload) else {
|
|
tracing::warn!("nats ws bridge decode channel event failed");
|
|
return;
|
|
};
|
|
let channel_id = data.channel_id();
|
|
let subscribers = self.sessions.subscribers(channel_id);
|
|
let delivered = self.sinks.send_many(subscribers, build(data));
|
|
tracing::debug!(%channel_id, delivered, "nats event forwarded to ws subscribers");
|
|
}
|
|
|
|
fn presence_event(&self, payload: &[u8], request_id: Uuid) {
|
|
let Ok(data) = serde_json::from_slice::<PresenceEvent>(payload) else {
|
|
tracing::warn!("nats ws bridge decode presence event failed");
|
|
return;
|
|
};
|
|
let ids = self.sessions.user_connections(data.user_id);
|
|
let delivered = self
|
|
.sinks
|
|
.send_many(ids, WsOutbound::Presence { request_id, data });
|
|
tracing::debug!(delivered, "nats presence forwarded to ws subscribers");
|
|
}
|
|
|
|
fn category_event(&self, payload: &[u8], request_id: Uuid) {
|
|
let Ok(data) = serde_json::from_slice::<CategoryEvent>(payload) else {
|
|
tracing::warn!("nats ws bridge decode category event failed");
|
|
return;
|
|
};
|
|
let targets = self.sessions.workspace_connections(&data.workspace_name);
|
|
let delivered = self
|
|
.sinks
|
|
.send_many(targets, WsOutbound::Category { request_id, data });
|
|
tracing::debug!(delivered, "nats category event forwarded to ws subscribers");
|
|
}
|
|
|
|
fn draft_event(&self, payload: &[u8], request_id: Uuid) {
|
|
let Ok(data) = serde_json::from_slice::<DraftEvent>(payload) else {
|
|
tracing::warn!("nats ws bridge decode draft event failed");
|
|
return;
|
|
};
|
|
let targets = self.sessions.user_connections(data.user_id);
|
|
let delivered = self
|
|
.sinks
|
|
.send_many(targets, WsOutbound::Draft { request_id, data });
|
|
tracing::debug!(delivered, "nats draft event forwarded to ws subscribers");
|
|
}
|
|
|
|
fn channel_meta_event(&self, subject: &str, payload: &[u8], request_id: Uuid) {
|
|
let Ok(data) = serde_json::from_slice::<super::ChannelEvent>(payload) else {
|
|
tracing::warn!("nats ws bridge decode channel event failed");
|
|
return;
|
|
};
|
|
let mut targets = data
|
|
.workspace_name
|
|
.as_deref()
|
|
.map(|workspace| self.sessions.workspace_connections(workspace))
|
|
.unwrap_or_else(|| self.sessions.subscribers(data.channel_id));
|
|
if targets.is_empty()
|
|
&& let Some(id) = subject
|
|
.rsplit('.')
|
|
.next()
|
|
.and_then(|v| v.parse::<Uuid>().ok())
|
|
{
|
|
targets = self.sessions.subscribers(id);
|
|
}
|
|
let delivered = self
|
|
.sinks
|
|
.send_many(targets, WsOutbound::Channel { request_id, data });
|
|
tracing::debug!(delivered, "nats channel event forwarded to ws subscribers");
|
|
}
|
|
}
|
|
|
|
pub trait ChannelScoped {
|
|
fn channel_id(&self) -> Uuid;
|
|
}
|
|
|
|
impl ChannelScoped for MessageEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for ThreadEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for MemberEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for ReactionEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for PollEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for ArticleEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for TypingEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
impl ChannelScoped for FollowEvent {
|
|
fn channel_id(&self) -> Uuid {
|
|
self.channel_id
|
|
}
|
|
}
|
|
|
|
fn request_id(msg: &async_nats::Message) -> Uuid {
|
|
msg.headers
|
|
.as_ref()
|
|
.and_then(|h| h.get("X-Request-Id"))
|
|
.and_then(|v| v.as_str().parse().ok())
|
|
.unwrap_or_else(Uuid::nil)
|
|
}
|