Files
imks/engine/websocket.rs
T
zhenyi 821537186e refactor(tests): reformat code and update dependency management
- Reorganized import statements in adapter tests for better readability
- Replaced or_insert_with(Vec::new) with or_default() in test closures
- Updated Cargo.lock with new dependency versions and checksums
- Added TLS features to tonic dependency configuration
- Included sqlx, chrono, and uuid dependencies with specific features
- Added jsonwebtoken and arc-swap as project dependencies
- Reformatted assertion statements to comply with line length limits
- Adjusted base64 import order in engine codec module
- Updated protobuf include statement formatting
2026-06-11 12:11:05 +08:00

268 lines
10 KiB
Rust

use std::sync::Arc;
use actix_web::{HttpRequest, HttpResponse, web};
use actix_ws::Message;
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketData, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
#[derive(Debug, serde::Deserialize)]
pub struct WsQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn websocket_handler(
req: HttpRequest,
body: web::Payload,
query: web::Query<WsQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> Result<HttpResponse, actix_web::Error> {
if query.eio.as_deref() != Some("4") {
return Ok(HttpResponse::BadRequest().body("invalid EIO version"));
}
if query.transport.as_deref() != Some("websocket") {
return Ok(HttpResponse::BadRequest().body("invalid transport"));
}
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?;
let sid = query.sid.clone();
if let Some(ref sid) = sid
&& !store.exists(sid)
{
return Ok(HttpResponse::BadRequest().body("unknown session"));
}
// Create or reuse session, obtaining the mpsc receiver for the forwarding task
let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
// Upgrade: session already exists, replace its channel and drain pending packets
let session_arc = match store.get(sid) {
Some(s) => s,
None => {
tracing::error!("Session {} not found for upgrade", sid);
return Ok(HttpResponse::InternalServerError().body("session not found"));
}
};
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
{
let mut s = session_arc.write().await;
// Swap tx atomically: old_tx will be dropped, closing its channel.
// Any packets in the old rx are consumed by the old send_handle,
// which then exits when it sees the channel close.
// Drain pending_packets (from polling buffering) into new channel.
let pending = s.take_pending();
for packet in pending {
let _ = new_tx.try_send(packet);
}
s.tx = new_tx;
s.set_transport(TransportType::WebSocket);
}
(sid.clone(), new_rx)
} else {
// New connection: generate SID and create session
let new_sid = crate::engine::session::generate_sid();
let rx = store.create(new_sid.clone(), TransportType::WebSocket);
if let Some(s) = store.get(&new_sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
(new_sid, rx)
};
let handshake = crate::engine::packet::HandshakeData {
sid: session_sid.clone(),
upgrades: vec![],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let open_packet = Packet::open(&handshake);
let open_msg = codec::encode_packet(&open_packet);
if ws_session.text(open_msg).await.is_err() {
tracing::warn!(
"Failed to send open packet to WebSocket session {}",
session_sid
);
store.remove(&session_sid);
return Ok(response);
}
let store_clone = store.get_ref().clone();
let on_message_clone = on_message.get_ref().clone();
let sid_clone = session_sid.clone();
let ws_session_clone = ws_session.clone();
let max_payload = config.max_payload;
// Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx)
let sid_for_send = session_sid.clone();
let store_for_send = store.get_ref().clone();
let mut ws_for_send = ws_session.clone();
let send_handle = actix_rt::spawn(async move {
while let Some(packet) = session_rx.recv().await {
let encoded = codec::encode_packet(&packet);
if ws_for_send.text(encoded).await.is_err() {
break;
}
}
// Session channel closed — clean up
store_for_send.remove(&sid_for_send);
});
// Task 2: Read incoming WebSocket messages → dispatch
let recv_handle = actix_rt::spawn(async move {
let mut ws_session = ws_session_clone;
while let Some(Ok(msg)) = msg_stream.recv().await {
match msg {
Message::Text(text) => {
if text.len() > max_payload {
tracing::warn!(
"Text payload too large ({}) for session {}",
text.len(),
sid_clone
);
let _ = ws_session.close(None).await;
break;
}
if let Ok(packet) = codec::decode_packet(&text) {
match packet.packet_type {
PacketType::Ping => {
if let PacketData::Text(ref data) = packet.data
&& data == "probe"
{
let pong = Packet::pong("probe");
let pong_msg = codec::encode_packet(&pong);
let _ = ws_session.text(pong_msg).await;
continue;
}
let pong = Packet::pong("");
let pong_msg = codec::encode_packet(&pong);
let _ = ws_session.text(pong_msg).await;
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Upgrade => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_transport(TransportType::WebSocket);
s.set_state(SessionState::Open);
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
let _ = ws_session.close(None).await;
break;
}
_ => {}
}
}
}
Message::Binary(bin) => {
// Enforce max payload size for binary frames
if bin.len() > max_payload {
tracing::warn!(
"Binary payload too large ({}) for session {}",
bin.len(),
sid_clone
);
continue;
}
if let Ok(packet) = codec::decode_packet_ws(&bin)
&& packet.packet_type == PacketType::Message
{
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
}
Message::Close(_) => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
}
}
});
// Task 3: Heartbeat ping sender
let sid_for_ping = session_sid.clone();
let store_for_ping = store.get_ref().clone();
let mut ws_for_ping = ws_session.clone();
let ping_interval = config.ping_interval;
let ping_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
loop {
interval.tick().await;
if let Some(s) = store_for_ping.get(&sid_for_ping) {
let session_state = {
let s = s.read().await;
s.state
};
if session_state == SessionState::Closed {
break;
}
let ping = Packet::ping("");
let ping_msg = codec::encode_packet(&ping);
if ws_for_ping.text(ping_msg).await.is_err() {
break;
}
} else {
break;
}
}
});
// Wait for any task to finish, then clean up
actix_rt::spawn(async move {
// actix_rt::spawn returns JoinHandle which is compatible with tokio::select!
tokio::select! {
_ = send_handle => {},
_ = recv_handle => {},
_ = ping_handle => {},
}
store.remove(&session_sid);
});
Ok(response)
}
pub fn configure_websocket(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(websocket_handler));
}