Files
imks/engine/webtransport.rs
T
zhenyi 821537186e refactor(tests): reformat code and update dependency management
- Reorganized import statements in adapter tests for better readability
- Replaced or_insert_with(Vec::new) with or_default() in test closures
- Updated Cargo.lock with new dependency versions and checksums
- Added TLS features to tonic dependency configuration
- Included sqlx, chrono, and uuid dependencies with specific features
- Added jsonwebtoken and arc-swap as project dependencies
- Reformatted assertion statements to comply with line length limits
- Adjusted base64 import order in engine codec module
- Updated protobuf include statement formatting
2026-06-11 12:11:05 +08:00

239 lines
7.9 KiB
Rust

use std::sync::Arc;
use wtransport::{Connection, Endpoint, Identity, ServerConfig};
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
use crate::{ImksError, ImksResult};
pub async fn run_webtransport_server(
port: u16,
cert_path: &str,
key_path: &str,
store: SessionStore,
config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> ImksResult<()> {
let identity = Identity::load_pemfiles(cert_path, key_path)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let server_config = ServerConfig::builder()
.with_bind_default(port)
.with_identity(identity)
.build();
let server =
Endpoint::server(server_config).map_err(|e| ImksError::WebTransport(e.to_string()))?;
tracing::info!("WebTransport server listening on UDP port {}", port);
loop {
let incoming = server.accept().await;
let store = store.clone();
let config = config.clone();
let on_message = on_message.clone();
tokio::spawn(async move {
match handle_webtransport_session(incoming, store, config, on_message).await {
Ok(_) => {}
Err(e) => {
tracing::error!("WebTransport session error: {}", e);
}
}
});
}
}
async fn handle_webtransport_session(
incoming: wtransport::endpoint::IncomingSession,
store: SessionStore,
config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> ImksResult<()> {
let request = incoming
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let connection = request
.accept()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let sid = crate::engine::session::generate_sid();
let mut rx = store.create(sid.clone(), TransportType::WebTransport);
if let Some(s) = store.get(&sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
let handshake = crate::engine::packet::HandshakeData {
sid: sid.clone(),
upgrades: vec![],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let open_packet = Packet::open(&handshake);
send_wt_packet(&connection, &open_packet).await?;
let store_clone = store.clone();
let sid_clone = sid.clone();
let on_message_clone = on_message.clone();
let connection_recv = connection.clone();
let max_payload = config.max_payload;
// Reuse buffer across recv iterations instead of allocating 65KB each time
let recv_handle = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
while let Ok((mut send, mut recv)) = connection_recv.accept_bi().await {
// Reset buffer length for the next read without deallocating
buf.resize(65536, 0);
match recv.read(&mut buf).await {
Ok(Some(n)) => {
if n > max_payload {
tracing::warn!(
"WebTransport payload too large ({}) for session {}",
n,
sid_clone
);
continue;
}
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
match packet.packet_type {
PacketType::Ping => {
let pong = Packet::pong("");
if send_wt_packet_on_stream(&mut send, &pong).await.is_err() {
break;
}
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
}
}
}
Ok(None) => break,
Err(_) => break,
}
}
Ok::<(), ImksError>(())
});
let connection_send = connection.clone();
let send_handle = tokio::spawn(async move {
while let Some(packet) = rx.recv().await {
if send_wt_packet(&connection_send, &packet).await.is_err() {
break;
}
}
});
let connection_ping = connection.clone();
let store_ping = store.clone();
let sid_ping = sid.clone();
let ping_interval = config.ping_interval;
let ping_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
loop {
interval.tick().await;
if let Some(s) = store_ping.get(&sid_ping) {
let state = {
let s = s.read().await;
s.state
};
if state == SessionState::Closed {
break;
}
let ping = Packet::ping("");
if send_wt_packet(&connection_ping, &ping).await.is_err() {
break;
}
} else {
break;
}
}
});
tokio::select! {
_ = recv_handle => {},
_ = send_handle => {},
_ = ping_handle => {},
}
store.remove(&sid);
Ok(())
}
async fn send_wt_packet(connection: &Connection, packet: &Packet) -> ImksResult<()> {
let (mut send, _recv) = connection
.open_bi()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.write_all(&encoded)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.finish()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
Ok(())
}
async fn send_wt_packet_on_stream(
send: &mut wtransport::SendStream,
packet: &Packet,
) -> ImksResult<()> {
let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.write_all(&encoded)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.finish()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
Ok(())
}