use async_trait::async_trait; use dashmap::DashMap; use tokio::sync::{mpsc, watch}; use crate::socket::message_bus::{MessageBus, MessageBusError}; pub struct NatsMessageBus { client: async_nats::Client, shutdowns: DashMap>, } impl NatsMessageBus { pub async fn new(nats_url: &str) -> Result { let client = async_nats::connect(nats_url) .await .map_err(|e| MessageBusError::Nats(e.to_string()))?; Ok(Self { client, shutdowns: DashMap::new(), }) } } #[async_trait] impl MessageBus for NatsMessageBus { async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> { self.client .publish(channel.to_string(), message.to_vec().into()) .await .map_err(|e| MessageBusError::Nats(e.to_string()))?; Ok(()) } async fn subscribe(&self, channel: &str) -> Result>, MessageBusError> { let (tx, rx) = mpsc::channel::>(256); let mut subscriber = self.client .subscribe(channel.to_string()) .await .map_err(|e| MessageBusError::Nats(e.to_string()))?; let (shutdown_tx, mut shutdown_rx) = watch::channel(false); self.shutdowns.insert(channel.to_string(), shutdown_tx); tokio::spawn(async move { use futures_util::StreamExt; loop { tokio::select! { _ = shutdown_rx.changed() => { break; } message = subscriber.next() => { match message { Some(msg) => { let data = msg.payload.to_vec(); if tx.send(data).await.is_err() { break; } } None => break, } } } } if let Err(e) = subscriber.unsubscribe().await { tracing::warn!("NATS unsubscribe error: {}", e); } }); Ok(rx) } async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> { if let Some((_, tx)) = self.shutdowns.remove(channel) { let _ = tx.send(true); } Ok(()) } async fn close(&self) -> Result<(), MessageBusError> { // Signal all subscribers to shutdown self.shutdowns.iter().for_each(|entry| { let _ = entry.value().send(true); }); self.shutdowns.clear(); Ok(()) } }