use std::sync::Arc; use wtransport::{Connection, Endpoint, ServerConfig, Identity}; use crate::engine::codec; use crate::engine::packet::{Packet, PacketType}; use crate::engine::server::EngineConfig; use crate::engine::session::{SessionState, SessionStore, TransportType}; pub async fn run_webtransport_server( port: u16, cert_path: &str, key_path: &str, store: SessionStore, config: EngineConfig, on_message: Arc, ) -> Result<(), Box> { let identity = Identity::load_pemfiles(cert_path, key_path).await?; let server_config = ServerConfig::builder() .with_bind_default(port) .with_identity(identity) .build(); let server = Endpoint::server(server_config)?; tracing::info!("WebTransport server listening on UDP port {}", port); loop { let incoming = server.accept().await; let store = store.clone(); let config = config.clone(); let on_message = on_message.clone(); tokio::spawn(async move { match handle_webtransport_session(incoming, store, config, on_message).await { Ok(_) => {} Err(e) => { tracing::error!("WebTransport session error: {}", e); } } }); } } async fn handle_webtransport_session( incoming: wtransport::endpoint::IncomingSession, store: SessionStore, config: EngineConfig, on_message: Arc, ) -> Result<(), Box> { let request = incoming.await?; let connection = request.accept().await?; let sid = crate::engine::session::generate_sid(); let mut rx = store.create(sid.clone(), TransportType::WebTransport); if let Some(s) = store.get(&sid) { let mut s = s.write().await; s.set_state(SessionState::Open); } let handshake = crate::engine::packet::HandshakeData { sid: sid.clone(), upgrades: vec![], ping_interval: config.ping_interval, ping_timeout: config.ping_timeout, max_payload: config.max_payload, }; let open_packet = Packet::open(&handshake); send_wt_packet(&connection, &open_packet).await?; let store_clone = store.clone(); let sid_clone = sid.clone(); let on_message_clone = on_message.clone(); let connection_recv = connection.clone(); let max_payload = config.max_payload; // Reuse buffer across recv iterations instead of allocating 65KB each time let recv_handle = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; loop { match connection_recv.accept_bi().await { Ok((mut send, mut recv)) => { // Reset buffer length for the next read without deallocating buf.resize(65536, 0); match recv.read(&mut buf).await { Ok(Some(n)) => { if n > max_payload { tracing::warn!( "WebTransport payload too large ({}) for session {}", n, sid_clone ); continue; } if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) { match packet.packet_type { PacketType::Ping => { let pong = Packet::pong(""); if send_wt_packet_on_stream(&mut send, &pong) .await .is_err() { break; } } PacketType::Pong => { if let Some(s) = store_clone.get(&sid_clone) { let mut s = s.write().await; s.update_ping(); } } PacketType::Message => { let on_msg = on_message_clone.clone(); let sid = sid_clone.clone(); tokio::spawn(async move { on_msg(sid, packet); }); } PacketType::Close => { if let Some(s) = store_clone.get(&sid_clone) { let mut s = s.write().await; s.set_state(SessionState::Closed); } store_clone.remove(&sid_clone); break; } _ => {} } } } Ok(None) => break, Err(_) => break, } } Err(_) => break, } } Ok::<(), Box>(()) }); let connection_send = connection.clone(); let send_handle = tokio::spawn(async move { while let Some(packet) = rx.recv().await { if send_wt_packet(&connection_send, &packet).await.is_err() { break; } } }); let connection_ping = connection.clone(); let store_ping = store.clone(); let sid_ping = sid.clone(); let ping_interval = config.ping_interval; let ping_handle = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval)); loop { interval.tick().await; if let Some(s) = store_ping.get(&sid_ping) { let state = { let s = s.read().await; s.state }; if state == SessionState::Closed { break; } let ping = Packet::ping(""); if send_wt_packet(&connection_ping, &ping).await.is_err() { break; } } else { break; } } }); tokio::select! { _ = recv_handle => {}, _ = send_handle => {}, _ = ping_handle => {}, } store.remove(&sid); Ok(()) } async fn send_wt_packet( connection: &Connection, packet: &Packet, ) -> Result<(), Box> { let (mut send, _recv) = connection.open_bi().await?.await?; let encoded = codec::encode_packet_binary_ws(packet); let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_)); let header = codec::encode_webtransport_header(encoded.len(), is_binary); send.write_all(&header).await?; send.write_all(&encoded).await?; send.finish().await?; Ok(()) } async fn send_wt_packet_on_stream( send: &mut wtransport::SendStream, packet: &Packet, ) -> Result<(), Box> { let encoded = codec::encode_packet_binary_ws(packet); let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_)); let header = codec::encode_webtransport_header(encoded.len(), is_binary); send.write_all(&header).await?; send.write_all(&encoded).await?; send.finish().await?; Ok(()) }