use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; use dashmap::DashMap; use fred::clients::Client; use fred::interfaces::{KeysInterface, SetsInterface}; use tokio::sync::mpsc; use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo}; use crate::socket::message_bus::MessageBus; use crate::socket::packet::Packet; use crate::socket::parser; use crate::socket::socket::Socket; const KEY_PREFIX_ROOMS: &str = "socket.io:rooms"; const KEY_PREFIX_SOCKET_ROOMS: &str = "socket.io:socket_rooms"; fn room_key(ns: &str, room: &str) -> String { format!("{}:{}:{}", KEY_PREFIX_ROOMS, ns, room) } fn socket_rooms_key(ns: &str, sid: &str) -> String { format!("{}:{}:{}", KEY_PREFIX_SOCKET_ROOMS, ns, sid) } /// Handle incoming bus messages from other servers. /// Only performs local state updates — the remote server already wrote to Redis. async fn handle_bus_message( msg: BusMessage, on_local_broadcast: &Arc, server_id: &str, ) { match msg { BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => { if sender_id == server_id { return; } if let Ok(decoded_packet) = parser::decode(&packet) { on_local_broadcast(&decoded_packet, &opts); } } BusMessage::SocketJoin { server_id: sender_id, .. } | BusMessage::SocketLeave { server_id: sender_id, .. } | BusMessage::SocketDisconnect { server_id: sender_id, .. } => { // Skip messages from this server; remote server already updated Redis if sender_id == server_id { return; } // No duplicate Redis writes — the sender already persisted the state change } } } pub struct RedisAdapter { message_bus: Arc, redis_client: Client, room_subscribers: DashMap>>, socket_rooms: DashMap>, rooms: DashMap>, sockets: DashMap>, server_id: String, namespace: String, on_local_broadcast: Arc, } impl RedisAdapter { pub fn new( message_bus: Arc, redis_client: Client, server_id: String, namespace: String, on_local_broadcast: Arc, ) -> Self { Self { message_bus, redis_client, server_id, namespace, on_local_broadcast, room_subscribers: DashMap::new(), socket_rooms: DashMap::new(), rooms: DashMap::new(), sockets: DashMap::new(), } } pub async fn init(&self) -> Result<(), AdapterError> { let channels = ["broadcast", "join", "leave", "disconnect"]; let prefix = format!("socket.io:{}:", self.namespace); for channel_type in channels { let channel = format!("{}{}", prefix, channel_type); match self.message_bus.subscribe(&channel).await { Ok(rx) => { self.room_subscribers.insert(channel_type.to_string(), rx); } Err(e) => return Err(AdapterError::MessageBus(e.to_string())), } } self.spawn_listener(); Ok(()) } fn spawn_listener(&self) { let server_id = self.server_id.clone(); let on_local_broadcast = self.on_local_broadcast.clone(); let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx); let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx); let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx); let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx); tokio::spawn(async move { loop { tokio::select! { Some(data) = async { broadcast_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { join_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { leave_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { disconnect_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } else => break, } } }); } } #[async_trait] impl Adapter for RedisAdapter { async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { if opts.flags.local_only { (self.on_local_broadcast)(packet, opts); return Ok(()); } let msg = BusMessage::Broadcast { namespace: packet.namespace.clone(), packet: parser::encode(packet), opts: opts.clone(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; (self.on_local_broadcast)(packet, opts); Ok(()) } async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> { let rk = room_key(ns, room); let srk = socket_rooms_key(ns, sid); self.redis_client .sadd::<(), _, _>(&rk, sid) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.redis_client .sadd::<(), _, _>(&srk, room) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.socket_rooms .entry(sid.to_string()) .and_modify(|set| { set.insert(room.to_string()); }) .or_insert_with(|| HashSet::from([room.to_string()])); self.rooms .entry(room.to_string()) .and_modify(|set| { set.insert(sid.to_string()); }) .or_insert_with(|| HashSet::from([sid.to_string()])); let msg = BusMessage::SocketJoin { namespace: ns.to_string(), sid: sid.to_string(), room: room.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:join", ns), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> { let rk = room_key(ns, room); let srk = socket_rooms_key(ns, sid); self.redis_client .srem::<(), _, _>(&rk, sid) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.redis_client .srem::<(), _, _>(&srk, room) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; if let Some(mut entry) = self.socket_rooms.get_mut(sid) { entry.value_mut().remove(room); } if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) { self.socket_rooms.remove(sid); } if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { self.rooms.remove(room); } let msg = BusMessage::SocketLeave { namespace: ns.to_string(), sid: sid.to_string(), room: room.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:leave", ns), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> { if let Some((_, rooms)) = self.socket_rooms.remove(sid) { for room in &rooms { if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { self.rooms.remove(room); } let rk = room_key(ns, room); if let Err(e) = self.redis_client.srem::<(), _, _>(&rk, sid).await { tracing::warn!("Redis SREM room error: {}", e); } } } let srk = socket_rooms_key(ns, sid); self.redis_client .del::<(), _>(&srk) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.sockets.remove(sid); let msg = BusMessage::SocketDisconnect { namespace: ns.to_string(), sid: sid.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg) .map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:disconnect", ns), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result, AdapterError> { let mut result = Vec::new(); let target_sids: HashSet = if opts.rooms.is_empty() { self.sockets.iter().map(|e| e.key().clone()).collect() } else { let mut sids = HashSet::new(); for room in &opts.rooms { if let Some(entry) = self.rooms.get(room) { sids.extend(entry.value().iter().cloned()); } } sids }; for sid in target_sids { if opts.except.contains(&sid) { continue; } let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default(); result.push(SocketInfo { sid: sid.clone(), namespace: self.namespace.clone(), rooms, }); } Ok(result) } async fn socket_rooms(&self, sid: &str) -> Result, AdapterError> { Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default()) } fn server_id(&self) -> &str { &self.server_id } async fn close(&self) -> Result<(), AdapterError> { self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } }