821537186e
- Reorganized import statements in adapter tests for better readability - Replaced or_insert_with(Vec::new) with or_default() in test closures - Updated Cargo.lock with new dependency versions and checksums - Added TLS features to tonic dependency configuration - Included sqlx, chrono, and uuid dependencies with specific features - Added jsonwebtoken and arc-swap as project dependencies - Reformatted assertion statements to comply with line length limits - Adjusted base64 import order in engine codec module - Updated protobuf include statement formatting
239 lines
7.9 KiB
Rust
239 lines
7.9 KiB
Rust
use std::sync::Arc;
|
|
|
|
use wtransport::{Connection, Endpoint, Identity, ServerConfig};
|
|
|
|
use crate::engine::codec;
|
|
use crate::engine::packet::{Packet, PacketType};
|
|
use crate::engine::server::EngineConfig;
|
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
|
use crate::{ImksError, ImksResult};
|
|
|
|
pub async fn run_webtransport_server(
|
|
port: u16,
|
|
cert_path: &str,
|
|
key_path: &str,
|
|
store: SessionStore,
|
|
config: EngineConfig,
|
|
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
|
) -> ImksResult<()> {
|
|
let identity = Identity::load_pemfiles(cert_path, key_path)
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
|
|
let server_config = ServerConfig::builder()
|
|
.with_bind_default(port)
|
|
.with_identity(identity)
|
|
.build();
|
|
|
|
let server =
|
|
Endpoint::server(server_config).map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
|
|
tracing::info!("WebTransport server listening on UDP port {}", port);
|
|
|
|
loop {
|
|
let incoming = server.accept().await;
|
|
|
|
let store = store.clone();
|
|
let config = config.clone();
|
|
let on_message = on_message.clone();
|
|
|
|
tokio::spawn(async move {
|
|
match handle_webtransport_session(incoming, store, config, on_message).await {
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
tracing::error!("WebTransport session error: {}", e);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn handle_webtransport_session(
|
|
incoming: wtransport::endpoint::IncomingSession,
|
|
store: SessionStore,
|
|
config: EngineConfig,
|
|
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
|
) -> ImksResult<()> {
|
|
let request = incoming
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
let connection = request
|
|
.accept()
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
|
|
let sid = crate::engine::session::generate_sid();
|
|
let mut rx = store.create(sid.clone(), TransportType::WebTransport);
|
|
|
|
if let Some(s) = store.get(&sid) {
|
|
let mut s = s.write().await;
|
|
s.set_state(SessionState::Open);
|
|
}
|
|
|
|
let handshake = crate::engine::packet::HandshakeData {
|
|
sid: sid.clone(),
|
|
upgrades: vec![],
|
|
ping_interval: config.ping_interval,
|
|
ping_timeout: config.ping_timeout,
|
|
max_payload: config.max_payload,
|
|
};
|
|
|
|
let open_packet = Packet::open(&handshake);
|
|
send_wt_packet(&connection, &open_packet).await?;
|
|
|
|
let store_clone = store.clone();
|
|
let sid_clone = sid.clone();
|
|
let on_message_clone = on_message.clone();
|
|
let connection_recv = connection.clone();
|
|
let max_payload = config.max_payload;
|
|
|
|
// Reuse buffer across recv iterations instead of allocating 65KB each time
|
|
let recv_handle = tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
while let Ok((mut send, mut recv)) = connection_recv.accept_bi().await {
|
|
// Reset buffer length for the next read without deallocating
|
|
buf.resize(65536, 0);
|
|
match recv.read(&mut buf).await {
|
|
Ok(Some(n)) => {
|
|
if n > max_payload {
|
|
tracing::warn!(
|
|
"WebTransport payload too large ({}) for session {}",
|
|
n,
|
|
sid_clone
|
|
);
|
|
continue;
|
|
}
|
|
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
|
|
match packet.packet_type {
|
|
PacketType::Ping => {
|
|
let pong = Packet::pong("");
|
|
if send_wt_packet_on_stream(&mut send, &pong).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
PacketType::Pong => {
|
|
if let Some(s) = store_clone.get(&sid_clone) {
|
|
let mut s = s.write().await;
|
|
s.update_ping();
|
|
}
|
|
}
|
|
PacketType::Message => {
|
|
let on_msg = on_message_clone.clone();
|
|
let sid = sid_clone.clone();
|
|
tokio::spawn(async move {
|
|
on_msg(sid, packet);
|
|
});
|
|
}
|
|
PacketType::Close => {
|
|
if let Some(s) = store_clone.get(&sid_clone) {
|
|
let mut s = s.write().await;
|
|
s.set_state(SessionState::Closed);
|
|
}
|
|
store_clone.remove(&sid_clone);
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
Ok(None) => break,
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
Ok::<(), ImksError>(())
|
|
});
|
|
|
|
let connection_send = connection.clone();
|
|
let send_handle = tokio::spawn(async move {
|
|
while let Some(packet) = rx.recv().await {
|
|
if send_wt_packet(&connection_send, &packet).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
let connection_ping = connection.clone();
|
|
let store_ping = store.clone();
|
|
let sid_ping = sid.clone();
|
|
let ping_interval = config.ping_interval;
|
|
let ping_handle = tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
|
|
|
|
loop {
|
|
interval.tick().await;
|
|
|
|
if let Some(s) = store_ping.get(&sid_ping) {
|
|
let state = {
|
|
let s = s.read().await;
|
|
s.state
|
|
};
|
|
|
|
if state == SessionState::Closed {
|
|
break;
|
|
}
|
|
|
|
let ping = Packet::ping("");
|
|
if send_wt_packet(&connection_ping, &ping).await.is_err() {
|
|
break;
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
tokio::select! {
|
|
_ = recv_handle => {},
|
|
_ = send_handle => {},
|
|
_ = ping_handle => {},
|
|
}
|
|
|
|
store.remove(&sid);
|
|
Ok(())
|
|
}
|
|
|
|
async fn send_wt_packet(connection: &Connection, packet: &Packet) -> ImksResult<()> {
|
|
let (mut send, _recv) = connection
|
|
.open_bi()
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
let encoded = codec::encode_packet_binary_ws(packet);
|
|
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
|
|
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
|
|
|
|
send.write_all(&header)
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
send.write_all(&encoded)
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
send.finish()
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn send_wt_packet_on_stream(
|
|
send: &mut wtransport::SendStream,
|
|
packet: &Packet,
|
|
) -> ImksResult<()> {
|
|
let encoded = codec::encode_packet_binary_ws(packet);
|
|
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
|
|
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
|
|
|
|
send.write_all(&header)
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
send.write_all(&encoded)
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
send.finish()
|
|
.await
|
|
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|