821537186e
- Reorganized import statements in adapter tests for better readability - Replaced or_insert_with(Vec::new) with or_default() in test closures - Updated Cargo.lock with new dependency versions and checksums - Added TLS features to tonic dependency configuration - Included sqlx, chrono, and uuid dependencies with specific features - Added jsonwebtoken and arc-swap as project dependencies - Reformatted assertion statements to comply with line length limits - Adjusted base64 import order in engine codec module - Updated protobuf include statement formatting
268 lines
10 KiB
Rust
268 lines
10 KiB
Rust
use std::sync::Arc;
|
|
|
|
use actix_web::{HttpRequest, HttpResponse, web};
|
|
use actix_ws::Message;
|
|
|
|
use crate::engine::codec;
|
|
use crate::engine::packet::{Packet, PacketData, PacketType};
|
|
use crate::engine::server::EngineConfig;
|
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
|
|
|
#[derive(Debug, serde::Deserialize)]
|
|
pub struct WsQuery {
|
|
#[serde(rename = "EIO")]
|
|
pub eio: Option<String>,
|
|
pub transport: Option<String>,
|
|
pub sid: Option<String>,
|
|
}
|
|
|
|
pub async fn websocket_handler(
|
|
req: HttpRequest,
|
|
body: web::Payload,
|
|
query: web::Query<WsQuery>,
|
|
store: web::Data<SessionStore>,
|
|
config: web::Data<EngineConfig>,
|
|
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|
if query.eio.as_deref() != Some("4") {
|
|
return Ok(HttpResponse::BadRequest().body("invalid EIO version"));
|
|
}
|
|
|
|
if query.transport.as_deref() != Some("websocket") {
|
|
return Ok(HttpResponse::BadRequest().body("invalid transport"));
|
|
}
|
|
|
|
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?;
|
|
|
|
let sid = query.sid.clone();
|
|
|
|
if let Some(ref sid) = sid
|
|
&& !store.exists(sid)
|
|
{
|
|
return Ok(HttpResponse::BadRequest().body("unknown session"));
|
|
}
|
|
|
|
// Create or reuse session, obtaining the mpsc receiver for the forwarding task
|
|
let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
|
|
// Upgrade: session already exists, replace its channel and drain pending packets
|
|
let session_arc = match store.get(sid) {
|
|
Some(s) => s,
|
|
None => {
|
|
tracing::error!("Session {} not found for upgrade", sid);
|
|
return Ok(HttpResponse::InternalServerError().body("session not found"));
|
|
}
|
|
};
|
|
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
|
|
{
|
|
let mut s = session_arc.write().await;
|
|
// Swap tx atomically: old_tx will be dropped, closing its channel.
|
|
// Any packets in the old rx are consumed by the old send_handle,
|
|
// which then exits when it sees the channel close.
|
|
// Drain pending_packets (from polling buffering) into new channel.
|
|
let pending = s.take_pending();
|
|
for packet in pending {
|
|
let _ = new_tx.try_send(packet);
|
|
}
|
|
s.tx = new_tx;
|
|
s.set_transport(TransportType::WebSocket);
|
|
}
|
|
(sid.clone(), new_rx)
|
|
} else {
|
|
// New connection: generate SID and create session
|
|
let new_sid = crate::engine::session::generate_sid();
|
|
let rx = store.create(new_sid.clone(), TransportType::WebSocket);
|
|
if let Some(s) = store.get(&new_sid) {
|
|
let mut s = s.write().await;
|
|
s.set_state(SessionState::Open);
|
|
}
|
|
(new_sid, rx)
|
|
};
|
|
|
|
let handshake = crate::engine::packet::HandshakeData {
|
|
sid: session_sid.clone(),
|
|
upgrades: vec![],
|
|
ping_interval: config.ping_interval,
|
|
ping_timeout: config.ping_timeout,
|
|
max_payload: config.max_payload,
|
|
};
|
|
|
|
let open_packet = Packet::open(&handshake);
|
|
let open_msg = codec::encode_packet(&open_packet);
|
|
if ws_session.text(open_msg).await.is_err() {
|
|
tracing::warn!(
|
|
"Failed to send open packet to WebSocket session {}",
|
|
session_sid
|
|
);
|
|
store.remove(&session_sid);
|
|
return Ok(response);
|
|
}
|
|
|
|
let store_clone = store.get_ref().clone();
|
|
let on_message_clone = on_message.get_ref().clone();
|
|
let sid_clone = session_sid.clone();
|
|
let ws_session_clone = ws_session.clone();
|
|
let max_payload = config.max_payload;
|
|
|
|
// Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx)
|
|
let sid_for_send = session_sid.clone();
|
|
let store_for_send = store.get_ref().clone();
|
|
let mut ws_for_send = ws_session.clone();
|
|
let send_handle = actix_rt::spawn(async move {
|
|
while let Some(packet) = session_rx.recv().await {
|
|
let encoded = codec::encode_packet(&packet);
|
|
if ws_for_send.text(encoded).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
// Session channel closed — clean up
|
|
store_for_send.remove(&sid_for_send);
|
|
});
|
|
|
|
// Task 2: Read incoming WebSocket messages → dispatch
|
|
let recv_handle = actix_rt::spawn(async move {
|
|
let mut ws_session = ws_session_clone;
|
|
while let Some(Ok(msg)) = msg_stream.recv().await {
|
|
match msg {
|
|
Message::Text(text) => {
|
|
if text.len() > max_payload {
|
|
tracing::warn!(
|
|
"Text payload too large ({}) for session {}",
|
|
text.len(),
|
|
sid_clone
|
|
);
|
|
let _ = ws_session.close(None).await;
|
|
break;
|
|
}
|
|
|
|
if let Ok(packet) = codec::decode_packet(&text) {
|
|
match packet.packet_type {
|
|
PacketType::Ping => {
|
|
if let PacketData::Text(ref data) = packet.data
|
|
&& data == "probe"
|
|
{
|
|
let pong = Packet::pong("probe");
|
|
let pong_msg = codec::encode_packet(&pong);
|
|
let _ = ws_session.text(pong_msg).await;
|
|
continue;
|
|
}
|
|
let pong = Packet::pong("");
|
|
let pong_msg = codec::encode_packet(&pong);
|
|
let _ = ws_session.text(pong_msg).await;
|
|
}
|
|
PacketType::Pong => {
|
|
if let Some(s) = store_clone.get(&sid_clone) {
|
|
let mut s = s.write().await;
|
|
s.update_ping();
|
|
}
|
|
}
|
|
PacketType::Upgrade => {
|
|
if let Some(s) = store_clone.get(&sid_clone) {
|
|
let mut s = s.write().await;
|
|
s.set_transport(TransportType::WebSocket);
|
|
s.set_state(SessionState::Open);
|
|
}
|
|
}
|
|
PacketType::Message => {
|
|
let on_msg = on_message_clone.clone();
|
|
let sid = sid_clone.clone();
|
|
tokio::spawn(async move {
|
|
on_msg(sid, packet);
|
|
});
|
|
}
|
|
PacketType::Close => {
|
|
if let Some(s) = store_clone.get(&sid_clone) {
|
|
let mut s = s.write().await;
|
|
s.set_state(SessionState::Closed);
|
|
}
|
|
store_clone.remove(&sid_clone);
|
|
let _ = ws_session.close(None).await;
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
Message::Binary(bin) => {
|
|
// Enforce max payload size for binary frames
|
|
if bin.len() > max_payload {
|
|
tracing::warn!(
|
|
"Binary payload too large ({}) for session {}",
|
|
bin.len(),
|
|
sid_clone
|
|
);
|
|
continue;
|
|
}
|
|
|
|
if let Ok(packet) = codec::decode_packet_ws(&bin)
|
|
&& packet.packet_type == PacketType::Message
|
|
{
|
|
let on_msg = on_message_clone.clone();
|
|
let sid = sid_clone.clone();
|
|
tokio::spawn(async move {
|
|
on_msg(sid, packet);
|
|
});
|
|
}
|
|
}
|
|
Message::Close(_) => {
|
|
if let Some(s) = store_clone.get(&sid_clone) {
|
|
let mut s = s.write().await;
|
|
s.set_state(SessionState::Closed);
|
|
}
|
|
store_clone.remove(&sid_clone);
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Task 3: Heartbeat ping sender
|
|
let sid_for_ping = session_sid.clone();
|
|
let store_for_ping = store.get_ref().clone();
|
|
let mut ws_for_ping = ws_session.clone();
|
|
let ping_interval = config.ping_interval;
|
|
let ping_handle = tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
|
|
|
|
loop {
|
|
interval.tick().await;
|
|
|
|
if let Some(s) = store_for_ping.get(&sid_for_ping) {
|
|
let session_state = {
|
|
let s = s.read().await;
|
|
s.state
|
|
};
|
|
|
|
if session_state == SessionState::Closed {
|
|
break;
|
|
}
|
|
|
|
let ping = Packet::ping("");
|
|
let ping_msg = codec::encode_packet(&ping);
|
|
if ws_for_ping.text(ping_msg).await.is_err() {
|
|
break;
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
// Wait for any task to finish, then clean up
|
|
actix_rt::spawn(async move {
|
|
// actix_rt::spawn returns JoinHandle which is compatible with tokio::select!
|
|
tokio::select! {
|
|
_ = send_handle => {},
|
|
_ = recv_handle => {},
|
|
_ = ping_handle => {},
|
|
}
|
|
store.remove(&session_sid);
|
|
});
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
pub fn configure_websocket(cfg: &mut web::ServiceConfig) {
|
|
cfg.route("/engine.io/", web::get().to(websocket_handler));
|
|
}
|