use std::sync::Arc; use actix_web::{HttpRequest, HttpResponse, web}; use actix_ws::Message; use crate::engine::codec; use crate::engine::packet::{Packet, PacketData, PacketType}; use crate::engine::server::EngineConfig; use crate::engine::session::{SessionState, SessionStore, TransportType}; #[derive(Debug, serde::Deserialize)] pub struct WsQuery { #[serde(rename = "EIO")] pub eio: Option, pub transport: Option, pub sid: Option, } pub async fn websocket_handler( req: HttpRequest, body: web::Payload, query: web::Query, store: web::Data, config: web::Data, on_message: web::Data>, ) -> Result { if query.eio.as_deref() != Some("4") { return Ok(HttpResponse::BadRequest().body("invalid EIO version")); } if query.transport.as_deref() != Some("websocket") { return Ok(HttpResponse::BadRequest().body("invalid transport")); } let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?; let sid = query.sid.clone(); if let Some(ref sid) = sid && !store.exists(sid) { return Ok(HttpResponse::BadRequest().body("unknown session")); } // Create or reuse session, obtaining the mpsc receiver for the forwarding task let (session_sid, mut session_rx) = if let Some(ref sid) = sid { // Upgrade: session already exists, replace its channel and drain pending packets let session_arc = match store.get(sid) { Some(s) => s, None => { tracing::error!("Session {} not found for upgrade", sid); return Ok(HttpResponse::InternalServerError().body("session not found")); } }; let (new_tx, new_rx) = tokio::sync::mpsc::channel(256); { let mut s = session_arc.write().await; // Swap tx atomically: old_tx will be dropped, closing its channel. // Any packets in the old rx are consumed by the old send_handle, // which then exits when it sees the channel close. // Drain pending_packets (from polling buffering) into new channel. let pending = s.take_pending(); for packet in pending { let _ = new_tx.try_send(packet); } s.tx = new_tx; s.set_transport(TransportType::WebSocket); } (sid.clone(), new_rx) } else { // New connection: generate SID and create session let new_sid = crate::engine::session::generate_sid(); let rx = store.create(new_sid.clone(), TransportType::WebSocket); if let Some(s) = store.get(&new_sid) { let mut s = s.write().await; s.set_state(SessionState::Open); } (new_sid, rx) }; let handshake = crate::engine::packet::HandshakeData { sid: session_sid.clone(), upgrades: vec![], ping_interval: config.ping_interval, ping_timeout: config.ping_timeout, max_payload: config.max_payload, }; let open_packet = Packet::open(&handshake); let open_msg = codec::encode_packet(&open_packet); if ws_session.text(open_msg).await.is_err() { tracing::warn!( "Failed to send open packet to WebSocket session {}", session_sid ); store.remove(&session_sid); return Ok(response); } let store_clone = store.get_ref().clone(); let on_message_clone = on_message.get_ref().clone(); let sid_clone = session_sid.clone(); let ws_session_clone = ws_session.clone(); let max_payload = config.max_payload; // Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx) let sid_for_send = session_sid.clone(); let store_for_send = store.get_ref().clone(); let mut ws_for_send = ws_session.clone(); let send_handle = actix_rt::spawn(async move { while let Some(packet) = session_rx.recv().await { let encoded = codec::encode_packet(&packet); if ws_for_send.text(encoded).await.is_err() { break; } } // Session channel closed — clean up store_for_send.remove(&sid_for_send); }); // Task 2: Read incoming WebSocket messages → dispatch let recv_handle = actix_rt::spawn(async move { let mut ws_session = ws_session_clone; while let Some(Ok(msg)) = msg_stream.recv().await { match msg { Message::Text(text) => { if text.len() > max_payload { tracing::warn!( "Text payload too large ({}) for session {}", text.len(), sid_clone ); let _ = ws_session.close(None).await; break; } if let Ok(packet) = codec::decode_packet(&text) { match packet.packet_type { PacketType::Ping => { if let PacketData::Text(ref data) = packet.data && data == "probe" { let pong = Packet::pong("probe"); let pong_msg = codec::encode_packet(&pong); let _ = ws_session.text(pong_msg).await; continue; } let pong = Packet::pong(""); let pong_msg = codec::encode_packet(&pong); let _ = ws_session.text(pong_msg).await; } PacketType::Pong => { if let Some(s) = store_clone.get(&sid_clone) { let mut s = s.write().await; s.update_ping(); } } PacketType::Upgrade => { if let Some(s) = store_clone.get(&sid_clone) { let mut s = s.write().await; s.set_transport(TransportType::WebSocket); s.set_state(SessionState::Open); } } PacketType::Message => { let on_msg = on_message_clone.clone(); let sid = sid_clone.clone(); tokio::spawn(async move { on_msg(sid, packet); }); } PacketType::Close => { if let Some(s) = store_clone.get(&sid_clone) { let mut s = s.write().await; s.set_state(SessionState::Closed); } store_clone.remove(&sid_clone); let _ = ws_session.close(None).await; break; } _ => {} } } } Message::Binary(bin) => { // Enforce max payload size for binary frames if bin.len() > max_payload { tracing::warn!( "Binary payload too large ({}) for session {}", bin.len(), sid_clone ); continue; } if let Ok(packet) = codec::decode_packet_ws(&bin) && packet.packet_type == PacketType::Message { let on_msg = on_message_clone.clone(); let sid = sid_clone.clone(); tokio::spawn(async move { on_msg(sid, packet); }); } } Message::Close(_) => { if let Some(s) = store_clone.get(&sid_clone) { let mut s = s.write().await; s.set_state(SessionState::Closed); } store_clone.remove(&sid_clone); break; } _ => {} } } }); // Task 3: Heartbeat ping sender let sid_for_ping = session_sid.clone(); let store_for_ping = store.get_ref().clone(); let mut ws_for_ping = ws_session.clone(); let ping_interval = config.ping_interval; let ping_handle = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval)); loop { interval.tick().await; if let Some(s) = store_for_ping.get(&sid_for_ping) { let session_state = { let s = s.read().await; s.state }; if session_state == SessionState::Closed { break; } let ping = Packet::ping(""); let ping_msg = codec::encode_packet(&ping); if ws_for_ping.text(ping_msg).await.is_err() { break; } } else { break; } } }); // Wait for any task to finish, then clean up actix_rt::spawn(async move { // actix_rt::spawn returns JoinHandle which is compatible with tokio::select! tokio::select! { _ = send_handle => {}, _ = recv_handle => {}, _ = ping_handle => {}, } store.remove(&session_sid); }); Ok(response) } pub fn configure_websocket(cfg: &mut web::ServiceConfig) { cfg.route("/engine.io/", web::get().to(websocket_handler)); }