feat(auth): add authentication protocol definitions and build configuration

- Add TokenClaims message for JWT payload structure with user id, issuer, timestamps, and scopes
- Implement IssueTokenRequest/Response for creating access and refresh tokens with TTL support
- Create RefreshTokenRequest/Response for token rotation functionality
- Define RevokeTokenRequest/Response with support for single token or user-wide revocation
- Add VerifyTokenRequest/Response for validating JWT tokens with detailed claims information
- Implement signing key distribution system with GetSigningKeysRequest/Response
- Create TokenService gRPC service with IssueToken, RefreshToken, RevokeToken, VerifyToken, and GetSigningKeys methods
- Add build.rs configuration to compile proto files using tonic_prost_build
- Include channel, channel_settings, member, and permission protocol definitions for IM services
- Generate Rust code bindings through pb/core.rs and pb/im.rs modules
This commit is contained in:
zhenyi
2026-06-10 23:45:40 +08:00
commit 06e8ee96a5
43 changed files with 9671 additions and 0 deletions
+239
View File
@@ -0,0 +1,239 @@
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use crate::engine::packet::{Packet, PacketData, PacketError, PacketType};
const RECORD_SEPARATOR: char = '\x1e';
pub fn encode_packet(packet: &Packet) -> String {
let type_char = packet.packet_type as u8 + b'0';
let type_str = type_char as char;
match &packet.data {
PacketData::Text(s) => format!("{type_str}{s}"),
PacketData::Binary(b) => format!("{type_str}b{}", BASE64.encode(b)),
PacketData::Empty => type_str.to_string(),
}
}
pub fn encode_packet_binary_ws(packet: &Packet) -> Vec<u8> {
let type_byte = packet.packet_type as u8 + b'0';
match &packet.data {
PacketData::Text(s) => {
let mut buf = Vec::with_capacity(1 + s.len());
buf.push(type_byte);
buf.extend_from_slice(s.as_bytes());
buf
}
PacketData::Binary(b) => {
let mut buf = Vec::with_capacity(1 + b.len());
buf.push(type_byte);
buf.extend_from_slice(b);
buf
}
PacketData::Empty => vec![type_byte],
}
}
pub fn decode_packet(input: &str) -> Result<Packet, PacketError> {
let mut chars = input.chars();
let type_char = chars.next().ok_or(PacketError::Empty)?;
let packet_type = PacketType::try_from(type_char)?;
let rest: String = chars.collect();
if rest.is_empty() {
return Ok(Packet {
packet_type,
data: PacketData::Empty,
});
}
if let Some(b64_data) = rest.strip_prefix('b') {
let decoded = BASE64.decode(b64_data)?;
return Ok(Packet {
packet_type,
data: PacketData::Binary(decoded),
});
}
Ok(Packet {
packet_type,
data: PacketData::Text(rest),
})
}
/// Decode a WebSocket binary frame into a Packet.
/// Handles both text-encoded packets (UTF-8 payload after type byte)
/// and binary packets (raw binary payload after type byte).
pub fn decode_packet_ws(input: &[u8]) -> Result<Packet, PacketError> {
if input.is_empty() {
return Err(PacketError::Empty);
}
let type_byte = input[0];
let packet_type = PacketType::try_from(type_byte.wrapping_sub(b'0'))?;
let rest = &input[1..];
if rest.is_empty() {
return Ok(Packet {
packet_type,
data: PacketData::Empty,
});
}
// Try UTF-8 first; if it fails, treat as binary data
match String::from_utf8(rest.to_vec()) {
Ok(text) => Ok(Packet {
packet_type,
data: PacketData::Text(text),
}),
Err(_) => Ok(Packet {
packet_type,
data: PacketData::Binary(rest.to_vec()),
}),
}
}
pub fn encode_payload(packets: &[Packet]) -> String {
packets
.iter()
.map(encode_packet)
.collect::<Vec<_>>()
.join(&RECORD_SEPARATOR.to_string())
}
pub fn decode_payload(input: &str) -> Result<Vec<Packet>, PacketError> {
input
.split(RECORD_SEPARATOR)
.filter(|s| !s.is_empty())
.map(decode_packet)
.collect()
}
pub fn encode_webtransport_header(payload_len: usize, is_binary: bool) -> Vec<u8> {
let binary_bit: u8 = if is_binary { 0x80 } else { 0x00 };
if payload_len <= 125 {
vec![binary_bit | (payload_len as u8)]
} else if payload_len <= 65535 {
let mut header = vec![binary_bit | 126];
header.extend_from_slice(&(payload_len as u16).to_be_bytes());
header
} else {
let mut header = vec![binary_bit | 127];
header.extend_from_slice(&(payload_len as u64).to_be_bytes());
header
}
}
pub fn decode_webtransport_header(header: &[u8]) -> Option<(usize, bool)> {
if header.is_empty() {
return None;
}
let first = header[0];
let is_binary = (first & 0x80) != 0;
let len_indicator = first & 0x7f;
if len_indicator <= 125 {
Some((len_indicator as usize, is_binary))
} else if len_indicator == 126 {
if header.len() < 3 {
return None;
}
let len = u16::from_be_bytes([header[1], header[2]]) as usize;
Some((len, is_binary))
} else {
if header.len() < 9 {
return None;
}
let len = u64::from_be_bytes([
header[1], header[2], header[3], header[4], header[5], header[6], header[7], header[8],
]) as usize;
Some((len, is_binary))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_text_packet() {
let packet = Packet::message_text("hello");
let encoded = encode_packet(&packet);
assert_eq!(encoded, "4hello");
let decoded = decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
}
#[test]
fn test_encode_decode_binary_packet() {
let packet = Packet::message_binary(vec![1, 2, 3, 4]);
let encoded = encode_packet(&packet);
assert_eq!(encoded, "4bAQIDBA==");
let decoded = decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Binary(vec![1, 2, 3, 4]));
}
#[test]
fn test_encode_decode_payload() {
let packets = vec![
Packet::message_text("hello"),
Packet::ping(""),
Packet::message_text("world"),
];
let encoded = encode_payload(&packets);
assert_eq!(encoded, "4hello\x1e2\x1e4world");
let decoded = decode_payload(&encoded).unwrap();
assert_eq!(decoded.len(), 3);
assert_eq!(decoded[0].packet_type, PacketType::Message);
assert_eq!(decoded[1].packet_type, PacketType::Ping);
assert_eq!(decoded[2].packet_type, PacketType::Message);
}
#[test]
fn test_webtransport_header() {
let header = encode_webtransport_header(6, false);
assert_eq!(header, vec![0x06]);
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
assert_eq!(len, 6);
assert!(!is_binary);
let header = encode_webtransport_header(200, true);
assert_eq!(header.len(), 3);
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
assert_eq!(len, 200);
assert!(is_binary);
}
#[test]
fn test_decode_packet_ws_text() {
let input = b"4hello";
let decoded = decode_packet_ws(input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
}
#[test]
fn test_decode_packet_ws_binary() {
// Type byte 4 (Message) + raw binary payload (non-UTF-8)
let input: Vec<u8> = vec![b'4', 0x80, 0xFF, 0x00, 0x01];
let decoded = decode_packet_ws(&input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01]));
}
#[test]
fn test_decode_packet_ws_empty() {
let input = b"4";
let decoded = decode_packet_ws(input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Empty);
}
}
+77
View File
@@ -0,0 +1,77 @@
use std::sync::Arc;
use std::time::Duration;
use crate::engine::packet::Packet;
use crate::engine::session::{SessionState, SessionStore, TransportType};
pub struct HeartbeatManager {
store: SessionStore,
ping_interval: u64,
ping_timeout: u64,
}
impl HeartbeatManager {
pub fn new(store: SessionStore, ping_interval: u64, ping_timeout: u64) -> Self {
Self {
store,
ping_interval,
ping_timeout,
}
}
pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let this = self.clone();
tokio::spawn(async move {
this.run().await;
})
}
async fn run(&self) {
let mut interval = tokio::time::interval(Duration::from_millis(self.ping_interval));
loop {
interval.tick().await;
self.check_sessions().await;
}
}
async fn check_sessions(&self) {
let now = std::time::Instant::now();
let timeout_duration = Duration::from_millis(self.ping_interval + self.ping_timeout);
let mut to_remove = Vec::new();
for entry in self.store.sessions.iter() {
let sid = entry.key().clone();
let session = entry.value().clone();
let (state, last_ping, transport) = {
let s = session.read().await;
(s.state, s.last_ping, s.transport)
};
if state == SessionState::Closed {
to_remove.push(sid);
continue;
}
if now.duration_since(last_ping) > timeout_duration {
tracing::warn!("Session {} timed out", sid);
to_remove.push(sid);
continue;
}
// For polling sessions: buffer a ping packet for the next GET request.
// WS/WT sessions rely on their own dedicated ping tasks; the timeout
// check above already serves as the safety net for all transports.
if state == SessionState::Open && transport == TransportType::Polling {
let mut s = session.write().await;
s.buffer_packet(Packet::ping(""));
}
}
for sid in to_remove {
self.store.remove(&sid);
}
}
}
+13
View File
@@ -0,0 +1,13 @@
pub mod codec;
pub mod heartbeat;
pub mod packet;
pub mod polling;
pub mod server;
pub mod session;
pub mod upgrade;
pub mod websocket;
pub mod webtransport;
pub use packet::{HandshakeData, Packet, PacketData, PacketType};
pub use server::{EngineConfig, EngineServer};
pub use session::{SessionState, SessionStore, TransportType};
+151
View File
@@ -0,0 +1,151 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum PacketType {
Open = 0,
Close = 1,
Ping = 2,
Pong = 3,
Message = 4,
Upgrade = 5,
Noop = 6,
}
impl TryFrom<u8> for PacketType {
type Error = PacketError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Open),
1 => Ok(Self::Close),
2 => Ok(Self::Ping),
3 => Ok(Self::Pong),
4 => Ok(Self::Message),
5 => Ok(Self::Upgrade),
6 => Ok(Self::Noop),
_ => Err(PacketError::InvalidType(value)),
}
}
}
impl TryFrom<char> for PacketType {
type Error = PacketError;
fn try_from(value: char) -> Result<Self, Self::Error> {
match value {
'0' => Ok(Self::Open),
'1' => Ok(Self::Close),
'2' => Ok(Self::Ping),
'3' => Ok(Self::Pong),
'4' => Ok(Self::Message),
'5' => Ok(Self::Upgrade),
'6' => Ok(Self::Noop),
_ => Err(PacketError::InvalidTypeChar(value)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PacketData {
Text(String),
Binary(Vec<u8>),
Empty,
}
#[derive(Debug, Clone)]
pub struct Packet {
pub packet_type: PacketType,
pub data: PacketData,
}
impl Packet {
pub fn open(handshake: &HandshakeData) -> Self {
let data = serde_json::to_string(handshake)
.unwrap_or_else(|e| {
tracing::error!("Failed to serialize handshake data: {}", e);
"{}".to_string()
});
Self {
packet_type: PacketType::Open,
data: PacketData::Text(data),
}
}
pub fn close() -> Self {
Self {
packet_type: PacketType::Close,
data: PacketData::Empty,
}
}
pub fn ping(data: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Ping,
data: PacketData::Text(data.into()),
}
}
pub fn pong(data: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Pong,
data: PacketData::Text(data.into()),
}
}
pub fn message_text(data: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Message,
data: PacketData::Text(data.into()),
}
}
pub fn message_binary(data: Vec<u8>) -> Self {
Self {
packet_type: PacketType::Message,
data: PacketData::Binary(data),
}
}
pub fn upgrade() -> Self {
Self {
packet_type: PacketType::Upgrade,
data: PacketData::Empty,
}
}
pub fn noop() -> Self {
Self {
packet_type: PacketType::Noop,
data: PacketData::Empty,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeData {
pub sid: String,
pub upgrades: Vec<String>,
#[serde(rename = "pingInterval")]
pub ping_interval: u64,
#[serde(rename = "pingTimeout")]
pub ping_timeout: u64,
#[serde(rename = "maxPayload")]
pub max_payload: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum PacketError {
#[error("invalid packet type: {0}")]
InvalidType(u8),
#[error("invalid packet type char: {0}")]
InvalidTypeChar(char),
#[error("empty packet")]
Empty,
#[error("invalid base64: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("invalid utf8: {0}")]
InvalidUtf8(#[from] std::string::FromUtf8Error),
#[error("serialization error: {0}")]
Serialization(String),
}
+185
View File
@@ -0,0 +1,185 @@
use std::sync::Arc;
use std::time::Duration;
use actix_web::{web, HttpRequest, HttpResponse};
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
#[derive(Debug, serde::Deserialize)]
pub struct PollingQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn polling_get(
_req: HttpRequest,
query: web::Query<PollingQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
_on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> HttpResponse {
if query.eio.as_deref() != Some("4") {
return HttpResponse::BadRequest().body("invalid EIO version");
}
if query.transport.as_deref() != Some("polling") {
return HttpResponse::BadRequest().body("invalid transport");
}
let sid = match &query.sid {
Some(sid) => sid.clone(),
None => {
return handle_handshake(&store, &config).await;
}
};
let session = match store.get(&sid) {
Some(s) => s,
None => return HttpResponse::BadRequest().body("unknown session"),
};
// Check session state and take any buffered pending packets
let notify = {
let mut session_guard = session.write().await;
if session_guard.state == SessionState::Closed {
return HttpResponse::BadRequest().body("session closed");
}
let pending = session_guard.take_pending();
if !pending.is_empty() {
let payload = codec::encode_payload(&pending);
return HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload);
}
session_guard.notify.clone()
};
let timeout = Duration::from_millis(config.ping_interval + config.ping_timeout);
let _result = tokio::time::timeout(timeout, notify.notified()).await;
// Re-verify session still exists after wait (may have been removed by heartbeat)
if store.get(&sid).is_none() {
return HttpResponse::BadRequest().body("session closed");
}
let mut session_guard = session.write().await;
let packets = session_guard.take_pending();
if packets.is_empty() {
let noop = codec::encode_packet(&Packet::noop());
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(noop)
} else {
let payload = codec::encode_payload(&packets);
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload)
}
}
pub async fn polling_post(
_req: HttpRequest,
body: web::Bytes,
query: web::Query<PollingQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> HttpResponse {
if query.eio.as_deref() != Some("4") {
return HttpResponse::BadRequest().body("invalid EIO version");
}
if query.transport.as_deref() != Some("polling") {
return HttpResponse::BadRequest().body("invalid transport");
}
// Check payload size BEFORE attempting to decode
if body.len() > config.max_payload {
return HttpResponse::PayloadTooLarge().body("payload too large");
}
let sid = match &query.sid {
Some(sid) => sid,
None => return HttpResponse::BadRequest().body("missing sid"),
};
let session = match store.get(sid) {
Some(s) => s,
None => return HttpResponse::BadRequest().body("unknown session"),
};
let body_str = match std::str::from_utf8(&body) {
Ok(s) => s,
Err(_) => return HttpResponse::BadRequest().body("invalid utf8"),
};
let packets = match codec::decode_payload(body_str) {
Ok(p) => p,
Err(_) => return HttpResponse::BadRequest().body("invalid payload"),
};
let mut session_guard = session.write().await;
for packet in packets {
match packet.packet_type {
PacketType::Pong => {
session_guard.update_ping();
}
PacketType::Message => {
let on_msg = on_message.get_ref().clone();
let sid_owned = sid.clone();
tokio::spawn(async move {
on_msg(sid_owned, packet);
});
}
PacketType::Close => {
session_guard.set_state(SessionState::Closed);
store.remove(sid);
return HttpResponse::Ok().body("ok");
}
_ => {}
}
}
HttpResponse::Ok().body("ok")
}
async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpResponse {
let sid = crate::engine::session::generate_sid();
let handshake = crate::engine::packet::HandshakeData {
sid: sid.clone(),
upgrades: vec!["websocket".to_string()],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let _rx = store.create(sid.clone(), TransportType::Polling);
if let Some(session) = store.get(&sid) {
let mut session = session.write().await;
session.set_state(SessionState::Open);
}
let packet = Packet::open(&handshake);
let payload = codec::encode_packet(&packet);
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload)
}
pub fn configure_polling(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(polling_get))
.route("/engine.io/", web::post().to(polling_post));
}
+115
View File
@@ -0,0 +1,115 @@
use std::sync::Arc;
use actix_web::{web, App, HttpServer};
use crate::engine::heartbeat::HeartbeatManager;
use crate::engine::packet::Packet;
use crate::engine::session::SessionStore;
#[derive(Debug, Clone)]
pub struct EngineConfig {
pub ping_interval: u64,
pub ping_timeout: u64,
pub max_payload: usize,
pub path: String,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
ping_interval: 25000,
ping_timeout: 20000,
max_payload: 1_000_000,
path: "/engine.io/".to_string(),
}
}
}
pub struct EngineServer {
pub config: EngineConfig,
pub store: SessionStore,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
}
impl EngineServer {
pub fn new(
config: EngineConfig,
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
) -> Self {
Self {
config,
store: SessionStore::new(),
on_message: Arc::new(on_message),
}
}
pub fn with_store(
config: EngineConfig,
store: SessionStore,
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
) -> Self {
Self {
config,
store,
on_message: Arc::new(on_message),
}
}
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
let store = self.store.clone();
let config = self.config.clone();
let on_message = self.on_message.clone();
// Start heartbeat manager to clean up stale sessions
let heartbeat = Arc::new(HeartbeatManager::new(
store.clone(),
config.ping_interval,
config.ping_timeout,
));
let heartbeat_handle = heartbeat.start();
tracing::info!("Engine.IO HTTP server listening on {}", addr);
let result = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(store.clone()))
.app_data(web::Data::new(config.clone()))
.app_data(web::Data::new(on_message.clone()))
.route(
"/engine.io/",
web::get().to(crate::engine::polling::polling_get),
)
.route(
"/engine.io/",
web::post().to(crate::engine::polling::polling_post),
)
.route(
"/engine.io/",
web::get().to(crate::engine::websocket::websocket_handler),
)
})
.bind(addr)?
.run()
.await;
heartbeat_handle.abort();
result
}
pub async fn run_webtransport(
&self,
port: u16,
cert_path: &str,
key_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
crate::engine::webtransport::run_webtransport_server(
port,
cert_path,
key_path,
self.store.clone(),
self.config.clone(),
self.on_message.clone(),
)
.await
}
}
+171
View File
@@ -0,0 +1,171 @@
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio::sync::{mpsc, Notify};
use crate::engine::packet::Packet;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportType {
Polling,
WebSocket,
WebTransport,
}
impl TransportType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Polling => "polling",
Self::WebSocket => "websocket",
Self::WebTransport => "webtransport",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
Connecting,
Open,
Upgrading,
Closing,
Closed,
}
pub struct Session {
pub sid: String,
pub transport: TransportType,
pub state: SessionState,
pub created_at: Instant,
pub last_ping: Instant,
pub tx: mpsc::Sender<Packet>,
pub pending_packets: Vec<Packet>,
pub notify: Arc<Notify>,
pub upgrade_tx: Option<mpsc::Sender<Packet>>,
}
impl Session {
pub fn new(sid: String, transport: TransportType) -> (Self, mpsc::Receiver<Packet>) {
let (tx, rx) = mpsc::channel(256);
let session = Self {
sid,
transport,
state: SessionState::Connecting,
created_at: Instant::now(),
last_ping: Instant::now(),
tx,
pending_packets: Vec::new(),
notify: Arc::new(Notify::new()),
upgrade_tx: None,
};
(session, rx)
}
/// Send a packet through the mpsc channel (for WS/WT transport consumption).
pub fn send_packet(&self, packet: Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
self.tx.try_send(packet)
}
/// Push a packet using the appropriate mechanism for the current transport.
/// Polling: buffer in pending_packets + notify waiting GET request.
/// WS/WT: try mpsc channel first; if full, buffer as fallback + notify.
pub fn push_packet(&mut self, packet: Packet) {
if self.transport == TransportType::Polling {
self.pending_packets.push(packet);
self.notify.notify_one();
} else {
if self.tx.try_send(packet.clone()).is_err() {
self.pending_packets.push(packet);
self.notify.notify_one();
}
}
}
/// Buffer a packet in pending_packets and notify any waiting polling request.
pub fn buffer_packet(&mut self, packet: Packet) {
self.pending_packets.push(packet);
self.notify.notify_one();
}
pub fn take_pending(&mut self) -> Vec<Packet> {
std::mem::take(&mut self.pending_packets)
}
pub fn update_ping(&mut self) {
self.last_ping = Instant::now();
}
pub fn set_transport(&mut self, transport: TransportType) {
self.transport = transport;
}
pub fn set_state(&mut self, state: SessionState) {
self.state = state;
}
}
#[derive(Clone)]
pub struct SessionStore {
pub sessions: Arc<DashMap<String, Arc<tokio::sync::RwLock<Session>>>>,
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
/// Create a new session. Returns the mpsc receiver for transport-level packet consumption.
/// Logs a warning if the SID collides with an existing session (extremely unlikely with crypto RNG).
pub fn create(&self, sid: String, transport: TransportType) -> mpsc::Receiver<Packet> {
let (session, rx) = Session::new(sid.clone(), transport);
let old = self
.sessions
.insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session)));
if old.is_some() {
tracing::warn!("Session ID collision for SID {}, replacing existing session", sid);
}
rx
}
pub fn get(&self, sid: &str) -> Option<Arc<tokio::sync::RwLock<Session>>> {
self.sessions.get(sid).map(|r| r.value().clone())
}
pub fn remove(&self, sid: &str) {
self.sessions.remove(sid);
}
pub fn exists(&self, sid: &str) -> bool {
self.sessions.contains_key(sid)
}
pub fn len(&self) -> usize {
self.sessions.len()
}
pub fn is_empty(&self) -> bool {
self.sessions.is_empty()
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
/// Generate a random session ID using a cryptographically secure RNG.
/// rand 0.9's default RNG (ChaCha8Rng seeded from OsRng) is crypto-secure.
pub fn generate_sid() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-";
let mut rng = rand::rng();
(0..20)
.map(|_| {
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
+53
View File
@@ -0,0 +1,53 @@
use crate::engine::packet::Packet;
use crate::engine::session::{SessionState, SessionStore, TransportType};
pub async fn handle_upgrade_probe(
store: &SessionStore,
sid: &str,
) -> Result<Packet, UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await;
if session.state == SessionState::Closed {
return Err(UpgradeError::SessionClosed);
}
session.set_state(SessionState::Upgrading);
Ok(Packet::pong("probe"))
}
pub async fn handle_upgrade_complete(
store: &SessionStore,
sid: &str,
new_transport: TransportType,
) -> Result<(), UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await;
session.set_transport(new_transport);
session.set_state(SessionState::Open);
Ok(())
}
pub async fn send_noop_to_pending_polling(
store: &SessionStore,
sid: &str,
) -> Result<(), UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await;
session.buffer_packet(Packet::noop());
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum UpgradeError {
#[error("session not found")]
SessionNotFound,
#[error("session closed")]
SessionClosed,
#[error("invalid state for upgrade")]
InvalidState,
}
+254
View File
@@ -0,0 +1,254 @@
use std::sync::Arc;
use actix_web::{web, HttpRequest, HttpResponse};
use actix_ws::Message;
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketData, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
#[derive(Debug, serde::Deserialize)]
pub struct WsQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn websocket_handler(
req: HttpRequest,
body: web::Payload,
query: web::Query<WsQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> Result<HttpResponse, actix_web::Error> {
if query.eio.as_deref() != Some("4") {
return Ok(HttpResponse::BadRequest().body("invalid EIO version"));
}
if query.transport.as_deref() != Some("websocket") {
return Ok(HttpResponse::BadRequest().body("invalid transport"));
}
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?;
let sid = query.sid.clone();
let is_upgrade = sid.as_ref().map(|s| store.exists(s)).unwrap_or(false);
// Create or reuse session, obtaining the mpsc receiver for the forwarding task
let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
if is_upgrade {
// Upgrade: session already exists, replace its channel and drain pending packets
let session_arc = store.get(sid).unwrap();
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
{
let mut s = session_arc.write().await;
// Swap tx atomically: old_tx will be dropped, closing its channel.
// Any packets in the old rx are consumed by the old send_handle,
// which then exits when it sees the channel close.
// Drain pending_packets (from polling buffering) into new channel.
let pending = s.take_pending();
for packet in pending {
let _ = new_tx.try_send(packet);
}
s.tx = new_tx;
s.set_transport(TransportType::WebSocket);
}
(sid.clone(), new_rx)
} else {
// Reconnect with known SID: create new session
let rx = store.create(sid.clone(), TransportType::WebSocket);
if let Some(s) = store.get(sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
(sid.clone(), rx)
}
} else {
// New connection: generate SID and create session
let new_sid = crate::engine::session::generate_sid();
let rx = store.create(new_sid.clone(), TransportType::WebSocket);
if let Some(s) = store.get(&new_sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
(new_sid, rx)
};
let handshake = crate::engine::packet::HandshakeData {
sid: session_sid.clone(),
upgrades: vec![],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let open_packet = Packet::open(&handshake);
let open_msg = codec::encode_packet(&open_packet);
if ws_session.text(open_msg).await.is_err() {
tracing::warn!("Failed to send open packet to WebSocket session {}", session_sid);
store.remove(&session_sid);
return Ok(response);
}
let store_clone = store.get_ref().clone();
let on_message_clone = on_message.get_ref().clone();
let sid_clone = session_sid.clone();
let ws_session_clone = ws_session.clone();
let max_payload = config.max_payload;
// Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx)
let sid_for_send = session_sid.clone();
let store_for_send = store.get_ref().clone();
let mut ws_for_send = ws_session.clone();
let send_handle = actix_rt::spawn(async move {
while let Some(packet) = session_rx.recv().await {
let encoded = codec::encode_packet(&packet);
if ws_for_send.text(encoded).await.is_err() {
break;
}
}
// Session channel closed — clean up
store_for_send.remove(&sid_for_send);
});
// Task 2: Read incoming WebSocket messages → dispatch
let recv_handle = actix_rt::spawn(async move {
let mut ws_session = ws_session_clone;
while let Some(Ok(msg)) = msg_stream.recv().await {
match msg {
Message::Text(text) => {
if let Ok(packet) = codec::decode_packet(&text) {
match packet.packet_type {
PacketType::Ping => {
if let PacketData::Text(ref data) = packet.data {
if data == "probe" {
let pong = Packet::pong("probe");
let pong_msg = codec::encode_packet(&pong);
let _ = ws_session.text(pong_msg).await;
continue;
}
}
let pong = Packet::pong("");
let pong_msg = codec::encode_packet(&pong);
let _ = ws_session.text(pong_msg).await;
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Upgrade => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_transport(TransportType::WebSocket);
s.set_state(SessionState::Open);
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
let _ = ws_session.close(None).await;
break;
}
_ => {}
}
}
}
Message::Binary(bin) => {
// Enforce max payload size for binary frames
if bin.len() > max_payload {
tracing::warn!(
"Binary payload too large ({}) for session {}",
bin.len(),
sid_clone
);
continue;
}
if let Ok(packet) = codec::decode_packet_ws(&bin) {
if packet.packet_type == PacketType::Message {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
}
}
Message::Close(_) => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
}
}
});
// Task 3: Heartbeat ping sender
let sid_for_ping = session_sid.clone();
let store_for_ping = store.get_ref().clone();
let mut ws_for_ping = ws_session.clone();
let ping_interval = config.ping_interval;
let ping_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
loop {
interval.tick().await;
if let Some(s) = store_for_ping.get(&sid_for_ping) {
let session_state = {
let s = s.read().await;
s.state
};
if session_state == SessionState::Closed {
break;
}
let ping = Packet::ping("");
let ping_msg = codec::encode_packet(&ping);
if ws_for_ping.text(ping_msg).await.is_err() {
break;
}
} else {
break;
}
}
});
// Wait for any task to finish, then clean up
actix_rt::spawn(async move {
// actix_rt::spawn returns JoinHandle which is compatible with tokio::select!
tokio::select! {
_ = send_handle => {},
_ = recv_handle => {},
_ = ping_handle => {},
}
store.remove(&session_sid);
});
Ok(response)
}
pub fn configure_websocket(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(websocket_handler));
}
+223
View File
@@ -0,0 +1,223 @@
use std::sync::Arc;
use wtransport::{Connection, Endpoint, ServerConfig, Identity};
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
pub async fn run_webtransport_server(
port: u16,
cert_path: &str,
key_path: &str,
store: SessionStore,
config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> Result<(), Box<dyn std::error::Error>> {
let identity = Identity::load_pemfiles(cert_path, key_path).await?;
let server_config = ServerConfig::builder()
.with_bind_default(port)
.with_identity(identity)
.build();
let server = Endpoint::server(server_config)?;
tracing::info!("WebTransport server listening on UDP port {}", port);
loop {
let incoming = server.accept().await;
let store = store.clone();
let config = config.clone();
let on_message = on_message.clone();
tokio::spawn(async move {
match handle_webtransport_session(incoming, store, config, on_message).await {
Ok(_) => {}
Err(e) => {
tracing::error!("WebTransport session error: {}", e);
}
}
});
}
}
async fn handle_webtransport_session(
incoming: wtransport::endpoint::IncomingSession,
store: SessionStore,
config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> Result<(), Box<dyn std::error::Error>> {
let request = incoming.await?;
let connection = request.accept().await?;
let sid = crate::engine::session::generate_sid();
let mut rx = store.create(sid.clone(), TransportType::WebTransport);
if let Some(s) = store.get(&sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
let handshake = crate::engine::packet::HandshakeData {
sid: sid.clone(),
upgrades: vec![],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let open_packet = Packet::open(&handshake);
send_wt_packet(&connection, &open_packet).await?;
let store_clone = store.clone();
let sid_clone = sid.clone();
let on_message_clone = on_message.clone();
let connection_recv = connection.clone();
let max_payload = config.max_payload;
// Reuse buffer across recv iterations instead of allocating 65KB each time
let recv_handle = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
match connection_recv.accept_bi().await {
Ok((mut send, mut recv)) => {
// Reset buffer length for the next read without deallocating
buf.resize(65536, 0);
match recv.read(&mut buf).await {
Ok(Some(n)) => {
if n > max_payload {
tracing::warn!(
"WebTransport payload too large ({}) for session {}",
n,
sid_clone
);
continue;
}
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
match packet.packet_type {
PacketType::Ping => {
let pong = Packet::pong("");
if send_wt_packet_on_stream(&mut send, &pong)
.await
.is_err()
{
break;
}
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
}
}
}
Ok(None) => break,
Err(_) => break,
}
}
Err(_) => break,
}
}
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
});
let connection_send = connection.clone();
let send_handle = tokio::spawn(async move {
while let Some(packet) = rx.recv().await {
if send_wt_packet(&connection_send, &packet).await.is_err() {
break;
}
}
});
let connection_ping = connection.clone();
let store_ping = store.clone();
let sid_ping = sid.clone();
let ping_interval = config.ping_interval;
let ping_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
loop {
interval.tick().await;
if let Some(s) = store_ping.get(&sid_ping) {
let state = {
let s = s.read().await;
s.state
};
if state == SessionState::Closed {
break;
}
let ping = Packet::ping("");
if send_wt_packet(&connection_ping, &ping).await.is_err() {
break;
}
} else {
break;
}
}
});
tokio::select! {
_ = recv_handle => {},
_ = send_handle => {},
_ = ping_handle => {},
}
store.remove(&sid);
Ok(())
}
async fn send_wt_packet(
connection: &Connection,
packet: &Packet,
) -> Result<(), Box<dyn std::error::Error>> {
let (mut send, _recv) = connection.open_bi().await?.await?;
let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header).await?;
send.write_all(&encoded).await?;
send.finish().await?;
Ok(())
}
async fn send_wt_packet_on_stream(
send: &mut wtransport::SendStream,
packet: &Packet,
) -> Result<(), Box<dyn std::error::Error>> {
let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header).await?;
send.write_all(&encoded).await?;
send.finish().await?;
Ok(())
}