feat(auth): add authentication protocol definitions and build configuration
- Add TokenClaims message for JWT payload structure with user id, issuer, timestamps, and scopes - Implement IssueTokenRequest/Response for creating access and refresh tokens with TTL support - Create RefreshTokenRequest/Response for token rotation functionality - Define RevokeTokenRequest/Response with support for single token or user-wide revocation - Add VerifyTokenRequest/Response for validating JWT tokens with detailed claims information - Implement signing key distribution system with GetSigningKeysRequest/Response - Create TokenService gRPC service with IssueToken, RefreshToken, RevokeToken, VerifyToken, and GetSigningKeys methods - Add build.rs configuration to compile proto files using tonic_prost_build - Include channel, channel_settings, member, and permission protocol definitions for IM services - Generate Rust code bindings through pb/core.rs and pb/im.rs modules
This commit is contained in:
+239
@@ -0,0 +1,239 @@
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
|
||||
|
||||
use crate::engine::packet::{Packet, PacketData, PacketError, PacketType};
|
||||
|
||||
const RECORD_SEPARATOR: char = '\x1e';
|
||||
|
||||
pub fn encode_packet(packet: &Packet) -> String {
|
||||
let type_char = packet.packet_type as u8 + b'0';
|
||||
let type_str = type_char as char;
|
||||
|
||||
match &packet.data {
|
||||
PacketData::Text(s) => format!("{type_str}{s}"),
|
||||
PacketData::Binary(b) => format!("{type_str}b{}", BASE64.encode(b)),
|
||||
PacketData::Empty => type_str.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode_packet_binary_ws(packet: &Packet) -> Vec<u8> {
|
||||
let type_byte = packet.packet_type as u8 + b'0';
|
||||
|
||||
match &packet.data {
|
||||
PacketData::Text(s) => {
|
||||
let mut buf = Vec::with_capacity(1 + s.len());
|
||||
buf.push(type_byte);
|
||||
buf.extend_from_slice(s.as_bytes());
|
||||
buf
|
||||
}
|
||||
PacketData::Binary(b) => {
|
||||
let mut buf = Vec::with_capacity(1 + b.len());
|
||||
buf.push(type_byte);
|
||||
buf.extend_from_slice(b);
|
||||
buf
|
||||
}
|
||||
PacketData::Empty => vec![type_byte],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_packet(input: &str) -> Result<Packet, PacketError> {
|
||||
let mut chars = input.chars();
|
||||
let type_char = chars.next().ok_or(PacketError::Empty)?;
|
||||
let packet_type = PacketType::try_from(type_char)?;
|
||||
let rest: String = chars.collect();
|
||||
|
||||
if rest.is_empty() {
|
||||
return Ok(Packet {
|
||||
packet_type,
|
||||
data: PacketData::Empty,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(b64_data) = rest.strip_prefix('b') {
|
||||
let decoded = BASE64.decode(b64_data)?;
|
||||
return Ok(Packet {
|
||||
packet_type,
|
||||
data: PacketData::Binary(decoded),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Packet {
|
||||
packet_type,
|
||||
data: PacketData::Text(rest),
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode a WebSocket binary frame into a Packet.
|
||||
/// Handles both text-encoded packets (UTF-8 payload after type byte)
|
||||
/// and binary packets (raw binary payload after type byte).
|
||||
pub fn decode_packet_ws(input: &[u8]) -> Result<Packet, PacketError> {
|
||||
if input.is_empty() {
|
||||
return Err(PacketError::Empty);
|
||||
}
|
||||
|
||||
let type_byte = input[0];
|
||||
let packet_type = PacketType::try_from(type_byte.wrapping_sub(b'0'))?;
|
||||
let rest = &input[1..];
|
||||
|
||||
if rest.is_empty() {
|
||||
return Ok(Packet {
|
||||
packet_type,
|
||||
data: PacketData::Empty,
|
||||
});
|
||||
}
|
||||
|
||||
// Try UTF-8 first; if it fails, treat as binary data
|
||||
match String::from_utf8(rest.to_vec()) {
|
||||
Ok(text) => Ok(Packet {
|
||||
packet_type,
|
||||
data: PacketData::Text(text),
|
||||
}),
|
||||
Err(_) => Ok(Packet {
|
||||
packet_type,
|
||||
data: PacketData::Binary(rest.to_vec()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode_payload(packets: &[Packet]) -> String {
|
||||
packets
|
||||
.iter()
|
||||
.map(encode_packet)
|
||||
.collect::<Vec<_>>()
|
||||
.join(&RECORD_SEPARATOR.to_string())
|
||||
}
|
||||
|
||||
pub fn decode_payload(input: &str) -> Result<Vec<Packet>, PacketError> {
|
||||
input
|
||||
.split(RECORD_SEPARATOR)
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(decode_packet)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn encode_webtransport_header(payload_len: usize, is_binary: bool) -> Vec<u8> {
|
||||
let binary_bit: u8 = if is_binary { 0x80 } else { 0x00 };
|
||||
|
||||
if payload_len <= 125 {
|
||||
vec![binary_bit | (payload_len as u8)]
|
||||
} else if payload_len <= 65535 {
|
||||
let mut header = vec![binary_bit | 126];
|
||||
header.extend_from_slice(&(payload_len as u16).to_be_bytes());
|
||||
header
|
||||
} else {
|
||||
let mut header = vec![binary_bit | 127];
|
||||
header.extend_from_slice(&(payload_len as u64).to_be_bytes());
|
||||
header
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_webtransport_header(header: &[u8]) -> Option<(usize, bool)> {
|
||||
if header.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let first = header[0];
|
||||
let is_binary = (first & 0x80) != 0;
|
||||
let len_indicator = first & 0x7f;
|
||||
|
||||
if len_indicator <= 125 {
|
||||
Some((len_indicator as usize, is_binary))
|
||||
} else if len_indicator == 126 {
|
||||
if header.len() < 3 {
|
||||
return None;
|
||||
}
|
||||
let len = u16::from_be_bytes([header[1], header[2]]) as usize;
|
||||
Some((len, is_binary))
|
||||
} else {
|
||||
if header.len() < 9 {
|
||||
return None;
|
||||
}
|
||||
let len = u64::from_be_bytes([
|
||||
header[1], header[2], header[3], header[4], header[5], header[6], header[7], header[8],
|
||||
]) as usize;
|
||||
Some((len, is_binary))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_text_packet() {
|
||||
let packet = Packet::message_text("hello");
|
||||
let encoded = encode_packet(&packet);
|
||||
assert_eq!(encoded, "4hello");
|
||||
|
||||
let decoded = decode_packet(&encoded).unwrap();
|
||||
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_binary_packet() {
|
||||
let packet = Packet::message_binary(vec![1, 2, 3, 4]);
|
||||
let encoded = encode_packet(&packet);
|
||||
assert_eq!(encoded, "4bAQIDBA==");
|
||||
|
||||
let decoded = decode_packet(&encoded).unwrap();
|
||||
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||
assert_eq!(decoded.data, PacketData::Binary(vec![1, 2, 3, 4]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_payload() {
|
||||
let packets = vec![
|
||||
Packet::message_text("hello"),
|
||||
Packet::ping(""),
|
||||
Packet::message_text("world"),
|
||||
];
|
||||
let encoded = encode_payload(&packets);
|
||||
assert_eq!(encoded, "4hello\x1e2\x1e4world");
|
||||
|
||||
let decoded = decode_payload(&encoded).unwrap();
|
||||
assert_eq!(decoded.len(), 3);
|
||||
assert_eq!(decoded[0].packet_type, PacketType::Message);
|
||||
assert_eq!(decoded[1].packet_type, PacketType::Ping);
|
||||
assert_eq!(decoded[2].packet_type, PacketType::Message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webtransport_header() {
|
||||
let header = encode_webtransport_header(6, false);
|
||||
assert_eq!(header, vec![0x06]);
|
||||
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
|
||||
assert_eq!(len, 6);
|
||||
assert!(!is_binary);
|
||||
|
||||
let header = encode_webtransport_header(200, true);
|
||||
assert_eq!(header.len(), 3);
|
||||
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
|
||||
assert_eq!(len, 200);
|
||||
assert!(is_binary);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_packet_ws_text() {
|
||||
let input = b"4hello";
|
||||
let decoded = decode_packet_ws(input).unwrap();
|
||||
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_packet_ws_binary() {
|
||||
// Type byte 4 (Message) + raw binary payload (non-UTF-8)
|
||||
let input: Vec<u8> = vec![b'4', 0x80, 0xFF, 0x00, 0x01];
|
||||
let decoded = decode_packet_ws(&input).unwrap();
|
||||
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||
assert_eq!(decoded.data, PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_packet_ws_empty() {
|
||||
let input = b"4";
|
||||
let decoded = decode_packet_ws(input).unwrap();
|
||||
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||
assert_eq!(decoded.data, PacketData::Empty);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::engine::packet::Packet;
|
||||
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||
|
||||
pub struct HeartbeatManager {
|
||||
store: SessionStore,
|
||||
ping_interval: u64,
|
||||
ping_timeout: u64,
|
||||
}
|
||||
|
||||
impl HeartbeatManager {
|
||||
pub fn new(store: SessionStore, ping_interval: u64, ping_timeout: u64) -> Self {
|
||||
Self {
|
||||
store,
|
||||
ping_interval,
|
||||
ping_timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
|
||||
let this = self.clone();
|
||||
tokio::spawn(async move {
|
||||
this.run().await;
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(&self) {
|
||||
let mut interval = tokio::time::interval(Duration::from_millis(self.ping_interval));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
self.check_sessions().await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_sessions(&self) {
|
||||
let now = std::time::Instant::now();
|
||||
let timeout_duration = Duration::from_millis(self.ping_interval + self.ping_timeout);
|
||||
|
||||
let mut to_remove = Vec::new();
|
||||
|
||||
for entry in self.store.sessions.iter() {
|
||||
let sid = entry.key().clone();
|
||||
let session = entry.value().clone();
|
||||
|
||||
let (state, last_ping, transport) = {
|
||||
let s = session.read().await;
|
||||
(s.state, s.last_ping, s.transport)
|
||||
};
|
||||
|
||||
if state == SessionState::Closed {
|
||||
to_remove.push(sid);
|
||||
continue;
|
||||
}
|
||||
|
||||
if now.duration_since(last_ping) > timeout_duration {
|
||||
tracing::warn!("Session {} timed out", sid);
|
||||
to_remove.push(sid);
|
||||
continue;
|
||||
}
|
||||
|
||||
// For polling sessions: buffer a ping packet for the next GET request.
|
||||
// WS/WT sessions rely on their own dedicated ping tasks; the timeout
|
||||
// check above already serves as the safety net for all transports.
|
||||
if state == SessionState::Open && transport == TransportType::Polling {
|
||||
let mut s = session.write().await;
|
||||
s.buffer_packet(Packet::ping(""));
|
||||
}
|
||||
}
|
||||
|
||||
for sid in to_remove {
|
||||
self.store.remove(&sid);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
pub mod codec;
|
||||
pub mod heartbeat;
|
||||
pub mod packet;
|
||||
pub mod polling;
|
||||
pub mod server;
|
||||
pub mod session;
|
||||
pub mod upgrade;
|
||||
pub mod websocket;
|
||||
pub mod webtransport;
|
||||
|
||||
pub use packet::{HandshakeData, Packet, PacketData, PacketType};
|
||||
pub use server::{EngineConfig, EngineServer};
|
||||
pub use session::{SessionState, SessionStore, TransportType};
|
||||
@@ -0,0 +1,151 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum PacketType {
|
||||
Open = 0,
|
||||
Close = 1,
|
||||
Ping = 2,
|
||||
Pong = 3,
|
||||
Message = 4,
|
||||
Upgrade = 5,
|
||||
Noop = 6,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for PacketType {
|
||||
type Error = PacketError;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0 => Ok(Self::Open),
|
||||
1 => Ok(Self::Close),
|
||||
2 => Ok(Self::Ping),
|
||||
3 => Ok(Self::Pong),
|
||||
4 => Ok(Self::Message),
|
||||
5 => Ok(Self::Upgrade),
|
||||
6 => Ok(Self::Noop),
|
||||
_ => Err(PacketError::InvalidType(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<char> for PacketType {
|
||||
type Error = PacketError;
|
||||
|
||||
fn try_from(value: char) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
'0' => Ok(Self::Open),
|
||||
'1' => Ok(Self::Close),
|
||||
'2' => Ok(Self::Ping),
|
||||
'3' => Ok(Self::Pong),
|
||||
'4' => Ok(Self::Message),
|
||||
'5' => Ok(Self::Upgrade),
|
||||
'6' => Ok(Self::Noop),
|
||||
_ => Err(PacketError::InvalidTypeChar(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PacketData {
|
||||
Text(String),
|
||||
Binary(Vec<u8>),
|
||||
Empty,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Packet {
|
||||
pub packet_type: PacketType,
|
||||
pub data: PacketData,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn open(handshake: &HandshakeData) -> Self {
|
||||
let data = serde_json::to_string(handshake)
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to serialize handshake data: {}", e);
|
||||
"{}".to_string()
|
||||
});
|
||||
Self {
|
||||
packet_type: PacketType::Open,
|
||||
data: PacketData::Text(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close() -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Close,
|
||||
data: PacketData::Empty,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ping(data: impl Into<String>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Ping,
|
||||
data: PacketData::Text(data.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pong(data: impl Into<String>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Pong,
|
||||
data: PacketData::Text(data.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn message_text(data: impl Into<String>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Message,
|
||||
data: PacketData::Text(data.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn message_binary(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Message,
|
||||
data: PacketData::Binary(data),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn upgrade() -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Upgrade,
|
||||
data: PacketData::Empty,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noop() -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Noop,
|
||||
data: PacketData::Empty,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HandshakeData {
|
||||
pub sid: String,
|
||||
pub upgrades: Vec<String>,
|
||||
#[serde(rename = "pingInterval")]
|
||||
pub ping_interval: u64,
|
||||
#[serde(rename = "pingTimeout")]
|
||||
pub ping_timeout: u64,
|
||||
#[serde(rename = "maxPayload")]
|
||||
pub max_payload: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PacketError {
|
||||
#[error("invalid packet type: {0}")]
|
||||
InvalidType(u8),
|
||||
#[error("invalid packet type char: {0}")]
|
||||
InvalidTypeChar(char),
|
||||
#[error("empty packet")]
|
||||
Empty,
|
||||
#[error("invalid base64: {0}")]
|
||||
InvalidBase64(#[from] base64::DecodeError),
|
||||
#[error("invalid utf8: {0}")]
|
||||
InvalidUtf8(#[from] std::string::FromUtf8Error),
|
||||
#[error("serialization error: {0}")]
|
||||
Serialization(String),
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::{web, HttpRequest, HttpResponse};
|
||||
|
||||
use crate::engine::codec;
|
||||
use crate::engine::packet::{Packet, PacketType};
|
||||
use crate::engine::server::EngineConfig;
|
||||
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct PollingQuery {
|
||||
#[serde(rename = "EIO")]
|
||||
pub eio: Option<String>,
|
||||
pub transport: Option<String>,
|
||||
pub sid: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn polling_get(
|
||||
_req: HttpRequest,
|
||||
query: web::Query<PollingQuery>,
|
||||
store: web::Data<SessionStore>,
|
||||
config: web::Data<EngineConfig>,
|
||||
_on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
||||
) -> HttpResponse {
|
||||
if query.eio.as_deref() != Some("4") {
|
||||
return HttpResponse::BadRequest().body("invalid EIO version");
|
||||
}
|
||||
|
||||
if query.transport.as_deref() != Some("polling") {
|
||||
return HttpResponse::BadRequest().body("invalid transport");
|
||||
}
|
||||
|
||||
let sid = match &query.sid {
|
||||
Some(sid) => sid.clone(),
|
||||
None => {
|
||||
return handle_handshake(&store, &config).await;
|
||||
}
|
||||
};
|
||||
|
||||
let session = match store.get(&sid) {
|
||||
Some(s) => s,
|
||||
None => return HttpResponse::BadRequest().body("unknown session"),
|
||||
};
|
||||
|
||||
// Check session state and take any buffered pending packets
|
||||
let notify = {
|
||||
let mut session_guard = session.write().await;
|
||||
|
||||
if session_guard.state == SessionState::Closed {
|
||||
return HttpResponse::BadRequest().body("session closed");
|
||||
}
|
||||
|
||||
let pending = session_guard.take_pending();
|
||||
if !pending.is_empty() {
|
||||
let payload = codec::encode_payload(&pending);
|
||||
return HttpResponse::Ok()
|
||||
.content_type("text/plain; charset=UTF-8")
|
||||
.body(payload);
|
||||
}
|
||||
|
||||
session_guard.notify.clone()
|
||||
};
|
||||
|
||||
let timeout = Duration::from_millis(config.ping_interval + config.ping_timeout);
|
||||
let _result = tokio::time::timeout(timeout, notify.notified()).await;
|
||||
|
||||
// Re-verify session still exists after wait (may have been removed by heartbeat)
|
||||
if store.get(&sid).is_none() {
|
||||
return HttpResponse::BadRequest().body("session closed");
|
||||
}
|
||||
|
||||
let mut session_guard = session.write().await;
|
||||
let packets = session_guard.take_pending();
|
||||
|
||||
if packets.is_empty() {
|
||||
let noop = codec::encode_packet(&Packet::noop());
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/plain; charset=UTF-8")
|
||||
.body(noop)
|
||||
} else {
|
||||
let payload = codec::encode_payload(&packets);
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/plain; charset=UTF-8")
|
||||
.body(payload)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn polling_post(
|
||||
_req: HttpRequest,
|
||||
body: web::Bytes,
|
||||
query: web::Query<PollingQuery>,
|
||||
store: web::Data<SessionStore>,
|
||||
config: web::Data<EngineConfig>,
|
||||
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
||||
) -> HttpResponse {
|
||||
if query.eio.as_deref() != Some("4") {
|
||||
return HttpResponse::BadRequest().body("invalid EIO version");
|
||||
}
|
||||
|
||||
if query.transport.as_deref() != Some("polling") {
|
||||
return HttpResponse::BadRequest().body("invalid transport");
|
||||
}
|
||||
|
||||
// Check payload size BEFORE attempting to decode
|
||||
if body.len() > config.max_payload {
|
||||
return HttpResponse::PayloadTooLarge().body("payload too large");
|
||||
}
|
||||
|
||||
let sid = match &query.sid {
|
||||
Some(sid) => sid,
|
||||
None => return HttpResponse::BadRequest().body("missing sid"),
|
||||
};
|
||||
|
||||
let session = match store.get(sid) {
|
||||
Some(s) => s,
|
||||
None => return HttpResponse::BadRequest().body("unknown session"),
|
||||
};
|
||||
|
||||
let body_str = match std::str::from_utf8(&body) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return HttpResponse::BadRequest().body("invalid utf8"),
|
||||
};
|
||||
|
||||
let packets = match codec::decode_payload(body_str) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return HttpResponse::BadRequest().body("invalid payload"),
|
||||
};
|
||||
|
||||
let mut session_guard = session.write().await;
|
||||
|
||||
for packet in packets {
|
||||
match packet.packet_type {
|
||||
PacketType::Pong => {
|
||||
session_guard.update_ping();
|
||||
}
|
||||
PacketType::Message => {
|
||||
let on_msg = on_message.get_ref().clone();
|
||||
let sid_owned = sid.clone();
|
||||
tokio::spawn(async move {
|
||||
on_msg(sid_owned, packet);
|
||||
});
|
||||
}
|
||||
PacketType::Close => {
|
||||
session_guard.set_state(SessionState::Closed);
|
||||
store.remove(sid);
|
||||
return HttpResponse::Ok().body("ok");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
HttpResponse::Ok().body("ok")
|
||||
}
|
||||
|
||||
async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpResponse {
|
||||
let sid = crate::engine::session::generate_sid();
|
||||
|
||||
let handshake = crate::engine::packet::HandshakeData {
|
||||
sid: sid.clone(),
|
||||
upgrades: vec!["websocket".to_string()],
|
||||
ping_interval: config.ping_interval,
|
||||
ping_timeout: config.ping_timeout,
|
||||
max_payload: config.max_payload,
|
||||
};
|
||||
|
||||
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||
|
||||
if let Some(session) = store.get(&sid) {
|
||||
let mut session = session.write().await;
|
||||
session.set_state(SessionState::Open);
|
||||
}
|
||||
|
||||
let packet = Packet::open(&handshake);
|
||||
let payload = codec::encode_packet(&packet);
|
||||
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/plain; charset=UTF-8")
|
||||
.body(payload)
|
||||
}
|
||||
|
||||
pub fn configure_polling(cfg: &mut web::ServiceConfig) {
|
||||
cfg.route("/engine.io/", web::get().to(polling_get))
|
||||
.route("/engine.io/", web::post().to(polling_post));
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_web::{web, App, HttpServer};
|
||||
|
||||
use crate::engine::heartbeat::HeartbeatManager;
|
||||
use crate::engine::packet::Packet;
|
||||
use crate::engine::session::SessionStore;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineConfig {
|
||||
pub ping_interval: u64,
|
||||
pub ping_timeout: u64,
|
||||
pub max_payload: usize,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
impl Default for EngineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
ping_interval: 25000,
|
||||
ping_timeout: 20000,
|
||||
max_payload: 1_000_000,
|
||||
path: "/engine.io/".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EngineServer {
|
||||
pub config: EngineConfig,
|
||||
pub store: SessionStore,
|
||||
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl EngineServer {
|
||||
pub fn new(
|
||||
config: EngineConfig,
|
||||
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
store: SessionStore::new(),
|
||||
on_message: Arc::new(on_message),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_store(
|
||||
config: EngineConfig,
|
||||
store: SessionStore,
|
||||
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
store,
|
||||
on_message: Arc::new(on_message),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
|
||||
let store = self.store.clone();
|
||||
let config = self.config.clone();
|
||||
let on_message = self.on_message.clone();
|
||||
|
||||
// Start heartbeat manager to clean up stale sessions
|
||||
let heartbeat = Arc::new(HeartbeatManager::new(
|
||||
store.clone(),
|
||||
config.ping_interval,
|
||||
config.ping_timeout,
|
||||
));
|
||||
let heartbeat_handle = heartbeat.start();
|
||||
|
||||
tracing::info!("Engine.IO HTTP server listening on {}", addr);
|
||||
|
||||
let result = HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(web::Data::new(store.clone()))
|
||||
.app_data(web::Data::new(config.clone()))
|
||||
.app_data(web::Data::new(on_message.clone()))
|
||||
.route(
|
||||
"/engine.io/",
|
||||
web::get().to(crate::engine::polling::polling_get),
|
||||
)
|
||||
.route(
|
||||
"/engine.io/",
|
||||
web::post().to(crate::engine::polling::polling_post),
|
||||
)
|
||||
.route(
|
||||
"/engine.io/",
|
||||
web::get().to(crate::engine::websocket::websocket_handler),
|
||||
)
|
||||
})
|
||||
.bind(addr)?
|
||||
.run()
|
||||
.await;
|
||||
|
||||
heartbeat_handle.abort();
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn run_webtransport(
|
||||
&self,
|
||||
port: u16,
|
||||
cert_path: &str,
|
||||
key_path: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
crate::engine::webtransport::run_webtransport_server(
|
||||
port,
|
||||
cert_path,
|
||||
key_path,
|
||||
self.store.clone(),
|
||||
self.config.clone(),
|
||||
self.on_message.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
|
||||
use crate::engine::packet::Packet;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TransportType {
|
||||
Polling,
|
||||
WebSocket,
|
||||
WebTransport,
|
||||
}
|
||||
|
||||
impl TransportType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Polling => "polling",
|
||||
Self::WebSocket => "websocket",
|
||||
Self::WebTransport => "webtransport",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SessionState {
|
||||
Connecting,
|
||||
Open,
|
||||
Upgrading,
|
||||
Closing,
|
||||
Closed,
|
||||
}
|
||||
|
||||
pub struct Session {
|
||||
pub sid: String,
|
||||
pub transport: TransportType,
|
||||
pub state: SessionState,
|
||||
pub created_at: Instant,
|
||||
pub last_ping: Instant,
|
||||
pub tx: mpsc::Sender<Packet>,
|
||||
pub pending_packets: Vec<Packet>,
|
||||
pub notify: Arc<Notify>,
|
||||
pub upgrade_tx: Option<mpsc::Sender<Packet>>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(sid: String, transport: TransportType) -> (Self, mpsc::Receiver<Packet>) {
|
||||
let (tx, rx) = mpsc::channel(256);
|
||||
let session = Self {
|
||||
sid,
|
||||
transport,
|
||||
state: SessionState::Connecting,
|
||||
created_at: Instant::now(),
|
||||
last_ping: Instant::now(),
|
||||
tx,
|
||||
pending_packets: Vec::new(),
|
||||
notify: Arc::new(Notify::new()),
|
||||
upgrade_tx: None,
|
||||
};
|
||||
(session, rx)
|
||||
}
|
||||
|
||||
/// Send a packet through the mpsc channel (for WS/WT transport consumption).
|
||||
pub fn send_packet(&self, packet: Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||
self.tx.try_send(packet)
|
||||
}
|
||||
|
||||
/// Push a packet using the appropriate mechanism for the current transport.
|
||||
/// Polling: buffer in pending_packets + notify waiting GET request.
|
||||
/// WS/WT: try mpsc channel first; if full, buffer as fallback + notify.
|
||||
pub fn push_packet(&mut self, packet: Packet) {
|
||||
if self.transport == TransportType::Polling {
|
||||
self.pending_packets.push(packet);
|
||||
self.notify.notify_one();
|
||||
} else {
|
||||
if self.tx.try_send(packet.clone()).is_err() {
|
||||
self.pending_packets.push(packet);
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Buffer a packet in pending_packets and notify any waiting polling request.
|
||||
pub fn buffer_packet(&mut self, packet: Packet) {
|
||||
self.pending_packets.push(packet);
|
||||
self.notify.notify_one();
|
||||
}
|
||||
|
||||
pub fn take_pending(&mut self) -> Vec<Packet> {
|
||||
std::mem::take(&mut self.pending_packets)
|
||||
}
|
||||
|
||||
pub fn update_ping(&mut self) {
|
||||
self.last_ping = Instant::now();
|
||||
}
|
||||
|
||||
pub fn set_transport(&mut self, transport: TransportType) {
|
||||
self.transport = transport;
|
||||
}
|
||||
|
||||
pub fn set_state(&mut self, state: SessionState) {
|
||||
self.state = state;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStore {
|
||||
pub sessions: Arc<DashMap<String, Arc<tokio::sync::RwLock<Session>>>>,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session. Returns the mpsc receiver for transport-level packet consumption.
|
||||
/// Logs a warning if the SID collides with an existing session (extremely unlikely with crypto RNG).
|
||||
pub fn create(&self, sid: String, transport: TransportType) -> mpsc::Receiver<Packet> {
|
||||
let (session, rx) = Session::new(sid.clone(), transport);
|
||||
let old = self
|
||||
.sessions
|
||||
.insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session)));
|
||||
if old.is_some() {
|
||||
tracing::warn!("Session ID collision for SID {}, replacing existing session", sid);
|
||||
}
|
||||
rx
|
||||
}
|
||||
|
||||
pub fn get(&self, sid: &str) -> Option<Arc<tokio::sync::RwLock<Session>>> {
|
||||
self.sessions.get(sid).map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
pub fn remove(&self, sid: &str) {
|
||||
self.sessions.remove(sid);
|
||||
}
|
||||
|
||||
pub fn exists(&self, sid: &str) -> bool {
|
||||
self.sessions.contains_key(sid)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.sessions.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.sessions.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random session ID using a cryptographically secure RNG.
|
||||
/// rand 0.9's default RNG (ChaCha8Rng seeded from OsRng) is crypto-secure.
|
||||
pub fn generate_sid() -> String {
|
||||
use rand::Rng;
|
||||
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-";
|
||||
let mut rng = rand::rng();
|
||||
(0..20)
|
||||
.map(|_| {
|
||||
let idx = rng.random_range(0..CHARSET.len());
|
||||
CHARSET[idx] as char
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
use crate::engine::packet::Packet;
|
||||
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||
|
||||
pub async fn handle_upgrade_probe(
|
||||
store: &SessionStore,
|
||||
sid: &str,
|
||||
) -> Result<Packet, UpgradeError> {
|
||||
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
|
||||
let mut session = session.write().await;
|
||||
|
||||
if session.state == SessionState::Closed {
|
||||
return Err(UpgradeError::SessionClosed);
|
||||
}
|
||||
|
||||
session.set_state(SessionState::Upgrading);
|
||||
Ok(Packet::pong("probe"))
|
||||
}
|
||||
|
||||
pub async fn handle_upgrade_complete(
|
||||
store: &SessionStore,
|
||||
sid: &str,
|
||||
new_transport: TransportType,
|
||||
) -> Result<(), UpgradeError> {
|
||||
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
|
||||
let mut session = session.write().await;
|
||||
|
||||
session.set_transport(new_transport);
|
||||
session.set_state(SessionState::Open);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_noop_to_pending_polling(
|
||||
store: &SessionStore,
|
||||
sid: &str,
|
||||
) -> Result<(), UpgradeError> {
|
||||
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
|
||||
let mut session = session.write().await;
|
||||
|
||||
session.buffer_packet(Packet::noop());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UpgradeError {
|
||||
#[error("session not found")]
|
||||
SessionNotFound,
|
||||
#[error("session closed")]
|
||||
SessionClosed,
|
||||
#[error("invalid state for upgrade")]
|
||||
InvalidState,
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_web::{web, HttpRequest, HttpResponse};
|
||||
use actix_ws::Message;
|
||||
|
||||
use crate::engine::codec;
|
||||
use crate::engine::packet::{Packet, PacketData, PacketType};
|
||||
use crate::engine::server::EngineConfig;
|
||||
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct WsQuery {
|
||||
#[serde(rename = "EIO")]
|
||||
pub eio: Option<String>,
|
||||
pub transport: Option<String>,
|
||||
pub sid: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn websocket_handler(
|
||||
req: HttpRequest,
|
||||
body: web::Payload,
|
||||
query: web::Query<WsQuery>,
|
||||
store: web::Data<SessionStore>,
|
||||
config: web::Data<EngineConfig>,
|
||||
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
if query.eio.as_deref() != Some("4") {
|
||||
return Ok(HttpResponse::BadRequest().body("invalid EIO version"));
|
||||
}
|
||||
|
||||
if query.transport.as_deref() != Some("websocket") {
|
||||
return Ok(HttpResponse::BadRequest().body("invalid transport"));
|
||||
}
|
||||
|
||||
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?;
|
||||
|
||||
let sid = query.sid.clone();
|
||||
|
||||
let is_upgrade = sid.as_ref().map(|s| store.exists(s)).unwrap_or(false);
|
||||
|
||||
// Create or reuse session, obtaining the mpsc receiver for the forwarding task
|
||||
let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
|
||||
if is_upgrade {
|
||||
// Upgrade: session already exists, replace its channel and drain pending packets
|
||||
let session_arc = store.get(sid).unwrap();
|
||||
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
|
||||
{
|
||||
let mut s = session_arc.write().await;
|
||||
// Swap tx atomically: old_tx will be dropped, closing its channel.
|
||||
// Any packets in the old rx are consumed by the old send_handle,
|
||||
// which then exits when it sees the channel close.
|
||||
// Drain pending_packets (from polling buffering) into new channel.
|
||||
let pending = s.take_pending();
|
||||
for packet in pending {
|
||||
let _ = new_tx.try_send(packet);
|
||||
}
|
||||
s.tx = new_tx;
|
||||
s.set_transport(TransportType::WebSocket);
|
||||
}
|
||||
(sid.clone(), new_rx)
|
||||
} else {
|
||||
// Reconnect with known SID: create new session
|
||||
let rx = store.create(sid.clone(), TransportType::WebSocket);
|
||||
if let Some(s) = store.get(sid) {
|
||||
let mut s = s.write().await;
|
||||
s.set_state(SessionState::Open);
|
||||
}
|
||||
(sid.clone(), rx)
|
||||
}
|
||||
} else {
|
||||
// New connection: generate SID and create session
|
||||
let new_sid = crate::engine::session::generate_sid();
|
||||
let rx = store.create(new_sid.clone(), TransportType::WebSocket);
|
||||
if let Some(s) = store.get(&new_sid) {
|
||||
let mut s = s.write().await;
|
||||
s.set_state(SessionState::Open);
|
||||
}
|
||||
(new_sid, rx)
|
||||
};
|
||||
|
||||
let handshake = crate::engine::packet::HandshakeData {
|
||||
sid: session_sid.clone(),
|
||||
upgrades: vec![],
|
||||
ping_interval: config.ping_interval,
|
||||
ping_timeout: config.ping_timeout,
|
||||
max_payload: config.max_payload,
|
||||
};
|
||||
|
||||
let open_packet = Packet::open(&handshake);
|
||||
let open_msg = codec::encode_packet(&open_packet);
|
||||
if ws_session.text(open_msg).await.is_err() {
|
||||
tracing::warn!("Failed to send open packet to WebSocket session {}", session_sid);
|
||||
store.remove(&session_sid);
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
let store_clone = store.get_ref().clone();
|
||||
let on_message_clone = on_message.get_ref().clone();
|
||||
let sid_clone = session_sid.clone();
|
||||
let ws_session_clone = ws_session.clone();
|
||||
let max_payload = config.max_payload;
|
||||
|
||||
// Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx)
|
||||
let sid_for_send = session_sid.clone();
|
||||
let store_for_send = store.get_ref().clone();
|
||||
let mut ws_for_send = ws_session.clone();
|
||||
let send_handle = actix_rt::spawn(async move {
|
||||
while let Some(packet) = session_rx.recv().await {
|
||||
let encoded = codec::encode_packet(&packet);
|
||||
if ws_for_send.text(encoded).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Session channel closed — clean up
|
||||
store_for_send.remove(&sid_for_send);
|
||||
});
|
||||
|
||||
// Task 2: Read incoming WebSocket messages → dispatch
|
||||
let recv_handle = actix_rt::spawn(async move {
|
||||
let mut ws_session = ws_session_clone;
|
||||
while let Some(Ok(msg)) = msg_stream.recv().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
if let Ok(packet) = codec::decode_packet(&text) {
|
||||
match packet.packet_type {
|
||||
PacketType::Ping => {
|
||||
if let PacketData::Text(ref data) = packet.data {
|
||||
if data == "probe" {
|
||||
let pong = Packet::pong("probe");
|
||||
let pong_msg = codec::encode_packet(&pong);
|
||||
let _ = ws_session.text(pong_msg).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let pong = Packet::pong("");
|
||||
let pong_msg = codec::encode_packet(&pong);
|
||||
let _ = ws_session.text(pong_msg).await;
|
||||
}
|
||||
PacketType::Pong => {
|
||||
if let Some(s) = store_clone.get(&sid_clone) {
|
||||
let mut s = s.write().await;
|
||||
s.update_ping();
|
||||
}
|
||||
}
|
||||
PacketType::Upgrade => {
|
||||
if let Some(s) = store_clone.get(&sid_clone) {
|
||||
let mut s = s.write().await;
|
||||
s.set_transport(TransportType::WebSocket);
|
||||
s.set_state(SessionState::Open);
|
||||
}
|
||||
}
|
||||
PacketType::Message => {
|
||||
let on_msg = on_message_clone.clone();
|
||||
let sid = sid_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
on_msg(sid, packet);
|
||||
});
|
||||
}
|
||||
PacketType::Close => {
|
||||
if let Some(s) = store_clone.get(&sid_clone) {
|
||||
let mut s = s.write().await;
|
||||
s.set_state(SessionState::Closed);
|
||||
}
|
||||
store_clone.remove(&sid_clone);
|
||||
let _ = ws_session.close(None).await;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Binary(bin) => {
|
||||
// Enforce max payload size for binary frames
|
||||
if bin.len() > max_payload {
|
||||
tracing::warn!(
|
||||
"Binary payload too large ({}) for session {}",
|
||||
bin.len(),
|
||||
sid_clone
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(packet) = codec::decode_packet_ws(&bin) {
|
||||
if packet.packet_type == PacketType::Message {
|
||||
let on_msg = on_message_clone.clone();
|
||||
let sid = sid_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
on_msg(sid, packet);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => {
|
||||
if let Some(s) = store_clone.get(&sid_clone) {
|
||||
let mut s = s.write().await;
|
||||
s.set_state(SessionState::Closed);
|
||||
}
|
||||
store_clone.remove(&sid_clone);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Task 3: Heartbeat ping sender
|
||||
let sid_for_ping = session_sid.clone();
|
||||
let store_for_ping = store.get_ref().clone();
|
||||
let mut ws_for_ping = ws_session.clone();
|
||||
let ping_interval = config.ping_interval;
|
||||
let ping_handle = tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if let Some(s) = store_for_ping.get(&sid_for_ping) {
|
||||
let session_state = {
|
||||
let s = s.read().await;
|
||||
s.state
|
||||
};
|
||||
|
||||
if session_state == SessionState::Closed {
|
||||
break;
|
||||
}
|
||||
|
||||
let ping = Packet::ping("");
|
||||
let ping_msg = codec::encode_packet(&ping);
|
||||
if ws_for_ping.text(ping_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for any task to finish, then clean up
|
||||
actix_rt::spawn(async move {
|
||||
// actix_rt::spawn returns JoinHandle which is compatible with tokio::select!
|
||||
tokio::select! {
|
||||
_ = send_handle => {},
|
||||
_ = recv_handle => {},
|
||||
_ = ping_handle => {},
|
||||
}
|
||||
store.remove(&session_sid);
|
||||
});
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub fn configure_websocket(cfg: &mut web::ServiceConfig) {
|
||||
cfg.route("/engine.io/", web::get().to(websocket_handler));
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use wtransport::{Connection, Endpoint, ServerConfig, Identity};
|
||||
|
||||
use crate::engine::codec;
|
||||
use crate::engine::packet::{Packet, PacketType};
|
||||
use crate::engine::server::EngineConfig;
|
||||
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||
|
||||
pub async fn run_webtransport_server(
|
||||
port: u16,
|
||||
cert_path: &str,
|
||||
key_path: &str,
|
||||
store: SessionStore,
|
||||
config: EngineConfig,
|
||||
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let identity = Identity::load_pemfiles(cert_path, key_path).await?;
|
||||
|
||||
let server_config = ServerConfig::builder()
|
||||
.with_bind_default(port)
|
||||
.with_identity(identity)
|
||||
.build();
|
||||
|
||||
let server = Endpoint::server(server_config)?;
|
||||
|
||||
tracing::info!("WebTransport server listening on UDP port {}", port);
|
||||
|
||||
loop {
|
||||
let incoming = server.accept().await;
|
||||
|
||||
let store = store.clone();
|
||||
let config = config.clone();
|
||||
let on_message = on_message.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match handle_webtransport_session(incoming, store, config, on_message).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("WebTransport session error: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_webtransport_session(
|
||||
incoming: wtransport::endpoint::IncomingSession,
|
||||
store: SessionStore,
|
||||
config: EngineConfig,
|
||||
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let request = incoming.await?;
|
||||
let connection = request.accept().await?;
|
||||
|
||||
let sid = crate::engine::session::generate_sid();
|
||||
let mut rx = store.create(sid.clone(), TransportType::WebTransport);
|
||||
|
||||
if let Some(s) = store.get(&sid) {
|
||||
let mut s = s.write().await;
|
||||
s.set_state(SessionState::Open);
|
||||
}
|
||||
|
||||
let handshake = crate::engine::packet::HandshakeData {
|
||||
sid: sid.clone(),
|
||||
upgrades: vec![],
|
||||
ping_interval: config.ping_interval,
|
||||
ping_timeout: config.ping_timeout,
|
||||
max_payload: config.max_payload,
|
||||
};
|
||||
|
||||
let open_packet = Packet::open(&handshake);
|
||||
send_wt_packet(&connection, &open_packet).await?;
|
||||
|
||||
let store_clone = store.clone();
|
||||
let sid_clone = sid.clone();
|
||||
let on_message_clone = on_message.clone();
|
||||
let connection_recv = connection.clone();
|
||||
let max_payload = config.max_payload;
|
||||
|
||||
// Reuse buffer across recv iterations instead of allocating 65KB each time
|
||||
let recv_handle = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
match connection_recv.accept_bi().await {
|
||||
Ok((mut send, mut recv)) => {
|
||||
// Reset buffer length for the next read without deallocating
|
||||
buf.resize(65536, 0);
|
||||
match recv.read(&mut buf).await {
|
||||
Ok(Some(n)) => {
|
||||
if n > max_payload {
|
||||
tracing::warn!(
|
||||
"WebTransport payload too large ({}) for session {}",
|
||||
n,
|
||||
sid_clone
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
|
||||
match packet.packet_type {
|
||||
PacketType::Ping => {
|
||||
let pong = Packet::pong("");
|
||||
if send_wt_packet_on_stream(&mut send, &pong)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
PacketType::Pong => {
|
||||
if let Some(s) = store_clone.get(&sid_clone) {
|
||||
let mut s = s.write().await;
|
||||
s.update_ping();
|
||||
}
|
||||
}
|
||||
PacketType::Message => {
|
||||
let on_msg = on_message_clone.clone();
|
||||
let sid = sid_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
on_msg(sid, packet);
|
||||
});
|
||||
}
|
||||
PacketType::Close => {
|
||||
if let Some(s) = store_clone.get(&sid_clone) {
|
||||
let mut s = s.write().await;
|
||||
s.set_state(SessionState::Closed);
|
||||
}
|
||||
store_clone.remove(&sid_clone);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
|
||||
});
|
||||
|
||||
let connection_send = connection.clone();
|
||||
let send_handle = tokio::spawn(async move {
|
||||
while let Some(packet) = rx.recv().await {
|
||||
if send_wt_packet(&connection_send, &packet).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let connection_ping = connection.clone();
|
||||
let store_ping = store.clone();
|
||||
let sid_ping = sid.clone();
|
||||
let ping_interval = config.ping_interval;
|
||||
let ping_handle = tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if let Some(s) = store_ping.get(&sid_ping) {
|
||||
let state = {
|
||||
let s = s.read().await;
|
||||
s.state
|
||||
};
|
||||
|
||||
if state == SessionState::Closed {
|
||||
break;
|
||||
}
|
||||
|
||||
let ping = Packet::ping("");
|
||||
if send_wt_packet(&connection_ping, &ping).await.is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = recv_handle => {},
|
||||
_ = send_handle => {},
|
||||
_ = ping_handle => {},
|
||||
}
|
||||
|
||||
store.remove(&sid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_wt_packet(
|
||||
connection: &Connection,
|
||||
packet: &Packet,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut send, _recv) = connection.open_bi().await?.await?;
|
||||
let encoded = codec::encode_packet_binary_ws(packet);
|
||||
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
|
||||
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
|
||||
|
||||
send.write_all(&header).await?;
|
||||
send.write_all(&encoded).await?;
|
||||
send.finish().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_wt_packet_on_stream(
|
||||
send: &mut wtransport::SendStream,
|
||||
packet: &Packet,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let encoded = codec::encode_packet_binary_ws(packet);
|
||||
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
|
||||
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
|
||||
|
||||
send.write_all(&header).await?;
|
||||
send.write_all(&encoded).await?;
|
||||
send.finish().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user