Files
zhenyi 821537186e refactor(tests): reformat code and update dependency management
- Reorganized import statements in adapter tests for better readability
- Replaced or_insert_with(Vec::new) with or_default() in test closures
- Updated Cargo.lock with new dependency versions and checksums
- Added TLS features to tonic dependency configuration
- Included sqlx, chrono, and uuid dependencies with specific features
- Added jsonwebtoken and arc-swap as project dependencies
- Reformatted assertion statements to comply with line length limits
- Adjusted base64 import order in engine codec module
- Updated protobuf include statement formatting
2026-06-11 12:11:05 +08:00

186 lines
5.6 KiB
Rust

use std::sync::Arc;
use std::time::Duration;
use actix_web::{HttpRequest, HttpResponse, web};
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
#[derive(Debug, serde::Deserialize)]
pub struct PollingQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn polling_get(
_req: HttpRequest,
query: web::Query<PollingQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
_on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> HttpResponse {
if query.eio.as_deref() != Some("4") {
return HttpResponse::BadRequest().body("invalid EIO version");
}
if query.transport.as_deref() != Some("polling") {
return HttpResponse::BadRequest().body("invalid transport");
}
let sid = match &query.sid {
Some(sid) => sid.clone(),
None => {
return handle_handshake(&store, &config).await;
}
};
let session = match store.get(&sid) {
Some(s) => s,
None => return HttpResponse::BadRequest().body("unknown session"),
};
// Check session state and take any buffered pending packets
let notify = {
let mut session_guard = session.write().await;
if session_guard.state == SessionState::Closed {
return HttpResponse::BadRequest().body("session closed");
}
let pending = session_guard.take_pending();
if !pending.is_empty() {
let payload = codec::encode_payload(&pending);
return HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload);
}
session_guard.notify.clone()
};
let timeout = Duration::from_millis(config.ping_interval + config.ping_timeout);
let _result = tokio::time::timeout(timeout, notify.notified()).await;
// Re-verify session still exists after wait (may have been removed by heartbeat)
if store.get(&sid).is_none() {
return HttpResponse::BadRequest().body("session closed");
}
let mut session_guard = session.write().await;
let packets = session_guard.take_pending();
if packets.is_empty() {
let noop = codec::encode_packet(&Packet::noop());
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(noop)
} else {
let payload = codec::encode_payload(&packets);
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload)
}
}
pub async fn polling_post(
_req: HttpRequest,
body: web::Bytes,
query: web::Query<PollingQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> HttpResponse {
if query.eio.as_deref() != Some("4") {
return HttpResponse::BadRequest().body("invalid EIO version");
}
if query.transport.as_deref() != Some("polling") {
return HttpResponse::BadRequest().body("invalid transport");
}
// Check payload size BEFORE attempting to decode
if body.len() > config.max_payload {
return HttpResponse::PayloadTooLarge().body("payload too large");
}
let sid = match &query.sid {
Some(sid) => sid,
None => return HttpResponse::BadRequest().body("missing sid"),
};
let session = match store.get(sid) {
Some(s) => s,
None => return HttpResponse::BadRequest().body("unknown session"),
};
let body_str = match std::str::from_utf8(&body) {
Ok(s) => s,
Err(_) => return HttpResponse::BadRequest().body("invalid utf8"),
};
let packets = match codec::decode_payload(body_str) {
Ok(p) => p,
Err(_) => return HttpResponse::BadRequest().body("invalid payload"),
};
let mut session_guard = session.write().await;
for packet in packets {
match packet.packet_type {
PacketType::Pong => {
session_guard.update_ping();
}
PacketType::Message => {
let on_msg = on_message.get_ref().clone();
let sid_owned = sid.clone();
tokio::spawn(async move {
on_msg(sid_owned, packet);
});
}
PacketType::Close => {
session_guard.set_state(SessionState::Closed);
store.remove(sid);
return HttpResponse::Ok().body("ok");
}
_ => {}
}
}
HttpResponse::Ok().body("ok")
}
async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpResponse {
let sid = crate::engine::session::generate_sid();
let handshake = crate::engine::packet::HandshakeData {
sid: sid.clone(),
upgrades: vec!["websocket".to_string()],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let _rx = store.create(sid.clone(), TransportType::Polling);
if let Some(session) = store.get(&sid) {
let mut session = session.write().await;
session.set_state(SessionState::Open);
}
let packet = Packet::open(&handshake);
let payload = codec::encode_packet(&packet);
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload)
}
pub fn configure_polling(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(polling_get))
.route("/engine.io/", web::post().to(polling_post));
}