use std::sync::Arc; use std::time::Instant; use dashmap::DashMap; use tokio::sync::{Notify, mpsc}; use crate::engine::packet::Packet; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TransportType { Polling, WebSocket, WebTransport, } impl TransportType { pub fn as_str(&self) -> &'static str { match self { Self::Polling => "polling", Self::WebSocket => "websocket", Self::WebTransport => "webtransport", } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SessionState { Connecting, Open, Upgrading, Closing, Closed, } pub struct Session { pub sid: String, pub transport: TransportType, pub state: SessionState, pub created_at: Instant, pub last_ping: Instant, pub tx: mpsc::Sender, pub pending_packets: Vec, pub notify: Arc, pub upgrade_tx: Option>, } impl Session { pub fn new(sid: String, transport: TransportType) -> (Self, mpsc::Receiver) { let (tx, rx) = mpsc::channel(256); let session = Self { sid, transport, state: SessionState::Connecting, created_at: Instant::now(), last_ping: Instant::now(), tx, pending_packets: Vec::new(), notify: Arc::new(Notify::new()), upgrade_tx: None, }; (session, rx) } /// Send a packet through the mpsc channel (for WS/WT transport consumption). pub fn send_packet(&self, packet: Packet) -> Result<(), mpsc::error::TrySendError> { self.tx.try_send(packet) } /// Push a packet using the appropriate mechanism for the current transport. /// Polling: buffer in pending_packets + notify waiting GET request. /// WS/WT: try mpsc channel first; if full, buffer as fallback + notify. pub fn push_packet(&mut self, packet: Packet) { if self.transport == TransportType::Polling { self.pending_packets.push(packet); self.notify.notify_one(); } else { if self.tx.try_send(packet.clone()).is_err() { self.pending_packets.push(packet); self.notify.notify_one(); } } } /// Buffer a packet in pending_packets and notify any waiting polling request. pub fn buffer_packet(&mut self, packet: Packet) { self.pending_packets.push(packet); self.notify.notify_one(); } pub fn take_pending(&mut self) -> Vec { std::mem::take(&mut self.pending_packets) } pub fn update_ping(&mut self) { self.last_ping = Instant::now(); } pub fn set_transport(&mut self, transport: TransportType) { self.transport = transport; } pub fn set_state(&mut self, state: SessionState) { self.state = state; } } #[derive(Clone)] pub struct SessionStore { pub sessions: Arc>>>, } impl SessionStore { pub fn new() -> Self { Self { sessions: Arc::new(DashMap::new()), } } /// Create a new session. Returns the mpsc receiver for transport-level packet consumption. /// Logs a warning if the SID collides with an existing session (extremely unlikely with crypto RNG). pub fn create(&self, sid: String, transport: TransportType) -> mpsc::Receiver { let (session, rx) = Session::new(sid.clone(), transport); let old = self .sessions .insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session))); if old.is_some() { tracing::warn!( "Session ID collision for SID {}, replacing existing session", sid ); } if let Some(m) = crate::telemetry::metrics::try_get() { m.engine_sessions_active.add( 1, &[opentelemetry::KeyValue::new( "transport", transport.as_str(), )], ); } rx } pub fn get(&self, sid: &str) -> Option>> { self.sessions.get(sid).map(|r| r.value().clone()) } pub fn remove(&self, sid: &str) { if self.sessions.remove(sid).is_some() && let Some(m) = crate::telemetry::metrics::try_get() { m.engine_sessions_active.add(-1, &[]); } } pub fn exists(&self, sid: &str) -> bool { self.sessions.contains_key(sid) } pub fn len(&self) -> usize { self.sessions.len() } pub fn is_empty(&self) -> bool { self.sessions.is_empty() } } impl Default for SessionStore { fn default() -> Self { Self::new() } } /// Generate a random session ID using a cryptographically secure RNG. /// rand 0.9's default RNG (ChaCha8Rng seeded from OsRng) is crypto-secure. pub fn generate_sid() -> String { use rand::Rng; const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-"; let mut rng = rand::rng(); (0..20) .map(|_| { let idx = rng.random_range(0..CHARSET.len()); CHARSET[idx] as char }) .collect() }