use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; use dashmap::DashMap; use tokio::sync::mpsc; use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo}; use crate::socket::message_bus::MessageBus; use crate::socket::packet::Packet; use crate::socket::parser; use crate::socket::socket::Socket; /// Handle incoming bus messages from other servers. /// Only performs local dispatch — no remote state writes needed. async fn handle_bus_message( msg: BusMessage, on_local_broadcast: &Arc, server_id: &str, ) { match msg { BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => { if sender_id == server_id { return; } if let Ok(decoded_packet) = parser::decode(&packet) { on_local_broadcast(&decoded_packet, &opts); } } // NATS adapter manages room state locally; cross-server join/leave/disconnect // are informational only and don't require duplicate state writes. BusMessage::SocketJoin { server_id: sender_id, .. } | BusMessage::SocketLeave { server_id: sender_id, .. } | BusMessage::SocketDisconnect { server_id: sender_id, .. } => { if sender_id == server_id { return; } } } } /// NATS-based adapter that manages room state locally and uses NATS /// for cross-server broadcast only. Does NOT depend on Redis. pub struct NatsAdapter { message_bus: Arc, room_subscribers: DashMap>>, socket_rooms: DashMap>, rooms: DashMap>, /// socket_sid → engine_sid mapping for local delivery socket_sids: DashMap, sockets: DashMap>, server_id: String, namespace: String, on_local_broadcast: Arc, } impl NatsAdapter { pub fn new( message_bus: Arc, server_id: String, namespace: String, on_local_broadcast: Arc, ) -> Self { Self { message_bus, server_id, namespace, on_local_broadcast, room_subscribers: DashMap::new(), socket_rooms: DashMap::new(), rooms: DashMap::new(), socket_sids: DashMap::new(), sockets: DashMap::new(), } } pub async fn init(&self) -> Result<(), AdapterError> { let channels = ["broadcast", "join", "leave", "disconnect"]; let prefix = format!("socket.io:{}:", self.namespace); for channel_type in channels { let subject = format!("{}{}", prefix, channel_type); match self.message_bus.subscribe(&subject).await { Ok(rx) => { self.room_subscribers.insert(channel_type.to_string(), rx); } Err(e) => return Err(AdapterError::MessageBus(e.to_string())), } } self.spawn_listener(); Ok(()) } fn spawn_listener(&self) { let server_id = self.server_id.clone(); let on_local_broadcast = self.on_local_broadcast.clone(); let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx); let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx); let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx); let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx); tokio::spawn(async move { loop { tokio::select! { Some(data) = async { broadcast_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { join_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { leave_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { disconnect_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } else => break, } } }); } } #[async_trait] impl Adapter for NatsAdapter { async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { if opts.flags.local_only { (self.on_local_broadcast)(packet, opts); return Ok(()); } let msg = BusMessage::Broadcast { namespace: self.namespace.clone(), packet: parser::encode(packet), opts: opts.clone(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:broadcast", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; (self.on_local_broadcast)(packet, opts); Ok(()) } async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> { self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string()); Ok(()) } async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> { self.socket_rooms .entry(sid.to_string()) .and_modify(|set| { set.insert(room.to_string()); }) .or_insert_with(|| HashSet::from([room.to_string()])); self.rooms .entry(room.to_string()) .and_modify(|set| { set.insert(sid.to_string()); }) .or_insert_with(|| HashSet::from([sid.to_string()])); let msg = BusMessage::SocketJoin { namespace: self.namespace.clone(), sid: sid.to_string(), room: room.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:join", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn del(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> { if let Some(mut entry) = self.socket_rooms.get_mut(sid) { entry.value_mut().remove(room); } if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) { self.socket_rooms.remove(sid); } if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { self.rooms.remove(room); } let msg = BusMessage::SocketLeave { namespace: self.namespace.clone(), sid: sid.to_string(), room: room.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:leave", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn del_all(&self, sid: &str, _ns: &str) -> Result<(), AdapterError> { if let Some((_, rooms)) = self.socket_rooms.remove(sid) { for room in &rooms { if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { self.rooms.remove(room); } } } self.socket_sids.remove(sid); self.sockets.remove(sid); let msg = BusMessage::SocketDisconnect { namespace: self.namespace.clone(), sid: sid.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:disconnect", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result, AdapterError> { let mut result = Vec::new(); let target_sids: HashSet = if opts.rooms.is_empty() { self.socket_sids.iter().map(|e| e.key().clone()).collect() } else { let mut sids = HashSet::new(); for room in &opts.rooms { if let Some(entry) = self.rooms.get(room) { sids.extend(entry.value().iter().cloned()); } } sids }; for sid in target_sids { if opts.except.contains(&sid) { continue; } let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default(); result.push(SocketInfo { sid: sid.clone(), namespace: self.namespace.clone(), rooms, }); } Ok(result) } async fn socket_rooms(&self, sid: &str) -> Result, AdapterError> { Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default()) } fn server_id(&self) -> &str { &self.server_id } async fn close(&self) -> Result<(), AdapterError> { self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } }