use std::sync::Arc; use dashmap::DashMap; use tokio::sync::mpsc; use crate::engine::packet::Packet as EnginePacket; use crate::engine::packet::PacketData as EnginePacketData; use crate::engine::server::{EngineConfig, EngineServer}; use crate::engine::session::SessionStore; use crate::socket::adapter::{Adapter, LocalAdapter}; use crate::socket::namespace::NamespaceManager; use crate::socket::packet::{Packet, PacketType}; use crate::socket::parser; use crate::socket::socket::Socket; pub struct SocketServer { pub engine: Arc, pub namespaces: Arc, pub adapter: Arc, socket_txs: Arc>>, } impl SocketServer { pub fn new(config: EngineConfig) -> Self { SocketServerBuilder::new(config).build() } pub fn builder(config: EngineConfig) -> SocketServerBuilder { SocketServerBuilder::new(config) } pub fn of(&self, path: impl Into) -> Arc { self.namespaces.get_or_create_namespace(&path.into()) } pub async fn run_http(self: Arc, addr: &str) -> std::io::Result<()> { self.engine.clone().run_http(addr).await } pub fn register_socket(&self, sid: String, tx: mpsc::Sender) { self.socket_txs.insert(sid, tx); } pub fn unregister_socket(&self, sid: &str) { self.socket_txs.remove(sid); } } pub struct SocketServerBuilder { config: EngineConfig, adapter: Option>, } impl SocketServerBuilder { pub fn new(config: EngineConfig) -> Self { Self { config, adapter: None, } } pub fn adapter(mut self, adapter: Arc) -> Self { self.adapter = Some(adapter); self } pub fn build(self) -> SocketServer { let namespaces = Arc::new(NamespaceManager::new()); let socket_txs: Arc>> = Arc::new(DashMap::new()); let engine_store = SessionStore::new(); let namespaces_clone = namespaces.clone(); let socket_txs_clone = socket_txs.clone(); let engine_store_clone = engine_store.clone(); let adapter: Arc = self.adapter.unwrap_or_else(|| { let ns_clone = namespaces.clone(); let send_fn = move |engine_sid: &str, packet: &Packet| { if let Some(ns) = ns_clone.get_namespace(&packet.namespace) { if let Some(socket) = ns.get_socket_by_engine_sid(engine_sid) { socket.send_packet(packet).map_err(|e| e.to_string()) } else { Err(format!( "Socket with engine_sid {} not found in namespace {}", engine_sid, packet.namespace )) } } else { Err(format!("Namespace {} not found", packet.namespace)) } }; Arc::new(LocalAdapter::new(send_fn)) }); let adapter_clone = adapter.clone(); let engine = Arc::new(EngineServer::with_store( self.config, engine_store, move |sid, engine_packet| { let namespaces = namespaces_clone.clone(); let socket_txs = socket_txs_clone.clone(); let engine_store = engine_store_clone.clone(); let adapter = adapter_clone.clone(); tokio::spawn(async move { handle_engine_message( sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter, ) .await; }); }, )); let server = SocketServer { engine, namespaces, adapter, socket_txs, }; for ns in server.namespaces.all_namespaces() { let adapter_ref = server.adapter.clone(); tokio::spawn(async move { ns.set_adapter(adapter_ref).await; }); } server } } async fn handle_engine_message( engine_sid: String, engine_packet: EnginePacket, namespaces: &Arc, socket_txs: &Arc>>, engine_store: &SessionStore, adapter: &Arc, ) { if let EnginePacketData::Text(ref text) = engine_packet.data { match parser::decode(text) { Ok(socket_packet) => match socket_packet.packet_type { PacketType::Connect => { handle_connect( &engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter, ) .await; } PacketType::Disconnect => { handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs); } PacketType::Event => { handle_event(&engine_sid, &socket_packet, namespaces); } PacketType::Ack => { handle_ack(&engine_sid, &socket_packet); } _ => {} }, Err(e) => { tracing::warn!(engine_sid = %engine_sid, error = %e, "Invalid Socket.IO packet"); } } } } async fn handle_connect( engine_sid: &str, packet: &Packet, namespaces: &Arc, socket_txs: &Arc>>, engine_store: &SessionStore, adapter: &Arc, ) { // Validate namespace path to prevent DoS via arbitrary namespace creation if !crate::socket::namespace::is_valid_namespace(&packet.namespace) { tracing::warn!( "Rejected connect with invalid namespace: {}", packet.namespace ); return; } let namespace = namespaces.get_or_create_namespace(&packet.namespace); // Ensure newly created namespaces get the shared adapter before registration. { let ns_adapter = namespace.adapter.read().await; if ns_adapter.is_none() { drop(ns_adapter); namespace.set_adapter(adapter.clone()).await; } } let socket_sid = crate::engine::session::generate_sid(); let (tx, mut rx) = mpsc::channel::(256); socket_txs.insert(socket_sid.clone(), tx.clone()); let socket = Arc::new(Socket::new( socket_sid.clone(), packet.namespace.clone(), engine_sid.to_string(), tx, )); // Run connect handler and add to namespace. // If the handler rejects, clean up and do NOT send a Connect response. if let Err(msg) = namespace .add_socket(socket.clone(), packet.data.as_ref()) .await { tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg); socket_txs.remove(&socket_sid); return; } // Connect handler passed — spawn forwarding task let engine_store_clone = engine_store.clone(); let engine_sid_clone = engine_sid.to_string(); let socket_sid_clone = socket_sid.clone(); let socket_txs_clone = socket_txs.clone(); let namespace_clone = namespace.clone(); tokio::spawn(async move { while let Some(socket_packet) = rx.recv().await { let encoded = parser::encode(&socket_packet); let engine_packet = EnginePacket::message_text(encoded); if let Some(session) = engine_store_clone.get(&engine_sid_clone) { let mut s = session.write().await; if s.state == crate::engine::session::SessionState::Closed { break; } s.push_packet(engine_packet); } else { break; } } // Forwarding task ended — ensure socket is cleaned up from namespace socket_txs_clone.remove(&socket_sid_clone); namespace_clone .remove_socket_by_sid(&socket_sid_clone) .await; }); // Send Connect response (only after handler passed) let response = Packet::connect( &socket.namespace, Some(serde_json::json!({ "sid": &socket.sid })), ); if socket.send_packet(&response).is_err() { tracing::warn!("Failed to send connect response to socket {}", socket.sid); } } fn handle_disconnect( engine_sid: &str, packet: &Packet, namespaces: &Arc, socket_txs: &Arc>>, ) { if let Some(namespace) = namespaces.get_namespace(&packet.namespace) { // Look up socket by engine_sid, then remove by socket_sid if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) { socket_txs.remove(&socket.sid); let socket_sid = socket.sid.clone(); let ns_clone = namespace.clone(); tokio::spawn(async move { ns_clone.remove_socket_by_sid(&socket_sid).await; }); } } } fn handle_event(engine_sid: &str, packet: &Packet, namespaces: &Arc) { if let Some(namespace) = namespaces.get_namespace(&packet.namespace) && let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) && let Some(ref data) = packet.data && let Some(arr) = data.as_array() && let Some(event) = arr.first().and_then(|v| v.as_str()) { let event_data = if arr.len() > 1 { serde_json::Value::Array(arr[1..].to_vec()) } else { serde_json::Value::Null }; let namespace_clone = namespace.clone(); let event = event.to_string(); let socket_clone = socket.clone(); tokio::spawn(async move { namespace_clone .handle_event(socket_clone, &event, &event_data) .await; }); } } fn handle_ack(engine_sid: &str, packet: &Packet) { tracing::debug!( "Received ACK from {} for namespace {} with id {:?}", engine_sid, packet.namespace, packet.id ); }