use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; use dashmap::DashMap; use fred::clients::Client; use fred::interfaces::{KeysInterface, SetsInterface}; use tokio::sync::mpsc; use crate::socket::adapter::{ Adapter, AdapterError, BroadcastOptions, BusMessage, LocalBroadcastFn, SocketInfo, }; use crate::socket::message_bus::MessageBus; use crate::socket::packet::Packet; use crate::socket::parser; use crate::socket::socket::Socket; const KEY_PREFIX_ROOMS: &str = "socket.io:rooms"; const KEY_PREFIX_SOCKET_ROOMS: &str = "socket.io:socket_rooms"; fn room_key(ns: &str, room: &str) -> String { format!("{}:{}:{}", KEY_PREFIX_ROOMS, ns, room) } fn socket_rooms_key(ns: &str, sid: &str) -> String { format!("{}:{}:{}", KEY_PREFIX_SOCKET_ROOMS, ns, sid) } /// Handle incoming bus messages from other servers. /// Only performs local state updates — the remote server already wrote to Redis. async fn handle_bus_message( msg: BusMessage, on_local_broadcast: &LocalBroadcastFn, server_id: &str, ) { match msg { BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id, } => { if sender_id == server_id { return; } if let Ok(decoded_packet) = parser::decode(&packet) { on_local_broadcast(&decoded_packet, &opts); } } BusMessage::SocketJoin { server_id: sender_id, .. } | BusMessage::SocketLeave { server_id: sender_id, .. } | BusMessage::SocketDisconnect { server_id: sender_id, .. } => { // Skip messages from this server; remote server already updated Redis if sender_id == server_id {} // No duplicate Redis writes — the sender already persisted the state change } } } pub struct RedisAdapter { message_bus: Arc, redis_client: Client, room_subscribers: DashMap>>, socket_rooms: DashMap>, rooms: DashMap>, /// socket_sid → engine_sid mapping for local inspection. socket_sids: DashMap, sockets: DashMap>, server_id: String, namespace: String, on_local_broadcast: LocalBroadcastFn, } impl RedisAdapter { pub fn new( message_bus: Arc, redis_client: Client, server_id: String, namespace: String, on_local_broadcast: LocalBroadcastFn, ) -> Self { Self { message_bus, redis_client, server_id, namespace, on_local_broadcast, room_subscribers: DashMap::new(), socket_rooms: DashMap::new(), rooms: DashMap::new(), socket_sids: DashMap::new(), sockets: DashMap::new(), } } pub async fn init(&self) -> Result<(), AdapterError> { let channels = ["broadcast", "join", "leave", "disconnect"]; let prefix = format!("socket.io:{}:", self.namespace); for channel_type in channels { let channel = format!("{}{}", prefix, channel_type); match self.message_bus.subscribe(&channel).await { Ok(rx) => { self.room_subscribers.insert(channel_type.to_string(), rx); } Err(e) => return Err(AdapterError::MessageBus(e.to_string())), } } self.spawn_listener(); Ok(()) } fn spawn_listener(&self) { let server_id = self.server_id.clone(); let on_local_broadcast = self.on_local_broadcast.clone(); let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx); let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx); let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx); let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx); tokio::spawn(async move { loop { tokio::select! { Some(data) = async { broadcast_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { join_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { leave_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } Some(data) = async { disconnect_rx.as_mut()?.recv().await } => { if let Ok(msg) = serde_json::from_slice::(&data) { handle_bus_message(msg, &on_local_broadcast, &server_id).await; } } else => break, } } }); } } #[async_trait] impl Adapter for RedisAdapter { async fn broadcast( &self, packet: &Packet, opts: &BroadcastOptions, ) -> Result<(), AdapterError> { if opts.flags.local_only { (self.on_local_broadcast)(packet, opts); return Ok(()); } let msg = BusMessage::Broadcast { namespace: packet.namespace.clone(), packet: parser::encode(packet), opts: opts.clone(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:broadcast", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; (self.on_local_broadcast)(packet, opts); Ok(()) } async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> { let rk = room_key(ns, room); let srk = socket_rooms_key(ns, sid); self.redis_client .sadd::<(), _, _>(&rk, sid) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.redis_client .sadd::<(), _, _>(&srk, room) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.socket_rooms .entry(sid.to_string()) .and_modify(|set| { set.insert(room.to_string()); }) .or_insert_with(|| HashSet::from([room.to_string()])); self.rooms .entry(room.to_string()) .and_modify(|set| { set.insert(sid.to_string()); }) .or_insert_with(|| HashSet::from([sid.to_string()])); let msg = BusMessage::SocketJoin { namespace: ns.to_string(), sid: sid.to_string(), room: room.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:join", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> { let rk = room_key(ns, room); let srk = socket_rooms_key(ns, sid); self.redis_client .srem::<(), _, _>(&rk, sid) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.redis_client .srem::<(), _, _>(&srk, room) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; if let Some(mut entry) = self.socket_rooms.get_mut(sid) { entry.value_mut().remove(room); } if self .socket_rooms .get(sid) .map(|e| e.value().is_empty()) .unwrap_or(true) { self.socket_rooms.remove(sid); } if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } if self .rooms .get(room) .map(|e| e.value().is_empty()) .unwrap_or(true) { self.rooms.remove(room); } let msg = BusMessage::SocketLeave { namespace: ns.to_string(), sid: sid.to_string(), room: room.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:leave", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> { if let Some((_, rooms)) = self.socket_rooms.remove(sid) { for room in &rooms { if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } if self .rooms .get(room) .map(|e| e.value().is_empty()) .unwrap_or(true) { self.rooms.remove(room); } let rk = room_key(ns, room); if let Err(e) = self.redis_client.srem::<(), _, _>(&rk, sid).await { tracing::warn!("Redis SREM room error: {}", e); } } } let srk = socket_rooms_key(ns, sid); self.redis_client .del::<(), _>(&srk) .await .map_err(|e| AdapterError::Redis(e.to_string()))?; self.socket_sids.remove(sid); self.sockets.remove(sid); let msg = BusMessage::SocketDisconnect { namespace: ns.to_string(), sid: sid.to_string(), server_id: self.server_id.clone(), }; let payload = serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish( &format!("socket.io:{}:disconnect", self.namespace), &payload, ) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } async fn register( &self, socket_sid: &str, engine_sid: &str, _ns: &str, ) -> Result<(), AdapterError> { self.socket_sids .insert(socket_sid.to_string(), engine_sid.to_string()); Ok(()) } async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> { self.del_all(socket_sid, ns).await } async fn fetch_sockets( &self, opts: &BroadcastOptions, ) -> Result, AdapterError> { let mut result = Vec::new(); let target_sids: HashSet = if opts.rooms.is_empty() { self.socket_sids.iter().map(|e| e.key().clone()).collect() } else { let mut sids = HashSet::new(); for room in &opts.rooms { if let Some(entry) = self.rooms.get(room) { sids.extend(entry.value().iter().cloned()); } } sids }; for sid in target_sids { if opts.except.contains(&sid) { continue; } let rooms = self .socket_rooms .get(&sid) .map(|e| e.value().clone()) .unwrap_or_default(); result.push(SocketInfo { sid: sid.clone(), namespace: self.namespace.clone(), rooms, }); } Ok(result) } async fn socket_rooms(&self, sid: &str) -> Result, AdapterError> { Ok(self .socket_rooms .get(sid) .map(|e| e.value().clone()) .unwrap_or_default()) } fn server_id(&self) -> &str { &self.server_id } async fn close(&self) -> Result<(), AdapterError> { self.message_bus .close() .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } }