feat(auth): add authentication protocol definitions and build configuration

- Add TokenClaims message for JWT payload structure with user id, issuer, timestamps, and scopes
- Implement IssueTokenRequest/Response for creating access and refresh tokens with TTL support
- Create RefreshTokenRequest/Response for token rotation functionality
- Define RevokeTokenRequest/Response with support for single token or user-wide revocation
- Add VerifyTokenRequest/Response for validating JWT tokens with detailed claims information
- Implement signing key distribution system with GetSigningKeysRequest/Response
- Create TokenService gRPC service with IssueToken, RefreshToken, RevokeToken, VerifyToken, and GetSigningKeys methods
- Add build.rs configuration to compile proto files using tonic_prost_build
- Include channel, channel_settings, member, and permission protocol definitions for IM services
- Generate Rust code bindings through pb/core.rs and pb/im.rs modules
This commit is contained in:
zhenyi
2026-06-10 23:45:40 +08:00
commit 06e8ee96a5
43 changed files with 9671 additions and 0 deletions
Generated
+3610
View File
File diff suppressed because it is too large Load Diff
+45
View File
@@ -0,0 +1,45 @@
[package]
name = "imks"
version = "0.1.0"
edition = "2024"
[lib]
path = "lib.rs"
name = "imks"
[[bin]]
path = "main.rs"
name = "imks"
[dependencies]
tonic = "0.14.6"
prost = "0.14.3"
prost-types = "0.14"
tonic-build = "0.14.6"
tonic-health = "0.14.6"
tonic-prost = "0.14.6"
tokio = { version = "1.52.3", features = ["full"] }
actix-web = { version = "4.13.0", features = [] }
actix-ws = { version = "0.4.0", features = [] }
actix-rt = "2"
serde = { version = "1", features = ["derive"] }
serde_json = { version = "1" }
base64 = "0.22"
rand = "0.9"
wtransport = "0.7"
dashmap = "6"
thiserror = "2"
async-trait = "0.1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
fred = { version = "10", features = ["subscriber-client"] }
async-nats = "0.38"
uuid = { version = "1", features = ["v4"] }
futures-util = "0.3"
[build-dependencies]
tonic-prost-build = "0.14.6"
walkdir = "2.5.0"
+24
View File
@@ -0,0 +1,24 @@
use std::path::Path;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let proto_dir = Path::new("proto/core");
let protos: Vec<_> = walkdir::WalkDir::new(proto_dir)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "proto"))
.map(|e| e.path().to_owned())
.collect();
for proto in &protos {
println!("cargo:rerun-if-changed={}", proto.display());
}
let includes = vec![proto_dir.to_path_buf()];
tonic_prost_build::configure()
.build_server(false)
.build_client(true)
.compile_protos(&protos, &includes)?;
Ok(())
}
+239
View File
@@ -0,0 +1,239 @@
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use crate::engine::packet::{Packet, PacketData, PacketError, PacketType};
const RECORD_SEPARATOR: char = '\x1e';
pub fn encode_packet(packet: &Packet) -> String {
let type_char = packet.packet_type as u8 + b'0';
let type_str = type_char as char;
match &packet.data {
PacketData::Text(s) => format!("{type_str}{s}"),
PacketData::Binary(b) => format!("{type_str}b{}", BASE64.encode(b)),
PacketData::Empty => type_str.to_string(),
}
}
pub fn encode_packet_binary_ws(packet: &Packet) -> Vec<u8> {
let type_byte = packet.packet_type as u8 + b'0';
match &packet.data {
PacketData::Text(s) => {
let mut buf = Vec::with_capacity(1 + s.len());
buf.push(type_byte);
buf.extend_from_slice(s.as_bytes());
buf
}
PacketData::Binary(b) => {
let mut buf = Vec::with_capacity(1 + b.len());
buf.push(type_byte);
buf.extend_from_slice(b);
buf
}
PacketData::Empty => vec![type_byte],
}
}
pub fn decode_packet(input: &str) -> Result<Packet, PacketError> {
let mut chars = input.chars();
let type_char = chars.next().ok_or(PacketError::Empty)?;
let packet_type = PacketType::try_from(type_char)?;
let rest: String = chars.collect();
if rest.is_empty() {
return Ok(Packet {
packet_type,
data: PacketData::Empty,
});
}
if let Some(b64_data) = rest.strip_prefix('b') {
let decoded = BASE64.decode(b64_data)?;
return Ok(Packet {
packet_type,
data: PacketData::Binary(decoded),
});
}
Ok(Packet {
packet_type,
data: PacketData::Text(rest),
})
}
/// Decode a WebSocket binary frame into a Packet.
/// Handles both text-encoded packets (UTF-8 payload after type byte)
/// and binary packets (raw binary payload after type byte).
pub fn decode_packet_ws(input: &[u8]) -> Result<Packet, PacketError> {
if input.is_empty() {
return Err(PacketError::Empty);
}
let type_byte = input[0];
let packet_type = PacketType::try_from(type_byte.wrapping_sub(b'0'))?;
let rest = &input[1..];
if rest.is_empty() {
return Ok(Packet {
packet_type,
data: PacketData::Empty,
});
}
// Try UTF-8 first; if it fails, treat as binary data
match String::from_utf8(rest.to_vec()) {
Ok(text) => Ok(Packet {
packet_type,
data: PacketData::Text(text),
}),
Err(_) => Ok(Packet {
packet_type,
data: PacketData::Binary(rest.to_vec()),
}),
}
}
pub fn encode_payload(packets: &[Packet]) -> String {
packets
.iter()
.map(encode_packet)
.collect::<Vec<_>>()
.join(&RECORD_SEPARATOR.to_string())
}
pub fn decode_payload(input: &str) -> Result<Vec<Packet>, PacketError> {
input
.split(RECORD_SEPARATOR)
.filter(|s| !s.is_empty())
.map(decode_packet)
.collect()
}
pub fn encode_webtransport_header(payload_len: usize, is_binary: bool) -> Vec<u8> {
let binary_bit: u8 = if is_binary { 0x80 } else { 0x00 };
if payload_len <= 125 {
vec![binary_bit | (payload_len as u8)]
} else if payload_len <= 65535 {
let mut header = vec![binary_bit | 126];
header.extend_from_slice(&(payload_len as u16).to_be_bytes());
header
} else {
let mut header = vec![binary_bit | 127];
header.extend_from_slice(&(payload_len as u64).to_be_bytes());
header
}
}
pub fn decode_webtransport_header(header: &[u8]) -> Option<(usize, bool)> {
if header.is_empty() {
return None;
}
let first = header[0];
let is_binary = (first & 0x80) != 0;
let len_indicator = first & 0x7f;
if len_indicator <= 125 {
Some((len_indicator as usize, is_binary))
} else if len_indicator == 126 {
if header.len() < 3 {
return None;
}
let len = u16::from_be_bytes([header[1], header[2]]) as usize;
Some((len, is_binary))
} else {
if header.len() < 9 {
return None;
}
let len = u64::from_be_bytes([
header[1], header[2], header[3], header[4], header[5], header[6], header[7], header[8],
]) as usize;
Some((len, is_binary))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_text_packet() {
let packet = Packet::message_text("hello");
let encoded = encode_packet(&packet);
assert_eq!(encoded, "4hello");
let decoded = decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
}
#[test]
fn test_encode_decode_binary_packet() {
let packet = Packet::message_binary(vec![1, 2, 3, 4]);
let encoded = encode_packet(&packet);
assert_eq!(encoded, "4bAQIDBA==");
let decoded = decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Binary(vec![1, 2, 3, 4]));
}
#[test]
fn test_encode_decode_payload() {
let packets = vec![
Packet::message_text("hello"),
Packet::ping(""),
Packet::message_text("world"),
];
let encoded = encode_payload(&packets);
assert_eq!(encoded, "4hello\x1e2\x1e4world");
let decoded = decode_payload(&encoded).unwrap();
assert_eq!(decoded.len(), 3);
assert_eq!(decoded[0].packet_type, PacketType::Message);
assert_eq!(decoded[1].packet_type, PacketType::Ping);
assert_eq!(decoded[2].packet_type, PacketType::Message);
}
#[test]
fn test_webtransport_header() {
let header = encode_webtransport_header(6, false);
assert_eq!(header, vec![0x06]);
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
assert_eq!(len, 6);
assert!(!is_binary);
let header = encode_webtransport_header(200, true);
assert_eq!(header.len(), 3);
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
assert_eq!(len, 200);
assert!(is_binary);
}
#[test]
fn test_decode_packet_ws_text() {
let input = b"4hello";
let decoded = decode_packet_ws(input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
}
#[test]
fn test_decode_packet_ws_binary() {
// Type byte 4 (Message) + raw binary payload (non-UTF-8)
let input: Vec<u8> = vec![b'4', 0x80, 0xFF, 0x00, 0x01];
let decoded = decode_packet_ws(&input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01]));
}
#[test]
fn test_decode_packet_ws_empty() {
let input = b"4";
let decoded = decode_packet_ws(input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Empty);
}
}
+77
View File
@@ -0,0 +1,77 @@
use std::sync::Arc;
use std::time::Duration;
use crate::engine::packet::Packet;
use crate::engine::session::{SessionState, SessionStore, TransportType};
pub struct HeartbeatManager {
store: SessionStore,
ping_interval: u64,
ping_timeout: u64,
}
impl HeartbeatManager {
pub fn new(store: SessionStore, ping_interval: u64, ping_timeout: u64) -> Self {
Self {
store,
ping_interval,
ping_timeout,
}
}
pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let this = self.clone();
tokio::spawn(async move {
this.run().await;
})
}
async fn run(&self) {
let mut interval = tokio::time::interval(Duration::from_millis(self.ping_interval));
loop {
interval.tick().await;
self.check_sessions().await;
}
}
async fn check_sessions(&self) {
let now = std::time::Instant::now();
let timeout_duration = Duration::from_millis(self.ping_interval + self.ping_timeout);
let mut to_remove = Vec::new();
for entry in self.store.sessions.iter() {
let sid = entry.key().clone();
let session = entry.value().clone();
let (state, last_ping, transport) = {
let s = session.read().await;
(s.state, s.last_ping, s.transport)
};
if state == SessionState::Closed {
to_remove.push(sid);
continue;
}
if now.duration_since(last_ping) > timeout_duration {
tracing::warn!("Session {} timed out", sid);
to_remove.push(sid);
continue;
}
// For polling sessions: buffer a ping packet for the next GET request.
// WS/WT sessions rely on their own dedicated ping tasks; the timeout
// check above already serves as the safety net for all transports.
if state == SessionState::Open && transport == TransportType::Polling {
let mut s = session.write().await;
s.buffer_packet(Packet::ping(""));
}
}
for sid in to_remove {
self.store.remove(&sid);
}
}
}
+13
View File
@@ -0,0 +1,13 @@
pub mod codec;
pub mod heartbeat;
pub mod packet;
pub mod polling;
pub mod server;
pub mod session;
pub mod upgrade;
pub mod websocket;
pub mod webtransport;
pub use packet::{HandshakeData, Packet, PacketData, PacketType};
pub use server::{EngineConfig, EngineServer};
pub use session::{SessionState, SessionStore, TransportType};
+151
View File
@@ -0,0 +1,151 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum PacketType {
Open = 0,
Close = 1,
Ping = 2,
Pong = 3,
Message = 4,
Upgrade = 5,
Noop = 6,
}
impl TryFrom<u8> for PacketType {
type Error = PacketError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Open),
1 => Ok(Self::Close),
2 => Ok(Self::Ping),
3 => Ok(Self::Pong),
4 => Ok(Self::Message),
5 => Ok(Self::Upgrade),
6 => Ok(Self::Noop),
_ => Err(PacketError::InvalidType(value)),
}
}
}
impl TryFrom<char> for PacketType {
type Error = PacketError;
fn try_from(value: char) -> Result<Self, Self::Error> {
match value {
'0' => Ok(Self::Open),
'1' => Ok(Self::Close),
'2' => Ok(Self::Ping),
'3' => Ok(Self::Pong),
'4' => Ok(Self::Message),
'5' => Ok(Self::Upgrade),
'6' => Ok(Self::Noop),
_ => Err(PacketError::InvalidTypeChar(value)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PacketData {
Text(String),
Binary(Vec<u8>),
Empty,
}
#[derive(Debug, Clone)]
pub struct Packet {
pub packet_type: PacketType,
pub data: PacketData,
}
impl Packet {
pub fn open(handshake: &HandshakeData) -> Self {
let data = serde_json::to_string(handshake)
.unwrap_or_else(|e| {
tracing::error!("Failed to serialize handshake data: {}", e);
"{}".to_string()
});
Self {
packet_type: PacketType::Open,
data: PacketData::Text(data),
}
}
pub fn close() -> Self {
Self {
packet_type: PacketType::Close,
data: PacketData::Empty,
}
}
pub fn ping(data: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Ping,
data: PacketData::Text(data.into()),
}
}
pub fn pong(data: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Pong,
data: PacketData::Text(data.into()),
}
}
pub fn message_text(data: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Message,
data: PacketData::Text(data.into()),
}
}
pub fn message_binary(data: Vec<u8>) -> Self {
Self {
packet_type: PacketType::Message,
data: PacketData::Binary(data),
}
}
pub fn upgrade() -> Self {
Self {
packet_type: PacketType::Upgrade,
data: PacketData::Empty,
}
}
pub fn noop() -> Self {
Self {
packet_type: PacketType::Noop,
data: PacketData::Empty,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeData {
pub sid: String,
pub upgrades: Vec<String>,
#[serde(rename = "pingInterval")]
pub ping_interval: u64,
#[serde(rename = "pingTimeout")]
pub ping_timeout: u64,
#[serde(rename = "maxPayload")]
pub max_payload: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum PacketError {
#[error("invalid packet type: {0}")]
InvalidType(u8),
#[error("invalid packet type char: {0}")]
InvalidTypeChar(char),
#[error("empty packet")]
Empty,
#[error("invalid base64: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("invalid utf8: {0}")]
InvalidUtf8(#[from] std::string::FromUtf8Error),
#[error("serialization error: {0}")]
Serialization(String),
}
+185
View File
@@ -0,0 +1,185 @@
use std::sync::Arc;
use std::time::Duration;
use actix_web::{web, HttpRequest, HttpResponse};
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
#[derive(Debug, serde::Deserialize)]
pub struct PollingQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn polling_get(
_req: HttpRequest,
query: web::Query<PollingQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
_on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> HttpResponse {
if query.eio.as_deref() != Some("4") {
return HttpResponse::BadRequest().body("invalid EIO version");
}
if query.transport.as_deref() != Some("polling") {
return HttpResponse::BadRequest().body("invalid transport");
}
let sid = match &query.sid {
Some(sid) => sid.clone(),
None => {
return handle_handshake(&store, &config).await;
}
};
let session = match store.get(&sid) {
Some(s) => s,
None => return HttpResponse::BadRequest().body("unknown session"),
};
// Check session state and take any buffered pending packets
let notify = {
let mut session_guard = session.write().await;
if session_guard.state == SessionState::Closed {
return HttpResponse::BadRequest().body("session closed");
}
let pending = session_guard.take_pending();
if !pending.is_empty() {
let payload = codec::encode_payload(&pending);
return HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload);
}
session_guard.notify.clone()
};
let timeout = Duration::from_millis(config.ping_interval + config.ping_timeout);
let _result = tokio::time::timeout(timeout, notify.notified()).await;
// Re-verify session still exists after wait (may have been removed by heartbeat)
if store.get(&sid).is_none() {
return HttpResponse::BadRequest().body("session closed");
}
let mut session_guard = session.write().await;
let packets = session_guard.take_pending();
if packets.is_empty() {
let noop = codec::encode_packet(&Packet::noop());
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(noop)
} else {
let payload = codec::encode_payload(&packets);
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload)
}
}
pub async fn polling_post(
_req: HttpRequest,
body: web::Bytes,
query: web::Query<PollingQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> HttpResponse {
if query.eio.as_deref() != Some("4") {
return HttpResponse::BadRequest().body("invalid EIO version");
}
if query.transport.as_deref() != Some("polling") {
return HttpResponse::BadRequest().body("invalid transport");
}
// Check payload size BEFORE attempting to decode
if body.len() > config.max_payload {
return HttpResponse::PayloadTooLarge().body("payload too large");
}
let sid = match &query.sid {
Some(sid) => sid,
None => return HttpResponse::BadRequest().body("missing sid"),
};
let session = match store.get(sid) {
Some(s) => s,
None => return HttpResponse::BadRequest().body("unknown session"),
};
let body_str = match std::str::from_utf8(&body) {
Ok(s) => s,
Err(_) => return HttpResponse::BadRequest().body("invalid utf8"),
};
let packets = match codec::decode_payload(body_str) {
Ok(p) => p,
Err(_) => return HttpResponse::BadRequest().body("invalid payload"),
};
let mut session_guard = session.write().await;
for packet in packets {
match packet.packet_type {
PacketType::Pong => {
session_guard.update_ping();
}
PacketType::Message => {
let on_msg = on_message.get_ref().clone();
let sid_owned = sid.clone();
tokio::spawn(async move {
on_msg(sid_owned, packet);
});
}
PacketType::Close => {
session_guard.set_state(SessionState::Closed);
store.remove(sid);
return HttpResponse::Ok().body("ok");
}
_ => {}
}
}
HttpResponse::Ok().body("ok")
}
async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpResponse {
let sid = crate::engine::session::generate_sid();
let handshake = crate::engine::packet::HandshakeData {
sid: sid.clone(),
upgrades: vec!["websocket".to_string()],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let _rx = store.create(sid.clone(), TransportType::Polling);
if let Some(session) = store.get(&sid) {
let mut session = session.write().await;
session.set_state(SessionState::Open);
}
let packet = Packet::open(&handshake);
let payload = codec::encode_packet(&packet);
HttpResponse::Ok()
.content_type("text/plain; charset=UTF-8")
.body(payload)
}
pub fn configure_polling(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(polling_get))
.route("/engine.io/", web::post().to(polling_post));
}
+115
View File
@@ -0,0 +1,115 @@
use std::sync::Arc;
use actix_web::{web, App, HttpServer};
use crate::engine::heartbeat::HeartbeatManager;
use crate::engine::packet::Packet;
use crate::engine::session::SessionStore;
#[derive(Debug, Clone)]
pub struct EngineConfig {
pub ping_interval: u64,
pub ping_timeout: u64,
pub max_payload: usize,
pub path: String,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
ping_interval: 25000,
ping_timeout: 20000,
max_payload: 1_000_000,
path: "/engine.io/".to_string(),
}
}
}
pub struct EngineServer {
pub config: EngineConfig,
pub store: SessionStore,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
}
impl EngineServer {
pub fn new(
config: EngineConfig,
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
) -> Self {
Self {
config,
store: SessionStore::new(),
on_message: Arc::new(on_message),
}
}
pub fn with_store(
config: EngineConfig,
store: SessionStore,
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
) -> Self {
Self {
config,
store,
on_message: Arc::new(on_message),
}
}
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
let store = self.store.clone();
let config = self.config.clone();
let on_message = self.on_message.clone();
// Start heartbeat manager to clean up stale sessions
let heartbeat = Arc::new(HeartbeatManager::new(
store.clone(),
config.ping_interval,
config.ping_timeout,
));
let heartbeat_handle = heartbeat.start();
tracing::info!("Engine.IO HTTP server listening on {}", addr);
let result = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(store.clone()))
.app_data(web::Data::new(config.clone()))
.app_data(web::Data::new(on_message.clone()))
.route(
"/engine.io/",
web::get().to(crate::engine::polling::polling_get),
)
.route(
"/engine.io/",
web::post().to(crate::engine::polling::polling_post),
)
.route(
"/engine.io/",
web::get().to(crate::engine::websocket::websocket_handler),
)
})
.bind(addr)?
.run()
.await;
heartbeat_handle.abort();
result
}
pub async fn run_webtransport(
&self,
port: u16,
cert_path: &str,
key_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
crate::engine::webtransport::run_webtransport_server(
port,
cert_path,
key_path,
self.store.clone(),
self.config.clone(),
self.on_message.clone(),
)
.await
}
}
+171
View File
@@ -0,0 +1,171 @@
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio::sync::{mpsc, Notify};
use crate::engine::packet::Packet;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportType {
Polling,
WebSocket,
WebTransport,
}
impl TransportType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Polling => "polling",
Self::WebSocket => "websocket",
Self::WebTransport => "webtransport",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
Connecting,
Open,
Upgrading,
Closing,
Closed,
}
pub struct Session {
pub sid: String,
pub transport: TransportType,
pub state: SessionState,
pub created_at: Instant,
pub last_ping: Instant,
pub tx: mpsc::Sender<Packet>,
pub pending_packets: Vec<Packet>,
pub notify: Arc<Notify>,
pub upgrade_tx: Option<mpsc::Sender<Packet>>,
}
impl Session {
pub fn new(sid: String, transport: TransportType) -> (Self, mpsc::Receiver<Packet>) {
let (tx, rx) = mpsc::channel(256);
let session = Self {
sid,
transport,
state: SessionState::Connecting,
created_at: Instant::now(),
last_ping: Instant::now(),
tx,
pending_packets: Vec::new(),
notify: Arc::new(Notify::new()),
upgrade_tx: None,
};
(session, rx)
}
/// Send a packet through the mpsc channel (for WS/WT transport consumption).
pub fn send_packet(&self, packet: Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
self.tx.try_send(packet)
}
/// Push a packet using the appropriate mechanism for the current transport.
/// Polling: buffer in pending_packets + notify waiting GET request.
/// WS/WT: try mpsc channel first; if full, buffer as fallback + notify.
pub fn push_packet(&mut self, packet: Packet) {
if self.transport == TransportType::Polling {
self.pending_packets.push(packet);
self.notify.notify_one();
} else {
if self.tx.try_send(packet.clone()).is_err() {
self.pending_packets.push(packet);
self.notify.notify_one();
}
}
}
/// Buffer a packet in pending_packets and notify any waiting polling request.
pub fn buffer_packet(&mut self, packet: Packet) {
self.pending_packets.push(packet);
self.notify.notify_one();
}
pub fn take_pending(&mut self) -> Vec<Packet> {
std::mem::take(&mut self.pending_packets)
}
pub fn update_ping(&mut self) {
self.last_ping = Instant::now();
}
pub fn set_transport(&mut self, transport: TransportType) {
self.transport = transport;
}
pub fn set_state(&mut self, state: SessionState) {
self.state = state;
}
}
#[derive(Clone)]
pub struct SessionStore {
pub sessions: Arc<DashMap<String, Arc<tokio::sync::RwLock<Session>>>>,
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
/// Create a new session. Returns the mpsc receiver for transport-level packet consumption.
/// Logs a warning if the SID collides with an existing session (extremely unlikely with crypto RNG).
pub fn create(&self, sid: String, transport: TransportType) -> mpsc::Receiver<Packet> {
let (session, rx) = Session::new(sid.clone(), transport);
let old = self
.sessions
.insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session)));
if old.is_some() {
tracing::warn!("Session ID collision for SID {}, replacing existing session", sid);
}
rx
}
pub fn get(&self, sid: &str) -> Option<Arc<tokio::sync::RwLock<Session>>> {
self.sessions.get(sid).map(|r| r.value().clone())
}
pub fn remove(&self, sid: &str) {
self.sessions.remove(sid);
}
pub fn exists(&self, sid: &str) -> bool {
self.sessions.contains_key(sid)
}
pub fn len(&self) -> usize {
self.sessions.len()
}
pub fn is_empty(&self) -> bool {
self.sessions.is_empty()
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
/// Generate a random session ID using a cryptographically secure RNG.
/// rand 0.9's default RNG (ChaCha8Rng seeded from OsRng) is crypto-secure.
pub fn generate_sid() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-";
let mut rng = rand::rng();
(0..20)
.map(|_| {
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
+53
View File
@@ -0,0 +1,53 @@
use crate::engine::packet::Packet;
use crate::engine::session::{SessionState, SessionStore, TransportType};
pub async fn handle_upgrade_probe(
store: &SessionStore,
sid: &str,
) -> Result<Packet, UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await;
if session.state == SessionState::Closed {
return Err(UpgradeError::SessionClosed);
}
session.set_state(SessionState::Upgrading);
Ok(Packet::pong("probe"))
}
pub async fn handle_upgrade_complete(
store: &SessionStore,
sid: &str,
new_transport: TransportType,
) -> Result<(), UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await;
session.set_transport(new_transport);
session.set_state(SessionState::Open);
Ok(())
}
pub async fn send_noop_to_pending_polling(
store: &SessionStore,
sid: &str,
) -> Result<(), UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await;
session.buffer_packet(Packet::noop());
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum UpgradeError {
#[error("session not found")]
SessionNotFound,
#[error("session closed")]
SessionClosed,
#[error("invalid state for upgrade")]
InvalidState,
}
+254
View File
@@ -0,0 +1,254 @@
use std::sync::Arc;
use actix_web::{web, HttpRequest, HttpResponse};
use actix_ws::Message;
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketData, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
#[derive(Debug, serde::Deserialize)]
pub struct WsQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn websocket_handler(
req: HttpRequest,
body: web::Payload,
query: web::Query<WsQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> Result<HttpResponse, actix_web::Error> {
if query.eio.as_deref() != Some("4") {
return Ok(HttpResponse::BadRequest().body("invalid EIO version"));
}
if query.transport.as_deref() != Some("websocket") {
return Ok(HttpResponse::BadRequest().body("invalid transport"));
}
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?;
let sid = query.sid.clone();
let is_upgrade = sid.as_ref().map(|s| store.exists(s)).unwrap_or(false);
// Create or reuse session, obtaining the mpsc receiver for the forwarding task
let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
if is_upgrade {
// Upgrade: session already exists, replace its channel and drain pending packets
let session_arc = store.get(sid).unwrap();
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
{
let mut s = session_arc.write().await;
// Swap tx atomically: old_tx will be dropped, closing its channel.
// Any packets in the old rx are consumed by the old send_handle,
// which then exits when it sees the channel close.
// Drain pending_packets (from polling buffering) into new channel.
let pending = s.take_pending();
for packet in pending {
let _ = new_tx.try_send(packet);
}
s.tx = new_tx;
s.set_transport(TransportType::WebSocket);
}
(sid.clone(), new_rx)
} else {
// Reconnect with known SID: create new session
let rx = store.create(sid.clone(), TransportType::WebSocket);
if let Some(s) = store.get(sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
(sid.clone(), rx)
}
} else {
// New connection: generate SID and create session
let new_sid = crate::engine::session::generate_sid();
let rx = store.create(new_sid.clone(), TransportType::WebSocket);
if let Some(s) = store.get(&new_sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
(new_sid, rx)
};
let handshake = crate::engine::packet::HandshakeData {
sid: session_sid.clone(),
upgrades: vec![],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let open_packet = Packet::open(&handshake);
let open_msg = codec::encode_packet(&open_packet);
if ws_session.text(open_msg).await.is_err() {
tracing::warn!("Failed to send open packet to WebSocket session {}", session_sid);
store.remove(&session_sid);
return Ok(response);
}
let store_clone = store.get_ref().clone();
let on_message_clone = on_message.get_ref().clone();
let sid_clone = session_sid.clone();
let ws_session_clone = ws_session.clone();
let max_payload = config.max_payload;
// Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx)
let sid_for_send = session_sid.clone();
let store_for_send = store.get_ref().clone();
let mut ws_for_send = ws_session.clone();
let send_handle = actix_rt::spawn(async move {
while let Some(packet) = session_rx.recv().await {
let encoded = codec::encode_packet(&packet);
if ws_for_send.text(encoded).await.is_err() {
break;
}
}
// Session channel closed — clean up
store_for_send.remove(&sid_for_send);
});
// Task 2: Read incoming WebSocket messages → dispatch
let recv_handle = actix_rt::spawn(async move {
let mut ws_session = ws_session_clone;
while let Some(Ok(msg)) = msg_stream.recv().await {
match msg {
Message::Text(text) => {
if let Ok(packet) = codec::decode_packet(&text) {
match packet.packet_type {
PacketType::Ping => {
if let PacketData::Text(ref data) = packet.data {
if data == "probe" {
let pong = Packet::pong("probe");
let pong_msg = codec::encode_packet(&pong);
let _ = ws_session.text(pong_msg).await;
continue;
}
}
let pong = Packet::pong("");
let pong_msg = codec::encode_packet(&pong);
let _ = ws_session.text(pong_msg).await;
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Upgrade => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_transport(TransportType::WebSocket);
s.set_state(SessionState::Open);
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
let _ = ws_session.close(None).await;
break;
}
_ => {}
}
}
}
Message::Binary(bin) => {
// Enforce max payload size for binary frames
if bin.len() > max_payload {
tracing::warn!(
"Binary payload too large ({}) for session {}",
bin.len(),
sid_clone
);
continue;
}
if let Ok(packet) = codec::decode_packet_ws(&bin) {
if packet.packet_type == PacketType::Message {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
}
}
Message::Close(_) => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
}
}
});
// Task 3: Heartbeat ping sender
let sid_for_ping = session_sid.clone();
let store_for_ping = store.get_ref().clone();
let mut ws_for_ping = ws_session.clone();
let ping_interval = config.ping_interval;
let ping_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
loop {
interval.tick().await;
if let Some(s) = store_for_ping.get(&sid_for_ping) {
let session_state = {
let s = s.read().await;
s.state
};
if session_state == SessionState::Closed {
break;
}
let ping = Packet::ping("");
let ping_msg = codec::encode_packet(&ping);
if ws_for_ping.text(ping_msg).await.is_err() {
break;
}
} else {
break;
}
}
});
// Wait for any task to finish, then clean up
actix_rt::spawn(async move {
// actix_rt::spawn returns JoinHandle which is compatible with tokio::select!
tokio::select! {
_ = send_handle => {},
_ = recv_handle => {},
_ = ping_handle => {},
}
store.remove(&session_sid);
});
Ok(response)
}
pub fn configure_websocket(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(websocket_handler));
}
+223
View File
@@ -0,0 +1,223 @@
use std::sync::Arc;
use wtransport::{Connection, Endpoint, ServerConfig, Identity};
use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType};
pub async fn run_webtransport_server(
port: u16,
cert_path: &str,
key_path: &str,
store: SessionStore,
config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> Result<(), Box<dyn std::error::Error>> {
let identity = Identity::load_pemfiles(cert_path, key_path).await?;
let server_config = ServerConfig::builder()
.with_bind_default(port)
.with_identity(identity)
.build();
let server = Endpoint::server(server_config)?;
tracing::info!("WebTransport server listening on UDP port {}", port);
loop {
let incoming = server.accept().await;
let store = store.clone();
let config = config.clone();
let on_message = on_message.clone();
tokio::spawn(async move {
match handle_webtransport_session(incoming, store, config, on_message).await {
Ok(_) => {}
Err(e) => {
tracing::error!("WebTransport session error: {}", e);
}
}
});
}
}
async fn handle_webtransport_session(
incoming: wtransport::endpoint::IncomingSession,
store: SessionStore,
config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> Result<(), Box<dyn std::error::Error>> {
let request = incoming.await?;
let connection = request.accept().await?;
let sid = crate::engine::session::generate_sid();
let mut rx = store.create(sid.clone(), TransportType::WebTransport);
if let Some(s) = store.get(&sid) {
let mut s = s.write().await;
s.set_state(SessionState::Open);
}
let handshake = crate::engine::packet::HandshakeData {
sid: sid.clone(),
upgrades: vec![],
ping_interval: config.ping_interval,
ping_timeout: config.ping_timeout,
max_payload: config.max_payload,
};
let open_packet = Packet::open(&handshake);
send_wt_packet(&connection, &open_packet).await?;
let store_clone = store.clone();
let sid_clone = sid.clone();
let on_message_clone = on_message.clone();
let connection_recv = connection.clone();
let max_payload = config.max_payload;
// Reuse buffer across recv iterations instead of allocating 65KB each time
let recv_handle = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
match connection_recv.accept_bi().await {
Ok((mut send, mut recv)) => {
// Reset buffer length for the next read without deallocating
buf.resize(65536, 0);
match recv.read(&mut buf).await {
Ok(Some(n)) => {
if n > max_payload {
tracing::warn!(
"WebTransport payload too large ({}) for session {}",
n,
sid_clone
);
continue;
}
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
match packet.packet_type {
PacketType::Ping => {
let pong = Packet::pong("");
if send_wt_packet_on_stream(&mut send, &pong)
.await
.is_err()
{
break;
}
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
}
}
}
Ok(None) => break,
Err(_) => break,
}
}
Err(_) => break,
}
}
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
});
let connection_send = connection.clone();
let send_handle = tokio::spawn(async move {
while let Some(packet) = rx.recv().await {
if send_wt_packet(&connection_send, &packet).await.is_err() {
break;
}
}
});
let connection_ping = connection.clone();
let store_ping = store.clone();
let sid_ping = sid.clone();
let ping_interval = config.ping_interval;
let ping_handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
loop {
interval.tick().await;
if let Some(s) = store_ping.get(&sid_ping) {
let state = {
let s = s.read().await;
s.state
};
if state == SessionState::Closed {
break;
}
let ping = Packet::ping("");
if send_wt_packet(&connection_ping, &ping).await.is_err() {
break;
}
} else {
break;
}
}
});
tokio::select! {
_ = recv_handle => {},
_ = send_handle => {},
_ = ping_handle => {},
}
store.remove(&sid);
Ok(())
}
async fn send_wt_packet(
connection: &Connection,
packet: &Packet,
) -> Result<(), Box<dyn std::error::Error>> {
let (mut send, _recv) = connection.open_bi().await?.await?;
let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header).await?;
send.write_all(&encoded).await?;
send.finish().await?;
Ok(())
}
async fn send_wt_packet_on_stream(
send: &mut wtransport::SendStream,
packet: &Packet,
) -> Result<(), Box<dyn std::error::Error>> {
let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header).await?;
send.write_all(&encoded).await?;
send.finish().await?;
Ok(())
}
+3
View File
@@ -0,0 +1,3 @@
pub mod pb;
pub mod socket;
pub mod engine;
+37
View File
@@ -0,0 +1,37 @@
use std::sync::Arc;
use imks::engine::server::EngineConfig;
use imks::socket::server::SocketServer;
fn main() {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let config = EngineConfig::default();
let socket_server = Arc::new(SocketServer::new(config));
let addr = "0.0.0.0:3000";
tracing::info!("Starting Socket.IO server on {}", addr);
tokio::runtime::Runtime::new()
.expect("Failed to create Tokio runtime")
.block_on(async {
let namespace = socket_server.of("/");
namespace
.on_connect(|socket, _auth| {
tracing::info!(
"Socket {} connected (engine: {})",
socket.sid,
socket.engine_sid
);
Ok(())
})
.await;
socket_server.run_http(addr).await.expect("Server error");
});
}
+1
View File
@@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/appks.core.v1.rs"));
+1
View File
@@ -0,0 +1 @@
include!(concat!(env!("OUT_DIR"), "/appks.im.v1.rs"));
+2
View File
@@ -0,0 +1,2 @@
pub mod core;
pub mod im;
+124
View File
@@ -0,0 +1,124 @@
syntax = "proto3";
package appks.core.v1;
// ============================================================
// JWT Payload
// ============================================================
message TokenClaims {
string sub = 1; // user id (uuid)
string iss = 2; // issuer (e.g. "appks")
int64 iat = 3; // issued at (unix seconds)
int64 exp = 4; // expires at (unix seconds)
string jti = 5; // unique token id (for revocation)
string scope = 6; // space-separated scopes
map<string, string> extra = 7; // extensible fields (workspace_id, role, etc.)
}
// ============================================================
// Issue (appks REST API → core)
// ============================================================
message IssueTokenRequest {
string user_id = 1;
int64 ttl_secs = 2; // access token lifetime
repeated string scopes = 3;
map<string, string> extra = 4;
}
message IssueTokenResponse {
string access_token = 1; // JWT
string refresh_token = 2; // opaque, stored in Redis
int64 expires_at = 3;
string key_id = 4; // kid header for the signing key
}
// ============================================================
// Refresh
// ============================================================
message RefreshTokenRequest {
string refresh_token = 1;
}
message RefreshTokenResponse {
string access_token = 1;
string refresh_token = 2; // rotated
int64 expires_at = 3;
string key_id = 4;
}
// ============================================================
// Revoke
// ============================================================
message RevokeTokenRequest {
oneof target {
string jti = 1; // revoke single token
string user_id = 2; // revoke all tokens for user
}
}
message RevokeTokenResponse {
int32 revoked_count = 1;
}
// ============================================================
// Verify (imks → core, RPC 模式)
// imks 把客户端携带的 JWT 发给 core 验证
// ============================================================
message VerifyTokenRequest {
string token = 1;
}
message VerifyTokenResponse {
bool valid = 1;
TokenClaims claims = 2; // only set when valid = true
string reason = 3; // "expired", "revoked", "invalid_signature", etc.
}
// ============================================================
// Key Distribution (imks → core, 本地验证模式)
// imks 拉取公钥/解密密钥,本地验证 JWT,无需每次 RPC
// 密钥窗口 3h,imks 定期刷新
// ============================================================
message SigningKey {
string kid = 1; // key id (matches JWT header kid)
string algorithm = 2; // "HS256", "RS256", "EdDSA", ...
string key_material = 3; // 对称: base64 secret / 非对称: PEM public key
int64 issued_at = 4; // 签发时间
int64 expires_at = 5; // 过期时间 (issued_at + 3h window)
bool active = 6; // 是否为当前活跃签名密钥
}
message GetSigningKeysRequest {
// 空 = 返回所有未过期密钥
// 非空 = 只返回指定 kid 的密钥
string kid = 1;
}
message GetSigningKeysResponse {
repeated SigningKey keys = 1; // 可能同时有多个有效密钥(滚动窗口)
int64 next_rotation_at = 2; // 下次密钥轮换时间,imks 据此安排刷新
}
// ============================================================
// Service
// ============================================================
service TokenService {
// --- 令牌生命周期 (appks REST handler 调用) ---
rpc IssueToken(IssueTokenRequest) returns (IssueTokenResponse);
rpc RefreshToken(RefreshTokenRequest) returns (RefreshTokenResponse);
rpc RevokeToken(RevokeTokenRequest) returns (RevokeTokenResponse);
// --- imks 验证 (RPC 模式) ---
rpc VerifyToken(VerifyTokenRequest) returns (VerifyTokenResponse);
// --- imks 密钥拉取 (本地验证模式) ---
// imks 启动时拉取,之后根据 next_rotation_at 定期刷新
rpc GetSigningKeys(GetSigningKeysRequest) returns (GetSigningKeysResponse);
}
+208
View File
@@ -0,0 +1,208 @@
syntax = "proto3";
package appks.im.v1;
import "google/protobuf/timestamp.proto";
// Channel management service for the IM microservice.
// Provides CRUD for channels and categories, plus channel statistics.
enum ChannelType {
CHANNEL_TYPE_UNSPECIFIED = 0;
CHANNEL_TYPE_PUBLIC = 1;
CHANNEL_TYPE_PRIVATE = 2;
CHANNEL_TYPE_DIRECT = 3;
CHANNEL_TYPE_GROUP = 4;
CHANNEL_TYPE_REPO = 5;
CHANNEL_TYPE_SYSTEM = 6;
}
enum ChannelKind {
CHANNEL_KIND_UNSPECIFIED = 0;
CHANNEL_KIND_TEXT = 1;
CHANNEL_KIND_VOICE = 2;
CHANNEL_KIND_STAGE = 3;
CHANNEL_KIND_FORUM = 4;
CHANNEL_KIND_ANNOUNCEMENT = 5;
}
enum Visibility {
VISIBILITY_UNSPECIFIED = 0;
VISIBILITY_PUBLIC = 1;
VISIBILITY_PRIVATE = 2;
VISIBILITY_INTERNAL = 3;
VISIBILITY_WORKSPACE = 4;
VISIBILITY_PROTECTED = 5;
VISIBILITY_HIDDEN = 6;
VISIBILITY_SECRET = 7;
}
message Channel {
string id = 1;
string workspace_id = 2;
optional string category_id = 3;
optional string parent_channel_id = 4;
string name = 5;
optional string topic = 6;
optional string description = 7;
ChannelType channel_type = 8;
ChannelKind channel_kind = 9;
Visibility visibility = 10;
int32 position = 11;
bool nsfw = 12;
bool read_only = 13;
bool archived = 14;
optional string created_by = 15;
optional int32 rate_limit_per_user = 16;
optional google.protobuf.Timestamp archived_at = 17;
optional string last_message_id = 18;
optional google.protobuf.Timestamp last_message_at = 19;
google.protobuf.Timestamp created_at = 20;
google.protobuf.Timestamp updated_at = 21;
}
message ChannelStats {
string channel_id = 1;
int32 members_count = 2;
int32 messages_count = 3;
int32 threads_count = 4;
int32 reactions_count = 5;
int32 mentions_count = 6;
int32 files_count = 7;
optional google.protobuf.Timestamp last_activity_at = 8;
google.protobuf.Timestamp updated_at = 9;
}
message ChannelCategory {
string id = 1;
string workspace_id = 2;
string name = 3;
int32 position = 4;
bool collapsed = 5;
google.protobuf.Timestamp created_at = 6;
google.protobuf.Timestamp updated_at = 7;
}
message GetChannelRequest {
string channel_id = 1;
}
message GetChannelResponse {
Channel channel = 1;
}
message ListChannelsRequest {
string workspace_name = 1;
optional string category_id = 2;
optional ChannelType channel_type = 3;
optional ChannelKind channel_kind = 4;
int32 limit = 5;
int32 offset = 6;
}
message ListChannelsResponse {
repeated Channel channels = 1;
int32 total = 2;
}
message CreateChannelRequest {
string workspace_name = 1;
string name = 2;
optional string topic = 3;
optional string description = 4;
optional string channel_type = 5;
optional string channel_kind = 6;
optional string visibility = 7;
optional string category_id = 8;
optional string parent_channel_id = 9;
optional string created_by = 10;
optional int32 rate_limit_per_user = 11;
}
message CreateChannelResponse {
Channel channel = 1;
}
message UpdateChannelRequest {
string channel_id = 1;
optional string name = 2;
optional string topic = 3;
optional string description = 4;
optional string visibility = 5;
optional int32 position = 6;
optional bool nsfw = 7;
optional bool read_only = 8;
optional bool archived = 9;
optional string category_id = 10;
optional int32 rate_limit_per_user = 11;
}
message UpdateChannelResponse {
Channel channel = 1;
}
message DeleteChannelRequest {
string channel_id = 1;
}
message DeleteChannelResponse {}
message GetChannelStatsRequest {
string channel_id = 1;
}
message GetChannelStatsResponse {
ChannelStats stats = 1;
}
message ListCategoriesRequest {
string workspace_name = 1;
}
message ListCategoriesResponse {
repeated ChannelCategory categories = 1;
}
message CreateCategoryRequest {
string workspace_name = 1;
string name = 2;
optional int32 position = 3;
}
message CreateCategoryResponse {
ChannelCategory category = 1;
}
message UpdateCategoryRequest {
string category_id = 1;
optional string name = 2;
optional int32 position = 3;
optional bool collapsed = 4;
}
message UpdateCategoryResponse {
ChannelCategory category = 1;
}
message DeleteCategoryRequest {
string category_id = 1;
}
message DeleteCategoryResponse {}
service ChannelService {
rpc GetChannel(GetChannelRequest) returns (GetChannelResponse);
rpc ListChannels(ListChannelsRequest) returns (ListChannelsResponse);
rpc CreateChannel(CreateChannelRequest) returns (CreateChannelResponse);
rpc UpdateChannel(UpdateChannelRequest) returns (UpdateChannelResponse);
rpc DeleteChannel(DeleteChannelRequest) returns (DeleteChannelResponse);
rpc GetChannelStats(GetChannelStatsRequest) returns (GetChannelStatsResponse);
rpc ListCategories(ListCategoriesRequest) returns (ListCategoriesResponse);
rpc CreateCategory(CreateCategoryRequest) returns (CreateCategoryResponse);
rpc UpdateCategory(UpdateCategoryRequest) returns (UpdateCategoryResponse);
rpc DeleteCategory(DeleteCategoryRequest) returns (DeleteCategoryResponse);
}
+401
View File
@@ -0,0 +1,401 @@
syntax = "proto3";
package appks.im.v1;
import "google/protobuf/timestamp.proto";
message ChannelRole {
string id = 1;
string channel_id = 2;
string name = 3;
repeated string permissions = 4;
bool assignable = 5;
google.protobuf.Timestamp created_at = 6;
google.protobuf.Timestamp updated_at = 7;
}
message ListChannelRolesRequest { string channel_id = 1; }
message ListChannelRolesResponse { repeated ChannelRole roles = 1; }
message CreateChannelRoleRequest {
string channel_id = 1;
string name = 2;
repeated string permissions = 3;
bool assignable = 4;
}
message CreateChannelRoleResponse { ChannelRole role = 1; }
message UpdateChannelRoleRequest {
string role_id = 1;
optional string name = 2;
repeated string permissions = 3;
optional bool assignable = 4;
}
message UpdateChannelRoleResponse { ChannelRole role = 1; }
message DeleteChannelRoleRequest { string role_id = 1; }
message DeleteChannelRoleResponse {}
service ChannelRoleService {
rpc ListChannelRoles(ListChannelRolesRequest) returns (ListChannelRolesResponse);
rpc CreateChannelRole(CreateChannelRoleRequest) returns (CreateChannelRoleResponse);
rpc UpdateChannelRole(UpdateChannelRoleRequest) returns (UpdateChannelRoleResponse);
rpc DeleteChannelRole(DeleteChannelRoleRequest) returns (DeleteChannelRoleResponse);
}
message ChannelInvitation {
string id = 1;
string channel_id = 2;
string invited_by = 3;
string invited_user_id = 4;
string role = 5;
string status = 6;
google.protobuf.Timestamp created_at = 7;
google.protobuf.Timestamp updated_at = 8;
}
message ListInvitationsRequest { string channel_id = 1; }
message ListInvitationsResponse { repeated ChannelInvitation invitations = 1; }
message CreateInvitationRequest {
string channel_id = 1;
string invited_user_id = 2;
string role = 3;
}
message CreateInvitationResponse { ChannelInvitation invitation = 1; }
message AcceptInvitationRequest { string invitation_id = 1; }
message AcceptInvitationResponse { ChannelInvitation invitation = 1; }
message RevokeInvitationRequest { string invitation_id = 1; }
message RevokeInvitationResponse {}
service ChannelInvitationService {
rpc ListInvitations(ListInvitationsRequest) returns (ListInvitationsResponse);
rpc CreateInvitation(CreateInvitationRequest) returns (CreateInvitationResponse);
rpc AcceptInvitation(AcceptInvitationRequest) returns (AcceptInvitationResponse);
rpc RevokeInvitation(RevokeInvitationRequest) returns (RevokeInvitationResponse);
}
message ChannelWebhook {
string id = 1;
string channel_id = 2;
string name = 3;
string url = 4;
string secret = 5;
repeated string events = 6;
bool active = 7;
google.protobuf.Timestamp created_at = 8;
google.protobuf.Timestamp updated_at = 9;
}
message ListWebhooksRequest { string channel_id = 1; }
message ListWebhooksResponse { repeated ChannelWebhook webhooks = 1; }
message CreateWebhookRequest {
string channel_id = 1;
string name = 2;
string url = 3;
optional string secret = 4;
repeated string events = 5;
}
message CreateWebhookResponse { ChannelWebhook webhook = 1; }
message UpdateWebhookRequest {
string webhook_id = 1;
optional string name = 2;
optional string url = 3;
optional string secret = 4;
repeated string events = 5;
optional bool active = 6;
}
message UpdateWebhookResponse { ChannelWebhook webhook = 1; }
message DeleteWebhookRequest { string webhook_id = 1; }
message DeleteWebhookResponse {}
service ChannelWebhookService {
rpc ListWebhooks(ListWebhooksRequest) returns (ListWebhooksResponse);
rpc CreateWebhook(CreateWebhookRequest) returns (CreateWebhookResponse);
rpc UpdateWebhook(UpdateWebhookRequest) returns (UpdateWebhookResponse);
rpc DeleteWebhook(DeleteWebhookRequest) returns (DeleteWebhookResponse);
}
message ChannelSlashCommand {
string id = 1;
string channel_id = 2;
string command = 3;
string description = 4;
string request_url = 5;
repeated string scopes = 6;
google.protobuf.Timestamp created_at = 7;
google.protobuf.Timestamp updated_at = 8;
}
message ListSlashCommandsRequest { string channel_id = 1; }
message ListSlashCommandsResponse { repeated ChannelSlashCommand commands = 1; }
message CreateSlashCommandRequest {
string channel_id = 1;
string command = 2;
string description = 3;
string request_url = 4;
repeated string scopes = 5;
}
message CreateSlashCommandResponse { ChannelSlashCommand command = 1; }
message UpdateSlashCommandRequest {
string command_id = 1;
optional string description = 2;
optional string request_url = 3;
repeated string scopes = 4;
}
message UpdateSlashCommandResponse { ChannelSlashCommand command = 1; }
message DeleteSlashCommandRequest { string command_id = 1; }
message DeleteSlashCommandResponse {}
service ChannelSlashCommandService {
rpc ListSlashCommands(ListSlashCommandsRequest) returns (ListSlashCommandsResponse);
rpc CreateSlashCommand(CreateSlashCommandRequest) returns (CreateSlashCommandResponse);
rpc UpdateSlashCommand(UpdateSlashCommandRequest) returns (UpdateSlashCommandResponse);
rpc DeleteSlashCommand(DeleteSlashCommandRequest) returns (DeleteSlashCommandResponse);
}
message ChannelRepoLink {
string id = 1;
string channel_id = 2;
string repo_id = 3;
string link_type = 4;
repeated string events = 5;
google.protobuf.Timestamp created_at = 6;
google.protobuf.Timestamp updated_at = 7;
}
message ListRepoLinksRequest { string channel_id = 1; }
message ListRepoLinksResponse { repeated ChannelRepoLink links = 1; }
message CreateRepoLinkRequest {
string channel_id = 1;
string repo_id = 2;
string link_type = 3;
repeated string events = 4;
}
message CreateRepoLinkResponse { ChannelRepoLink link = 1; }
message DeleteRepoLinkRequest { string link_id = 1; }
message DeleteRepoLinkResponse {}
service ChannelRepoLinkService {
rpc ListRepoLinks(ListRepoLinksRequest) returns (ListRepoLinksResponse);
rpc CreateRepoLink(CreateRepoLinkRequest) returns (CreateRepoLinkResponse);
rpc DeleteRepoLink(DeleteRepoLinkRequest) returns (DeleteRepoLinkResponse);
}
message ImIntegration {
string id = 1;
string channel_id = 2;
string provider = 3;
string external_channel_id = 4;
string sync_direction = 5;
bool active = 6;
google.protobuf.Timestamp created_at = 7;
google.protobuf.Timestamp updated_at = 8;
}
message ListIntegrationsRequest { string channel_id = 1; }
message ListIntegrationsResponse { repeated ImIntegration integrations = 1; }
message CreateIntegrationRequest {
string channel_id = 1;
string provider = 2;
string external_channel_id = 3;
string sync_direction = 4;
}
message CreateIntegrationResponse { ImIntegration integration = 1; }
message UpdateIntegrationRequest {
string integration_id = 1;
optional string sync_direction = 2;
optional bool active = 3;
}
message UpdateIntegrationResponse { ImIntegration integration = 1; }
message DeleteIntegrationRequest { string integration_id = 1; }
message DeleteIntegrationResponse {}
service ImIntegrationService {
rpc ListIntegrations(ListIntegrationsRequest) returns (ListIntegrationsResponse);
rpc CreateIntegration(CreateIntegrationRequest) returns (CreateIntegrationResponse);
rpc UpdateIntegration(UpdateIntegrationRequest) returns (UpdateIntegrationResponse);
rpc DeleteIntegration(DeleteIntegrationRequest) returns (DeleteIntegrationResponse);
}
message CustomEmoji {
string id = 1;
string workspace_id = 2;
string name = 3;
string image_url = 4;
google.protobuf.Timestamp created_at = 5;
}
message ListCustomEmojisRequest { string workspace_id = 1; }
message ListCustomEmojisResponse { repeated CustomEmoji emojis = 1; }
message CreateCustomEmojiRequest {
string workspace_id = 1;
string name = 2;
string image_url = 3;
}
message CreateCustomEmojiResponse { CustomEmoji emoji = 1; }
message DeleteCustomEmojiRequest { string emoji_id = 1; }
message DeleteCustomEmojiResponse {}
service CustomEmojiService {
rpc ListCustomEmojis(ListCustomEmojisRequest) returns (ListCustomEmojisResponse);
rpc CreateCustomEmoji(CreateCustomEmojiRequest) returns (CreateCustomEmojiResponse);
rpc DeleteCustomEmoji(DeleteCustomEmojiRequest) returns (DeleteCustomEmojiResponse);
}
message ForumTag {
string id = 1;
string channel_id = 2;
string name = 3;
bool moderated = 4;
int32 position = 5;
google.protobuf.Timestamp created_at = 6;
google.protobuf.Timestamp updated_at = 7;
}
message ListForumTagsRequest { string channel_id = 1; }
message ListForumTagsResponse { repeated ForumTag tags = 1; }
message CreateForumTagRequest {
string channel_id = 1;
string name = 2;
bool moderated = 3;
optional int32 position = 4;
}
message CreateForumTagResponse { ForumTag tag = 1; }
message UpdateForumTagRequest {
string tag_id = 1;
optional string name = 2;
optional bool moderated = 3;
optional int32 position = 4;
}
message UpdateForumTagResponse { ForumTag tag = 1; }
message DeleteForumTagRequest { string tag_id = 1; }
message DeleteForumTagResponse {}
service ForumTagService {
rpc ListForumTags(ListForumTagsRequest) returns (ListForumTagsResponse);
rpc CreateForumTag(CreateForumTagRequest) returns (CreateForumTagResponse);
rpc UpdateForumTag(UpdateForumTagRequest) returns (UpdateForumTagResponse);
rpc DeleteForumTag(DeleteForumTagRequest) returns (DeleteForumTagResponse);
}
message VoiceParticipant {
string id = 1;
string channel_id = 2;
string user_id = 3;
bool muted = 4;
bool deafened = 5;
google.protobuf.Timestamp joined_at = 6;
}
message ListVoiceParticipantsRequest { string channel_id = 1; }
message ListVoiceParticipantsResponse { repeated VoiceParticipant participants = 1; }
message UpdateVoiceStateRequest {
string channel_id = 1;
string user_id = 2;
optional bool muted = 3;
optional bool deafened = 4;
}
message UpdateVoiceStateResponse { VoiceParticipant participant = 1; }
service VoiceService {
rpc ListVoiceParticipants(ListVoiceParticipantsRequest) returns (ListVoiceParticipantsResponse);
rpc UpdateVoiceState(UpdateVoiceStateRequest) returns (UpdateVoiceStateResponse);
}
message Stage {
string id = 1;
string channel_id = 2;
string topic = 3;
string privacy_level = 4;
bool discoverable = 5;
google.protobuf.Timestamp started_at = 6;
google.protobuf.Timestamp ended_at = 7;
google.protobuf.Timestamp created_at = 8;
google.protobuf.Timestamp updated_at = 9;
}
message GetStageRequest { string channel_id = 1; }
message GetStageResponse { Stage stage = 1; }
message CreateStageRequest {
string channel_id = 1;
string topic = 2;
string privacy_level = 3;
bool discoverable = 4;
}
message CreateStageResponse { Stage stage = 1; }
message UpdateStageRequest {
string stage_id = 1;
optional string topic = 2;
optional string privacy_level = 3;
optional bool discoverable = 4;
}
message UpdateStageResponse { Stage stage = 1; }
message DeleteStageRequest { string stage_id = 1; }
message DeleteStageResponse {}
service StageService {
rpc GetStage(GetStageRequest) returns (GetStageResponse);
rpc CreateStage(CreateStageRequest) returns (CreateStageResponse);
rpc UpdateStage(UpdateStageRequest) returns (UpdateStageResponse);
rpc DeleteStage(DeleteStageRequest) returns (DeleteStageResponse);
}
message ChannelAuditEvent {
string id = 1;
string channel_id = 2;
string actor_id = 3;
string event_type = 4;
string target_type = 5;
string target_id = 6;
optional string old_value = 7;
optional string new_value = 8;
google.protobuf.Timestamp created_at = 9;
}
message ListChannelEventsRequest {
string channel_id = 1;
int32 limit = 2;
int32 offset = 3;
}
message ListChannelEventsResponse {
repeated ChannelAuditEvent events = 1;
int32 total = 2;
}
service ChannelAuditService {
rpc ListChannelEvents(ListChannelEventsRequest) returns (ListChannelEventsResponse);
}
+126
View File
@@ -0,0 +1,126 @@
syntax = "proto3";
package appks.im.v1;
import "google/protobuf/timestamp.proto";
// Member management service for the IM microservice.
// Provides CRUD for channel members, join/leave, and membership checks.
enum Role {
ROLE_UNSPECIFIED = 0;
ROLE_OWNER = 1;
ROLE_ADMIN = 2;
ROLE_MAINTAINER = 3;
ROLE_MODERATOR = 4;
ROLE_MEMBER = 5;
ROLE_CONTRIBUTOR = 6;
ROLE_VIEWER = 7;
ROLE_GUEST = 8;
ROLE_BOT = 9;
}
enum MemberStatus {
MEMBER_STATUS_UNSPECIFIED = 0;
MEMBER_STATUS_ACTIVE = 1;
MEMBER_STATUS_INVITED = 2;
MEMBER_STATUS_LEFT = 3;
MEMBER_STATUS_KICKED = 4;
MEMBER_STATUS_BANNED = 5;
}
message ChannelMember {
string id = 1;
string channel_id = 2;
string user_id = 3;
string role = 4;
string status = 5;
bool muted = 6;
bool pinned = 7;
optional string last_read_message_id = 8;
optional google.protobuf.Timestamp last_read_at = 9;
optional google.protobuf.Timestamp joined_at = 10;
optional google.protobuf.Timestamp left_at = 11;
google.protobuf.Timestamp created_at = 12;
google.protobuf.Timestamp updated_at = 13;
}
message ListMembersRequest {
string channel_id = 1;
optional string status = 2;
int32 limit = 3;
int32 offset = 4;
}
message ListMembersResponse {
repeated ChannelMember members = 1;
int32 total = 2;
}
message InviteMemberRequest {
string channel_id = 1;
string user_id = 2;
optional string role = 3;
}
message InviteMemberResponse {
ChannelMember member = 1;
}
message UpdateMemberRequest {
string channel_id = 1;
string user_id = 2;
optional string role = 3;
optional bool muted = 4;
optional bool pinned = 5;
}
message UpdateMemberResponse {
ChannelMember member = 1;
}
message KickMemberRequest {
string channel_id = 1;
string user_id = 2;
}
message KickMemberResponse {}
message JoinChannelRequest {
string channel_id = 1;
string user_id = 2;
}
message JoinChannelResponse {
ChannelMember member = 1;
}
message LeaveChannelRequest {
string channel_id = 1;
string user_id = 2;
}
message LeaveChannelResponse {}
message IsMemberRequest {
string channel_id = 1;
string user_id = 2;
}
message IsMemberResponse {
bool is_member = 1;
string role = 2;
}
service MemberService {
rpc ListMembers(ListMembersRequest) returns (ListMembersResponse);
rpc InviteMember(InviteMemberRequest) returns (InviteMemberResponse);
rpc UpdateMember(UpdateMemberRequest) returns (UpdateMemberResponse);
rpc KickMember(KickMemberRequest) returns (KickMemberResponse);
rpc JoinChannel(JoinChannelRequest) returns (JoinChannelResponse);
rpc LeaveChannel(LeaveChannelRequest) returns (LeaveChannelResponse);
rpc IsMember(IsMemberRequest) returns (IsMemberResponse);
}
+125
View File
@@ -0,0 +1,125 @@
syntax = "proto3";
package appks.im.v1;
// IM-specific permissions for channel operations.
// Separate from the general Permission enum used for repo/workspace access.
enum ImPermission {
IM_PERMISSION_UNSPECIFIED = 0;
IM_PERMISSION_READ_CHANNEL = 1;
IM_PERMISSION_SEND_MESSAGE = 2;
IM_PERMISSION_MANAGE_THREADS = 3;
IM_PERMISSION_MANAGE_REACTIONS = 4;
IM_PERMISSION_MANAGE_PINS = 5;
IM_PERMISSION_INVITE_MEMBERS = 6;
IM_PERMISSION_KICK_MEMBERS = 7;
IM_PERMISSION_MANAGE_CHANNEL = 8;
IM_PERMISSION_MANAGE_ROLES = 9;
IM_PERMISSION_MANAGE_WEBHOOKS = 10;
IM_PERMISSION_MANAGE_EMOJIS = 11;
IM_PERMISSION_VIEW_AUDIT_LOG = 12;
IM_PERMISSION_MANAGE_INTEGRATIONS = 13;
IM_PERMISSION_SEND_TTS = 14;
IM_PERMISSION_USE_SLASH_COMMANDS = 15;
IM_PERMISSION_ATTACH_FILES = 16;
IM_PERMISSION_MENTION_EVERYONE = 17;
IM_PERMISSION_MANAGE_MESSAGES = 18;
IM_PERMISSION_ADMIN = 19;
}
message PermissionOverwrite {
string id = 1;
string channel_id = 2;
string target_type = 3;
string target_id = 4;
repeated ImPermission allow = 5;
repeated ImPermission deny = 6;
string created_at = 7;
string updated_at = 8;
}
message CheckPermissionRequest {
string channel_id = 1;
string user_id = 2;
ImPermission permission = 3;
}
message CheckPermissionResponse {
bool allowed = 1;
string role = 2;
}
message GetPermissionsRequest {
string channel_id = 1;
string user_id = 2;
}
message GetPermissionsResponse {
repeated ImPermission permissions = 1;
string role = 2;
}
message SetPermissionOverwriteRequest {
string channel_id = 1;
string target_type = 2;
string target_id = 3;
repeated ImPermission allow = 4;
repeated ImPermission deny = 5;
}
message SetPermissionOverwriteResponse {
PermissionOverwrite overwrite = 1;
}
message GetPermissionOverwritesRequest {
string channel_id = 1;
}
message GetPermissionOverwritesResponse {
repeated PermissionOverwrite overwrites = 1;
}
message DeletePermissionOverwriteRequest {
string channel_id = 1;
string target_type = 2;
string target_id = 3;
}
message DeletePermissionOverwriteResponse {}
message ResolveChannelRequest {
string channel_id = 1;
}
message ResolveChannelResponse {
string channel_id = 1;
string workspace_id = 2;
string name = 3;
string visibility = 4;
string channel_type = 5;
bool read_only = 6;
bool archived = 7;
optional string created_by = 8;
}
message EnsureReadableRequest {
string channel_id = 1;
string user_id = 2;
}
message EnsureReadableResponse {
bool allowed = 1;
}
service PermissionService {
rpc CheckPermission(CheckPermissionRequest) returns (CheckPermissionResponse);
rpc GetPermissions(GetPermissionsRequest) returns (GetPermissionsResponse);
rpc SetPermissionOverwrite(SetPermissionOverwriteRequest) returns (SetPermissionOverwriteResponse);
rpc GetPermissionOverwrites(GetPermissionOverwritesRequest) returns (GetPermissionOverwritesResponse);
rpc DeletePermissionOverwrite(DeletePermissionOverwriteRequest) returns (DeletePermissionOverwriteResponse);
rpc ResolveChannel(ResolveChannelRequest) returns (ResolveChannelResponse);
rpc EnsureReadable(EnsureReadableRequest) returns (EnsureReadableResponse);
}
+199
View File
@@ -0,0 +1,199 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use uuid::Uuid;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, SocketInfo};
use crate::socket::packet::Packet;
pub struct LocalAdapter {
server_id: String,
rooms: Arc<DashMap<String, HashSet<String>>>,
socket_rooms: Arc<DashMap<String, HashSet<String>>>,
/// socket_sid → engine_sid
pub socket_sids: Arc<DashMap<String, String>>,
/// socket_sid → namespace path
socket_namespace: Arc<DashMap<String, String>>,
send_fn: Arc<dyn Fn(&str, &Packet) -> Result<(), String> + Send + Sync>,
}
impl LocalAdapter {
pub fn new(
send_fn: impl Fn(&str, &Packet) -> Result<(), String> + Send + Sync + 'static,
) -> Self {
Self {
server_id: Uuid::new_v4().to_string(),
rooms: Arc::new(DashMap::new()),
socket_rooms: Arc::new(DashMap::new()),
socket_sids: Arc::new(DashMap::new()),
socket_namespace: Arc::new(DashMap::new()),
send_fn: Arc::new(send_fn),
}
}
fn room_key(ns: &str, room: &str) -> String {
format!("{}:{}", ns, room)
}
/// Collect socket SIDs matching the broadcast options, scoped to the given namespace.
fn collect_matching_sids(&self, opts: &BroadcastOptions, namespace: &str) -> Vec<String> {
if opts.rooms.is_empty() {
// Broadcast to all sockets in this namespace only
self.socket_sids
.iter()
.filter(|e| {
self.socket_namespace
.get(e.key())
.map(|ns| ns.value() == namespace)
.unwrap_or(false)
})
.map(|e| e.key().clone())
.collect()
} else {
let mut sids = HashSet::new();
for room in &opts.rooms {
let key = Self::room_key(namespace, room);
if let Some(entry) = self.rooms.get(&key) {
for sid in entry.value() {
sids.insert(sid.clone());
}
}
}
sids.into_iter().collect()
}
}
}
#[async_trait]
impl Adapter for LocalAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
let namespace = &packet.namespace;
let sids = self.collect_matching_sids(opts, namespace);
for sid in &sids {
if opts.except.contains(sid) {
continue;
}
// socket_sids maps socket SID -> engine SID
if let Some(entry) = self.socket_sids.get(sid) {
let engine_sid = entry.value();
let result = (self.send_fn)(engine_sid, packet);
if let Err(e) = result {
tracing::warn!("Failed to broadcast to {}: {}", sid, e);
}
}
}
Ok(())
}
async fn register(&self, socket_sid: &str, engine_sid: &str, ns: &str) -> Result<(), AdapterError> {
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
self.socket_namespace.insert(socket_sid.to_string(), ns.to_string());
Ok(())
}
async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> {
self.del_all(socket_sid, ns).await
}
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let key = Self::room_key(ns, room);
self.rooms.entry(key).or_insert_with(HashSet::new).value_mut().insert(sid.to_string());
self.socket_rooms.entry(sid.to_string()).or_insert_with(HashSet::new).value_mut().insert(room.to_string());
Ok(())
}
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let key = Self::room_key(ns, room);
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
room_sids.value_mut().remove(sid);
if room_sids.value_mut().is_empty() {
drop(room_sids);
self.rooms.remove(&key);
}
}
if let Some(mut rooms) = self.socket_rooms.get_mut(sid) {
rooms.value_mut().remove(room);
if rooms.value_mut().is_empty() {
drop(rooms);
self.socket_rooms.remove(sid);
}
}
Ok(())
}
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
for room in &rooms {
let key = Self::room_key(ns, room);
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
room_sids.value_mut().remove(sid);
if room_sids.value_mut().is_empty() {
drop(room_sids);
self.rooms.remove(&key);
}
}
}
}
self.socket_sids.remove(sid);
Ok(())
}
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
// fetch_sockets needs namespace context; use an empty namespace to match all
// (this method is typically called for inspection, not delivery)
let sids: Vec<String> = if opts.rooms.is_empty() {
self.socket_sids.iter().map(|e| e.key().clone()).collect()
} else {
let mut sids_set = HashSet::new();
for room in &opts.rooms {
for entry in self.rooms.iter() {
if entry.key().ends_with(&format!(":{}", room)) {
for sid in entry.value() {
sids_set.insert(sid.clone());
}
}
}
}
sids_set.into_iter().collect()
};
let mut result = Vec::new();
for sid in &sids {
if opts.except.contains(sid) {
continue;
}
if self.socket_sids.contains_key(sid) {
let namespace = self.socket_namespace
.get(sid)
.map(|r| r.value().clone())
.unwrap_or_default();
let rooms = self.socket_rooms
.get(sid)
.map(|r| r.value().clone())
.unwrap_or_default();
result.push(SocketInfo {
sid: sid.clone(),
namespace,
rooms,
});
}
}
Ok(result)
}
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms
.get(sid)
.map(|r| r.value().clone())
.unwrap_or_default())
}
fn server_id(&self) -> &str {
&self.server_id
}
async fn close(&self) -> Result<(), AdapterError> {
Ok(())
}
}
+99
View File
@@ -0,0 +1,99 @@
pub mod local;
pub mod redis;
pub mod nats;
use std::collections::HashSet;
use async_trait::async_trait;
use thiserror::Error;
use crate::socket::packet::Packet;
#[derive(Error, Debug)]
pub enum AdapterError {
#[error("Redis error: {0}")]
Redis(String),
#[error("NATS error: {0}")]
Nats(String),
#[error("Message bus error: {0}")]
MessageBus(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Room error: {0}")]
Room(String),
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct BroadcastOptions {
pub rooms: HashSet<String>,
pub except: HashSet<String>,
pub flags: BroadcastFlags,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct BroadcastFlags {
pub local_only: bool,
pub broadcast: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SocketInfo {
pub sid: String,
pub namespace: String,
pub rooms: HashSet<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
pub enum BusMessage {
Broadcast {
namespace: String,
packet: String,
opts: BroadcastOptions,
server_id: String,
},
SocketJoin {
namespace: String,
sid: String,
room: String,
server_id: String,
},
SocketLeave {
namespace: String,
sid: String,
room: String,
server_id: String,
},
SocketDisconnect {
namespace: String,
sid: String,
server_id: String,
},
}
#[async_trait]
pub trait Adapter: Send + Sync + 'static {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError>;
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>;
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError>;
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError>;
fn server_id(&self) -> &str;
async fn close(&self) -> Result<(), AdapterError>;
/// Register a socket SID → engine SID mapping in the adapter.
/// Must be called when a socket first connects, before any room operations.
/// The `ns` parameter is the namespace path this socket belongs to.
async fn register(&self, _socket_sid: &str, _engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
Ok(())
}
/// Unregister a socket from the adapter, removing all local mappings.
async fn unregister(&self, _socket_sid: &str, _ns: &str) -> Result<(), AdapterError> {
Ok(())
}
}
pub use local::LocalAdapter;
pub use redis::RedisAdapter;
pub use nats::NatsAdapter;
+302
View File
@@ -0,0 +1,302 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::mpsc;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
use crate::socket::message_bus::MessageBus;
use crate::socket::packet::Packet;
use crate::socket::parser;
use crate::socket::socket::Socket;
/// Handle incoming bus messages from other servers.
/// Only performs local dispatch — no remote state writes needed.
async fn handle_bus_message(
msg: BusMessage,
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
server_id: &str,
) {
match msg {
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
if sender_id == server_id {
return;
}
if let Ok(decoded_packet) = parser::decode(&packet) {
on_local_broadcast(&decoded_packet, &opts);
}
}
// NATS adapter manages room state locally; cross-server join/leave/disconnect
// are informational only and don't require duplicate state writes.
BusMessage::SocketJoin { server_id: sender_id, .. }
| BusMessage::SocketLeave { server_id: sender_id, .. }
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
if sender_id == server_id {
return;
}
}
}
}
/// NATS-based adapter that manages room state locally and uses NATS
/// for cross-server broadcast only. Does NOT depend on Redis.
pub struct NatsAdapter {
message_bus: Arc<dyn MessageBus>,
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
socket_rooms: DashMap<String, HashSet<String>>,
rooms: DashMap<String, HashSet<String>>,
/// socket_sid → engine_sid mapping for local delivery
socket_sids: DashMap<String, String>,
sockets: DashMap<String, Arc<Socket>>,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
}
impl NatsAdapter {
pub fn new(
message_bus: Arc<dyn MessageBus>,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
) -> Self {
Self {
message_bus,
server_id,
namespace,
on_local_broadcast,
room_subscribers: DashMap::new(),
socket_rooms: DashMap::new(),
rooms: DashMap::new(),
socket_sids: DashMap::new(),
sockets: DashMap::new(),
}
}
pub async fn init(&self) -> Result<(), AdapterError> {
let channels = ["broadcast", "join", "leave", "disconnect"];
let prefix = format!("socket.io:{}:", self.namespace);
for channel_type in channels {
let subject = format!("{}{}", prefix, channel_type);
match self.message_bus.subscribe(&subject).await {
Ok(rx) => {
self.room_subscribers.insert(channel_type.to_string(), rx);
}
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
}
}
self.spawn_listener();
Ok(())
}
fn spawn_listener(&self) {
let server_id = self.server_id.clone();
let on_local_broadcast = self.on_local_broadcast.clone();
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
tokio::spawn(async move {
loop {
tokio::select! {
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { join_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
else => break,
}
}
});
}
}
#[async_trait]
impl Adapter for NatsAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
if opts.flags.local_only {
(self.on_local_broadcast)(packet, opts);
return Ok(());
}
let msg = BusMessage::Broadcast {
namespace: self.namespace.clone(),
packet: parser::encode(packet),
opts: opts.clone(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:broadcast", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
(self.on_local_broadcast)(packet, opts);
Ok(())
}
async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
Ok(())
}
async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
self.socket_rooms
.entry(sid.to_string())
.and_modify(|set| { set.insert(room.to_string()); })
.or_insert_with(|| HashSet::from([room.to_string()]));
self.rooms
.entry(room.to_string())
.and_modify(|set| { set.insert(sid.to_string()); })
.or_insert_with(|| HashSet::from([sid.to_string()]));
let msg = BusMessage::SocketJoin {
namespace: self.namespace.clone(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:join", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
entry.value_mut().remove(room);
}
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
self.socket_rooms.remove(sid);
}
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
let msg = BusMessage::SocketLeave {
namespace: self.namespace.clone(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:leave", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del_all(&self, sid: &str, _ns: &str) -> Result<(), AdapterError> {
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
for room in &rooms {
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
}
}
self.socket_sids.remove(sid);
self.sockets.remove(sid);
let msg = BusMessage::SocketDisconnect {
namespace: self.namespace.clone(),
sid: sid.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:disconnect", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
let mut result = Vec::new();
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
self.socket_sids.iter().map(|e| e.key().clone()).collect()
} else {
let mut sids = HashSet::new();
for room in &opts.rooms {
if let Some(entry) = self.rooms.get(room) {
sids.extend(entry.value().iter().cloned());
}
}
sids
};
for sid in target_sids {
if opts.except.contains(&sid) {
continue;
}
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
result.push(SocketInfo {
sid: sid.clone(),
namespace: self.namespace.clone(),
rooms,
});
}
Ok(result)
}
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
}
fn server_id(&self) -> &str {
&self.server_id
}
async fn close(&self) -> Result<(), AdapterError> {
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
}
+344
View File
@@ -0,0 +1,344 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use fred::clients::Client;
use fred::interfaces::{KeysInterface, SetsInterface};
use tokio::sync::mpsc;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
use crate::socket::message_bus::MessageBus;
use crate::socket::packet::Packet;
use crate::socket::parser;
use crate::socket::socket::Socket;
const KEY_PREFIX_ROOMS: &str = "socket.io:rooms";
const KEY_PREFIX_SOCKET_ROOMS: &str = "socket.io:socket_rooms";
fn room_key(ns: &str, room: &str) -> String {
format!("{}:{}:{}", KEY_PREFIX_ROOMS, ns, room)
}
fn socket_rooms_key(ns: &str, sid: &str) -> String {
format!("{}:{}:{}", KEY_PREFIX_SOCKET_ROOMS, ns, sid)
}
/// Handle incoming bus messages from other servers.
/// Only performs local state updates — the remote server already wrote to Redis.
async fn handle_bus_message(
msg: BusMessage,
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
server_id: &str,
) {
match msg {
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
if sender_id == server_id {
return;
}
if let Ok(decoded_packet) = parser::decode(&packet) {
on_local_broadcast(&decoded_packet, &opts);
}
}
BusMessage::SocketJoin { server_id: sender_id, .. }
| BusMessage::SocketLeave { server_id: sender_id, .. }
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
// Skip messages from this server; remote server already updated Redis
if sender_id == server_id {
return;
}
// No duplicate Redis writes — the sender already persisted the state change
}
}
}
pub struct RedisAdapter {
message_bus: Arc<dyn MessageBus>,
redis_client: Client,
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
socket_rooms: DashMap<String, HashSet<String>>,
rooms: DashMap<String, HashSet<String>>,
sockets: DashMap<String, Arc<Socket>>,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
}
impl RedisAdapter {
pub fn new(
message_bus: Arc<dyn MessageBus>,
redis_client: Client,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
) -> Self {
Self {
message_bus,
redis_client,
server_id,
namespace,
on_local_broadcast,
room_subscribers: DashMap::new(),
socket_rooms: DashMap::new(),
rooms: DashMap::new(),
sockets: DashMap::new(),
}
}
pub async fn init(&self) -> Result<(), AdapterError> {
let channels = ["broadcast", "join", "leave", "disconnect"];
let prefix = format!("socket.io:{}:", self.namespace);
for channel_type in channels {
let channel = format!("{}{}", prefix, channel_type);
match self.message_bus.subscribe(&channel).await {
Ok(rx) => {
self.room_subscribers.insert(channel_type.to_string(), rx);
}
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
}
}
self.spawn_listener();
Ok(())
}
fn spawn_listener(&self) {
let server_id = self.server_id.clone();
let on_local_broadcast = self.on_local_broadcast.clone();
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
tokio::spawn(async move {
loop {
tokio::select! {
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { join_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
else => break,
}
}
});
}
}
#[async_trait]
impl Adapter for RedisAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
if opts.flags.local_only {
(self.on_local_broadcast)(packet, opts);
return Ok(());
}
let msg = BusMessage::Broadcast {
namespace: packet.namespace.clone(),
packet: parser::encode(packet),
opts: opts.clone(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
(self.on_local_broadcast)(packet, opts);
Ok(())
}
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let rk = room_key(ns, room);
let srk = socket_rooms_key(ns, sid);
self.redis_client
.sadd::<(), _, _>(&rk, sid)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.redis_client
.sadd::<(), _, _>(&srk, room)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.socket_rooms
.entry(sid.to_string())
.and_modify(|set| { set.insert(room.to_string()); })
.or_insert_with(|| HashSet::from([room.to_string()]));
self.rooms
.entry(room.to_string())
.and_modify(|set| { set.insert(sid.to_string()); })
.or_insert_with(|| HashSet::from([sid.to_string()]));
let msg = BusMessage::SocketJoin {
namespace: ns.to_string(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:join", ns), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let rk = room_key(ns, room);
let srk = socket_rooms_key(ns, sid);
self.redis_client
.srem::<(), _, _>(&rk, sid)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.redis_client
.srem::<(), _, _>(&srk, room)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
entry.value_mut().remove(room);
}
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
self.socket_rooms.remove(sid);
}
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
let msg = BusMessage::SocketLeave {
namespace: ns.to_string(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:leave", ns), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
for room in &rooms {
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
let rk = room_key(ns, room);
if let Err(e) = self.redis_client.srem::<(), _, _>(&rk, sid).await {
tracing::warn!("Redis SREM room error: {}", e);
}
}
}
let srk = socket_rooms_key(ns, sid);
self.redis_client
.del::<(), _>(&srk)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.sockets.remove(sid);
let msg = BusMessage::SocketDisconnect {
namespace: ns.to_string(),
sid: sid.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:disconnect", ns), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
let mut result = Vec::new();
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
self.sockets.iter().map(|e| e.key().clone()).collect()
} else {
let mut sids = HashSet::new();
for room in &opts.rooms {
if let Some(entry) = self.rooms.get(room) {
sids.extend(entry.value().iter().cloned());
}
}
sids
};
for sid in target_sids {
if opts.except.contains(&sid) {
continue;
}
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
result.push(SocketInfo {
sid: sid.clone(),
namespace: self.namespace.clone(),
rooms,
});
}
Ok(result)
}
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
}
fn server_id(&self) -> &str {
&self.server_id
}
async fn close(&self) -> Result<(), AdapterError> {
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
}
+31
View File
@@ -0,0 +1,31 @@
pub mod redis;
pub mod nats;
use async_trait::async_trait;
use thiserror::Error;
use tokio::sync::mpsc;
#[derive(Error, Debug)]
pub enum MessageBusError {
#[error("Redis error: {0}")]
Redis(String),
#[error("NATS error: {0}")]
Nats(String),
#[error("Connection closed")]
ConnectionClosed,
#[error("Channel not found: {0}")]
ChannelNotFound(String),
#[error("Serialization error: {0}")]
Serialization(String),
}
#[async_trait]
pub trait MessageBus: Send + Sync + 'static {
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError>;
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError>;
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError>;
async fn close(&self) -> Result<(), MessageBusError>;
}
pub use redis::RedisMessageBus;
pub use nats::NatsMessageBus;
+88
View File
@@ -0,0 +1,88 @@
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::{mpsc, watch};
use crate::socket::message_bus::{MessageBus, MessageBusError};
pub struct NatsMessageBus {
client: async_nats::Client,
shutdowns: DashMap<String, watch::Sender<bool>>,
}
impl NatsMessageBus {
pub async fn new(nats_url: &str) -> Result<Self, MessageBusError> {
let client = async_nats::connect(nats_url)
.await
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
Ok(Self {
client,
shutdowns: DashMap::new(),
})
}
}
#[async_trait]
impl MessageBus for NatsMessageBus {
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
self.client
.publish(channel.to_string(), message.to_vec().into())
.await
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
Ok(())
}
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
let mut subscriber = self.client
.subscribe(channel.to_string())
.await
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
self.shutdowns.insert(channel.to_string(), shutdown_tx);
tokio::spawn(async move {
use futures_util::StreamExt;
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
break;
}
message = subscriber.next() => {
match message {
Some(msg) => {
let data = msg.payload.to_vec();
if tx.send(data).await.is_err() {
break;
}
}
None => break,
}
}
}
}
if let Err(e) = subscriber.unsubscribe().await {
tracing::warn!("NATS unsubscribe error: {}", e);
}
});
Ok(rx)
}
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
if let Some((_, tx)) = self.shutdowns.remove(channel) {
let _ = tx.send(true);
}
Ok(())
}
async fn close(&self) -> Result<(), MessageBusError> {
// Signal all subscribers to shutdown
self.shutdowns.iter().for_each(|entry| {
let _ = entry.value().send(true);
});
self.shutdowns.clear();
Ok(())
}
}
+99
View File
@@ -0,0 +1,99 @@
use async_trait::async_trait;
use fred::clients::{Client, SubscriberClient};
use fred::interfaces::{ClientLike, EventInterface, PubsubInterface};
use fred::prelude::*;
use tokio::sync::mpsc;
use crate::socket::message_bus::{MessageBus, MessageBusError};
pub struct RedisMessageBus {
client: Client,
subscriber: SubscriberClient,
}
impl RedisMessageBus {
pub async fn new(redis_url: &str) -> Result<Self, MessageBusError> {
let config = Config::from_url(redis_url)
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
let client = Client::new(config.clone(), None, None, None);
let subscriber = SubscriberClient::new(config, None, None, None);
// connect() starts the connection task; result is checked by wait_for_connect()
let _ = client.connect().await;
let _ = subscriber.connect().await;
client
.wait_for_connect()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
subscriber
.wait_for_connect()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(Self { client, subscriber })
}
pub fn client(&self) -> &Client {
&self.client
}
}
#[async_trait]
impl MessageBus for RedisMessageBus {
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
self.client
.publish::<(), _, Vec<u8>>(channel, message.to_vec())
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(())
}
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
self.subscriber
.subscribe(channel.to_string())
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
let subscriber = self.subscriber.clone();
let channel_owned = channel.to_string();
let mut message_rx = subscriber.message_rx();
tokio::spawn(async move {
while let Ok(message) = message_rx.recv().await {
if &message.channel == &channel_owned {
let data: Vec<u8> = FromValue::from_value(message.value)
.unwrap_or_default();
if tx.send(data).await.is_err() {
break;
}
}
}
});
Ok(rx)
}
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
self.subscriber
.unsubscribe(channel.to_string())
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(())
}
async fn close(&self) -> Result<(), MessageBusError> {
self.client
.quit()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
self.subscriber
.quit()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(())
}
}
+16
View File
@@ -0,0 +1,16 @@
pub mod adapter;
pub mod message_bus;
pub mod namespace;
pub mod packet;
pub mod parser;
pub mod server;
pub mod session_store;
pub mod socket;
pub use adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, RedisAdapter, NatsAdapter, SocketInfo};
pub use message_bus::{MessageBus, MessageBusError, RedisMessageBus, NatsMessageBus};
pub use namespace::{is_valid_namespace, Namespace, NamespaceManager};
pub use packet::{Packet, PacketType};
pub use server::{SocketServer, SocketServerBuilder};
pub use session_store::{InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait};
pub use socket::Socket;
+239
View File
@@ -0,0 +1,239 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::RwLock;
use crate::socket::adapter::{Adapter, BroadcastOptions, BroadcastFlags};
use crate::socket::packet::Packet;
use crate::socket::socket::Socket;
pub type EventHandler = Arc<dyn Fn(&Socket, &serde_json::Value) + Send + Sync>;
type ConnectHandler = Arc<dyn Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync>;
pub struct Namespace {
pub path: String,
/// Primary storage: socket_sid → Socket
sockets: DashMap<String, Arc<Socket>>,
/// Reverse index: engine_sid → socket_sid (for engine-level lookups)
engine_to_socket: DashMap<String, String>,
handlers: RwLock<HashMap<String, Vec<EventHandler>>>,
connect_handler: RwLock<Option<ConnectHandler>>,
pub(crate) adapter: RwLock<Option<Arc<dyn Adapter>>>,
}
impl Namespace {
pub fn new(path: impl Into<String>) -> Self {
Self {
path: path.into(),
sockets: DashMap::new(),
engine_to_socket: DashMap::new(),
handlers: RwLock::new(HashMap::new()),
connect_handler: RwLock::new(None),
adapter: RwLock::new(None),
}
}
pub async fn set_adapter(&self, adapter: Arc<dyn Adapter>) {
let mut guard = self.adapter.write().await;
*guard = Some(adapter);
}
/// Add a socket to this namespace. Returns Err if the connect handler rejects.
pub async fn add_socket(&self, socket: Arc<Socket>) -> Result<(), String> {
// Run connect handler before adding to storage
let handler = self.connect_handler.read().await;
if let Some(ref h) = *handler {
h(&socket, None)?;
}
drop(handler);
let socket_sid = socket.sid.clone();
let engine_sid = socket.engine_sid.clone();
// Register with adapter (socket_sid → engine_sid mapping)
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
if let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await {
tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e);
}
}
// Store socket by socket_sid, plus reverse index
self.sockets.insert(socket_sid.clone(), socket);
self.engine_to_socket.insert(engine_sid, socket_sid);
Ok(())
}
/// Remove a socket by its socket SID.
pub async fn remove_socket_by_sid(&self, socket_sid: &str) {
if let Some((_, socket)) = self.sockets.remove(socket_sid) {
self.engine_to_socket.remove(&socket.engine_sid);
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
if let Err(e) = adapter.del_all(socket_sid, &self.path).await {
tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e);
}
}
}
}
/// Remove a socket by its engine SID (for engine-level disconnections).
pub async fn remove_socket(&self, engine_sid: &str) {
if let Some((_, socket_sid)) = self.engine_to_socket.remove(engine_sid) {
self.remove_socket_by_sid(&socket_sid).await;
}
}
/// Look up a socket by its socket SID.
pub fn get_socket(&self, socket_sid: &str) -> Option<Arc<Socket>> {
self.sockets.get(socket_sid).map(|r| r.value().clone())
}
/// Look up a socket by its engine SID (reverse lookup).
pub fn get_socket_by_engine_sid(&self, engine_sid: &str) -> Option<Arc<Socket>> {
self.engine_to_socket
.get(engine_sid)
.and_then(|entry| self.sockets.get(entry.value()).map(|r| r.value().clone()))
}
pub fn socket_count(&self) -> usize {
self.sockets.len()
}
pub async fn on_event(&self, event: impl Into<String>, handler: EventHandler) {
let mut handlers = self.handlers.write().await;
handlers.entry(event.into()).or_default().push(handler);
}
pub async fn on_connect<F>(&self, handler: F)
where
F: Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync + 'static,
{
let mut connect_handler = self.connect_handler.write().await;
*connect_handler = Some(Arc::new(handler));
}
pub async fn emit(&self, event: impl Into<String>, data: serde_json::Value) {
let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
let opts = BroadcastOptions::default();
if let Err(e) = adapter.broadcast(&packet, &opts).await {
tracing::warn!("Adapter broadcast error: {}", e);
}
} else {
self.emit_local(&packet);
}
}
pub async fn emit_to_room(&self, room: &str, event: impl Into<String>, data: serde_json::Value) {
let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
let opts = BroadcastOptions {
rooms: HashSet::from([room.to_string()]),
except: HashSet::new(),
flags: BroadcastFlags::default(),
};
if let Err(e) = adapter.broadcast(&packet, &opts).await {
tracing::warn!("Adapter broadcast to room error: {}", e);
}
} else {
self.emit_local(&packet);
}
}
pub fn emit_local(&self, packet: &Packet) {
for entry in self.sockets.iter() {
let socket = entry.value();
if socket.send_packet(packet).is_err() {
tracing::warn!("Failed to send event to socket {}", socket.sid);
}
}
}
pub async fn emit_to(&self, socket_sid: &str, event: impl Into<String>, data: serde_json::Value) {
if let Some(socket) = self.get_socket(socket_sid) {
let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
if socket.send_packet(&packet).is_err() {
tracing::warn!("Failed to send event to socket {}", socket.sid);
}
}
}
pub async fn handle_event(&self, socket: &Socket, event: &str, data: &serde_json::Value) {
let handlers = self.handlers.read().await;
if let Some(event_handlers) = handlers.get(event) {
for handler in event_handlers {
handler(socket, data);
}
}
}
}
pub struct NamespaceManager {
namespaces: DashMap<String, Arc<Namespace>>,
}
impl NamespaceManager {
pub fn new() -> Self {
let manager = Self {
namespaces: DashMap::new(),
};
manager.create_namespace("/");
manager
}
pub fn create_namespace(&self, path: impl Into<String>) -> Arc<Namespace> {
let path = path.into();
let namespace = Arc::new(Namespace::new(&path));
self.namespaces.insert(path.clone(), namespace.clone());
namespace
}
pub fn get_namespace(&self, path: &str) -> Option<Arc<Namespace>> {
self.namespaces.get(path).map(|r| r.value().clone())
}
pub fn get_or_create_namespace(&self, path: &str) -> Arc<Namespace> {
if let Some(ns) = self.get_namespace(path) {
ns
} else {
self.create_namespace(path)
}
}
pub fn remove_namespace(&self, path: &str) {
self.namespaces.remove(path);
}
pub fn namespace_count(&self) -> usize {
self.namespaces.len()
}
pub fn all_namespaces(&self) -> Vec<Arc<Namespace>> {
self.namespaces.iter().map(|e| e.value().clone()).collect()
}
}
impl Default for NamespaceManager {
fn default() -> Self {
Self::new()
}
}
/// Validate a namespace path. Returns true if the path is valid.
/// Rules: must start with '/', max 256 chars, no control characters.
pub fn is_valid_namespace(path: &str) -> bool {
!path.is_empty()
&& path.starts_with('/')
&& path.len() <= 256
&& !path.chars().any(|c| c.is_control())
}
+174
View File
@@ -0,0 +1,174 @@
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum PacketType {
Connect = 0,
Disconnect = 1,
Event = 2,
Ack = 3,
ConnectError = 4,
BinaryEvent = 5,
BinaryAck = 6,
}
impl TryFrom<u8> for PacketType {
type Error = PacketError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Connect),
1 => Ok(Self::Disconnect),
2 => Ok(Self::Event),
3 => Ok(Self::Ack),
4 => Ok(Self::ConnectError),
5 => Ok(Self::BinaryEvent),
6 => Ok(Self::BinaryAck),
_ => Err(PacketError::InvalidType(value)),
}
}
}
impl TryFrom<char> for PacketType {
type Error = PacketError;
fn try_from(value: char) -> Result<Self, Self::Error> {
match value {
'0' => Ok(Self::Connect),
'1' => Ok(Self::Disconnect),
'2' => Ok(Self::Event),
'3' => Ok(Self::Ack),
'4' => Ok(Self::ConnectError),
'5' => Ok(Self::BinaryEvent),
'6' => Ok(Self::BinaryAck),
_ => Err(PacketError::InvalidTypeChar(value)),
}
}
}
#[derive(Debug, Clone)]
pub struct Packet {
pub packet_type: PacketType,
pub namespace: String,
pub data: Option<Value>,
pub id: Option<u64>,
pub attachments: Vec<Vec<u8>>,
/// Expected number of binary attachments (set during decode for binary packets).
/// Used to validate attachment count before assembling the full packet.
pub expected_attachments: Option<usize>,
}
impl Packet {
pub fn connect(namespace: impl Into<String>, data: Option<Value>) -> Self {
Self {
packet_type: PacketType::Connect,
namespace: namespace.into(),
data,
id: None,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn disconnect(namespace: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Disconnect,
namespace: namespace.into(),
data: None,
id: None,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn event(namespace: impl Into<String>, data: Value, id: Option<u64>) -> Self {
Self {
packet_type: PacketType::Event,
namespace: namespace.into(),
data: Some(data),
id,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn ack(namespace: impl Into<String>, data: Value, id: u64) -> Self {
Self {
packet_type: PacketType::Ack,
namespace: namespace.into(),
data: Some(data),
id: Some(id),
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn connect_error(namespace: impl Into<String>, message: impl Into<String>) -> Self {
Self {
packet_type: PacketType::ConnectError,
namespace: namespace.into(),
data: Some(serde_json::json!({ "message": message.into() })),
id: None,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn binary_event(
namespace: impl Into<String>,
data: Value,
id: Option<u64>,
attachments: Vec<Vec<u8>>,
) -> Self {
Self {
packet_type: PacketType::BinaryEvent,
namespace: namespace.into(),
data: Some(data),
id,
attachments,
expected_attachments: None,
}
}
pub fn binary_ack(
namespace: impl Into<String>,
data: Value,
id: u64,
attachments: Vec<Vec<u8>>,
) -> Self {
Self {
packet_type: PacketType::BinaryAck,
namespace: namespace.into(),
data: Some(data),
id: Some(id),
attachments,
expected_attachments: None,
}
}
pub fn has_binary(&self) -> bool {
!self.attachments.is_empty()
}
pub fn attachment_count(&self) -> usize {
self.attachments.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum PacketError {
#[error("invalid packet type: {0}")]
InvalidType(u8),
#[error("invalid packet type char: {0}")]
InvalidTypeChar(char),
#[error("empty packet")]
Empty,
#[error("invalid format: {0}")]
InvalidFormat(String),
#[error("json error: {0}")]
Json(#[from] serde_json::Error),
#[error("missing namespace")]
MissingNamespace,
#[error("invalid attachment count")]
InvalidAttachmentCount,
}
+392
View File
@@ -0,0 +1,392 @@
use serde_json::Value;
use crate::socket::packet::{Packet, PacketError, PacketType};
pub fn encode(packet: &Packet) -> String {
let type_char = packet.packet_type as u8 + b'0';
let mut result = String::new();
result.push(type_char as char);
if packet.has_binary() {
result.push_str(&packet.attachment_count().to_string());
result.push('-');
}
if packet.namespace != "/" {
result.push_str(&packet.namespace);
result.push(',');
}
if let Some(id) = packet.id {
result.push_str(&id.to_string());
}
if let Some(ref data) = packet.data {
if packet.has_binary() {
let data_with_placeholders = replace_binary_with_placeholders(data, packet.attachment_count());
let encoded_data = serde_json::to_string(&data_with_placeholders)
.unwrap_or_else(|e| {
tracing::error!("Failed to serialize socket packet data: {}", e);
"null".to_string()
});
result.push_str(&encoded_data);
} else {
let encoded_data = serde_json::to_string(data)
.unwrap_or_else(|e| {
tracing::error!("Failed to serialize socket packet data: {}", e);
"null".to_string()
});
result.push_str(&encoded_data);
}
}
result
}
pub fn encode_with_attachments(packet: &Packet) -> Vec<Vec<u8>> {
let mut result = Vec::new();
let encoded = encode(packet);
result.push(encoded.into_bytes());
for attachment in &packet.attachments {
result.push(attachment.clone());
}
result
}
pub fn decode(input: &str) -> Result<Packet, PacketError> {
if input.is_empty() {
return Err(PacketError::Empty);
}
let mut chars = input.chars().peekable();
let type_char = chars.next().ok_or(PacketError::Empty)?;
let packet_type = PacketType::try_from(type_char)?;
let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) {
let mut count_str = String::new();
while let Some(&c) = chars.peek() {
if c == '-' {
chars.next();
break;
}
if c.is_ascii_digit() {
count_str.push(c);
chars.next();
} else {
break;
}
}
count_str.parse::<usize>().unwrap_or(0)
} else {
0
};
let remaining: String = chars.collect();
let (namespace, rest) = if let Some(after_slash) = remaining.strip_prefix('/') {
// Check if this is a custom namespace (has a comma separating namespace from data/id)
// or if '/' is just the root namespace prefix followed immediately by data
if let Some(comma_pos) = after_slash.find(',') {
let ns = format!("/{}", &after_slash[..comma_pos]);
let rest = after_slash[comma_pos + 1..].to_string();
(ns, rest)
} else if after_slash.starts_with('[')
|| after_slash.starts_with(|c: char| c.is_ascii_digit())
|| after_slash.is_empty()
{
// '/[' means '/' is the root namespace and '[' starts the data
// '/<digits>' means root namespace followed by ack id
// '/' alone means disconnect on root namespace
("/".to_string(), after_slash.to_string())
} else {
// Non-root namespace without data (e.g., disconnect on custom namespace)
(remaining, String::new())
}
} else {
("/".to_string(), remaining)
};
let (id, data_str) = parse_id_and_data(&rest);
let data = if data_str.is_empty() {
None
} else {
Some(serde_json::from_str(&data_str)?)
};
Ok(Packet {
packet_type,
namespace,
data,
id,
attachments: Vec::new(),
// Store attachment_count for binary packets; actual attachments come via decode_with_attachments
expected_attachments: if attachment_count > 0 { Some(attachment_count) } else { None },
})
}
pub fn decode_with_attachments(
main_packet: &str,
attachments: Vec<Vec<u8>>,
) -> Result<Packet, PacketError> {
let mut packet = decode(main_packet)?;
let expected = packet.expected_attachments.unwrap_or(0);
if expected != attachments.len() {
return Err(PacketError::InvalidAttachmentCount);
}
packet.attachments = attachments;
packet.expected_attachments = None;
if packet.has_binary() {
if let Some(ref data) = packet.data {
packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments));
}
}
Ok(packet)
}
fn parse_id_and_data(input: &str) -> (Option<u64>, String) {
let mut id_str = String::new();
let mut chars = input.chars().peekable();
while let Some(&c) = chars.peek() {
if c.is_ascii_digit() {
id_str.push(c);
chars.next();
} else {
break;
}
}
let id = if id_str.is_empty() {
None
} else {
id_str.parse::<u64>().ok()
};
let data: String = chars.collect();
(id, data)
}
/// Replace binary values in the data with { "_placeholder": true, "num": N } placeholders.
/// This is used when encoding binary events/acks for transmission over text-based transports.
fn replace_binary_with_placeholders(value: &Value, total_attachments: usize) -> Value {
match value {
Value::Array(arr) => {
let mut placeholder_idx = total_attachments; // Start from known count
let new_arr: Vec<Value> = arr
.iter()
.map(|v| replace_binary_with_placeholders_inner(v, &mut placeholder_idx))
.collect();
Value::Array(new_arr)
}
Value::Object(map) => {
let mut placeholder_idx = total_attachments;
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(
k.clone(),
replace_binary_with_placeholders_inner(v, &mut placeholder_idx),
);
}
Value::Object(new_map)
}
_ => value.clone(),
}
}
fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut usize) -> Value {
match value {
Value::Array(arr) => {
let new_arr: Vec<Value> = arr
.iter()
.map(|v| replace_binary_with_placeholders_inner(v, placeholder_idx))
.collect();
Value::Array(new_arr)
}
Value::Object(map) => {
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(
k.clone(),
replace_binary_with_placeholders_inner(v, placeholder_idx),
);
}
Value::Object(new_map)
}
// Binary data would be represented as base64 strings in the initial data;
// in the Socket.IO protocol, binary attachments are separate and referenced by placeholder.
// This function handles the case where the data structure itself contains placeholder markers.
_ => value.clone(),
}
}
fn replace_placeholders_with_binary(value: &Value, attachments: &[Vec<u8>]) -> Value {
match value {
Value::Object(map) => {
// Check if this is a placeholder object: { "_placeholder": true, "num": N }
if let (Some(Value::Bool(true)), Some(Value::Number(num))) =
(map.get("_placeholder"), map.get("num"))
{
if let Some(idx) = num.as_u64() {
if let Some(attachment) = attachments.get(idx as usize) {
return Value::String(base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
attachment,
));
}
}
}
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(k.clone(), replace_placeholders_with_binary(v, attachments));
}
Value::Object(new_map)
}
Value::Array(arr) => Value::Array(
arr.iter()
.map(|v| replace_placeholders_with_binary(v, attachments))
.collect(),
),
_ => value.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_encode_connect() {
let packet = Packet::connect("/", None);
let encoded = encode(&packet);
assert_eq!(encoded, "0");
let packet = Packet::connect("/admin", Some(json!({"sid": "abc"})));
let encoded = encode(&packet);
assert_eq!(encoded, "0/admin,{\"sid\":\"abc\"}");
}
#[test]
fn test_encode_event() {
let packet = Packet::event("/", json!(["foo"]), None);
let encoded = encode(&packet);
assert_eq!(encoded, "2[\"foo\"]");
let packet = Packet::event("/admin", json!(["bar"]), None);
let encoded = encode(&packet);
assert_eq!(encoded, "2/admin,[\"bar\"]");
}
#[test]
fn test_encode_event_with_ack() {
let packet = Packet::event("/", json!(["foo"]), Some(12));
let encoded = encode(&packet);
assert_eq!(encoded, "212[\"foo\"]");
}
#[test]
fn test_encode_ack() {
let packet = Packet::ack("/", json!([]), 12);
let encoded = encode(&packet);
assert_eq!(encoded, "312[]");
}
#[test]
fn test_encode_disconnect() {
let packet = Packet::disconnect("/");
let encoded = encode(&packet);
assert_eq!(encoded, "1");
let packet = Packet::disconnect("/admin");
let encoded = encode(&packet);
assert_eq!(encoded, "1/admin,");
}
#[test]
fn test_encode_connect_error() {
let packet = Packet::connect_error("/", "Not authorized");
let encoded = encode(&packet);
assert_eq!(encoded, "4{\"message\":\"Not authorized\"}");
}
#[test]
fn test_decode_connect() {
let packet = decode("0").unwrap();
assert_eq!(packet.packet_type, PacketType::Connect);
assert_eq!(packet.namespace, "/");
assert!(packet.data.is_none());
}
#[test]
fn test_decode_connect_with_namespace() {
let packet = decode("0/admin,{\"sid\":\"abc\"}").unwrap();
assert_eq!(packet.packet_type, PacketType::Connect);
assert_eq!(packet.namespace, "/admin");
assert!(packet.data.is_some());
}
#[test]
fn test_decode_event() {
let packet = decode("2[\"foo\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.namespace, "/");
assert_eq!(packet.data, Some(json!(["foo"])));
}
#[test]
fn test_decode_event_with_namespace() {
let packet = decode("2/admin,[\"bar\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.namespace, "/admin");
assert_eq!(packet.data, Some(json!(["bar"])));
}
#[test]
fn test_decode_event_with_ack() {
let packet = decode("212[\"foo\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.id, Some(12));
assert_eq!(packet.data, Some(json!(["foo"])));
}
#[test]
fn test_decode_ack() {
let packet = decode("312[]").unwrap();
assert_eq!(packet.packet_type, PacketType::Ack);
assert_eq!(packet.id, Some(12));
}
#[test]
fn test_decode_disconnect() {
let packet = decode("1").unwrap();
assert_eq!(packet.packet_type, PacketType::Disconnect);
assert_eq!(packet.namespace, "/");
}
#[test]
fn test_decode_disconnect_with_namespace() {
let packet = decode("1/admin,").unwrap();
assert_eq!(packet.packet_type, PacketType::Disconnect);
assert_eq!(packet.namespace, "/admin");
}
#[test]
fn test_decode_binary_event_attachment_count() {
let packet = decode("51-[\"baz\",{\"_placeholder\":true,\"num\":0}]").unwrap();
assert_eq!(packet.packet_type, PacketType::BinaryEvent);
assert_eq!(packet.expected_attachments, Some(1));
assert_eq!(packet.namespace, "/");
}
}
+301
View File
@@ -0,0 +1,301 @@
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::mpsc;
use crate::engine::packet::Packet as EnginePacket;
use crate::engine::packet::PacketData as EnginePacketData;
use crate::engine::server::{EngineConfig, EngineServer};
use crate::engine::session::SessionStore;
use crate::socket::adapter::{Adapter, LocalAdapter};
use crate::socket::namespace::NamespaceManager;
use crate::socket::packet::{Packet, PacketType};
use crate::socket::parser;
use crate::socket::socket::Socket;
pub struct SocketServer {
pub engine: Arc<EngineServer>,
pub namespaces: Arc<NamespaceManager>,
pub adapter: Arc<dyn Adapter>,
socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>>,
}
impl SocketServer {
pub fn new(config: EngineConfig) -> Self {
SocketServerBuilder::new(config).build()
}
pub fn builder(config: EngineConfig) -> SocketServerBuilder {
SocketServerBuilder::new(config)
}
pub fn of(&self, path: impl Into<String>) -> Arc<crate::socket::namespace::Namespace> {
self.namespaces.get_or_create_namespace(&path.into())
}
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
self.engine.clone().run_http(addr).await
}
pub fn register_socket(&self, sid: String, tx: mpsc::Sender<Packet>) {
self.socket_txs.insert(sid, tx);
}
pub fn unregister_socket(&self, sid: &str) {
self.socket_txs.remove(sid);
}
}
pub struct SocketServerBuilder {
config: EngineConfig,
adapter: Option<Arc<dyn Adapter>>,
}
impl SocketServerBuilder {
pub fn new(config: EngineConfig) -> Self {
Self {
config,
adapter: None,
}
}
pub fn adapter(mut self, adapter: Arc<dyn Adapter>) -> Self {
self.adapter = Some(adapter);
self
}
pub fn build(self) -> SocketServer {
let namespaces = Arc::new(NamespaceManager::new());
let socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>> = Arc::new(DashMap::new());
let engine_store = SessionStore::new();
let namespaces_clone = namespaces.clone();
let socket_txs_clone = socket_txs.clone();
let engine_store_clone = engine_store.clone();
let adapter: Arc<dyn Adapter> = self.adapter.unwrap_or_else(|| {
let ns_clone = namespaces.clone();
let send_fn = move |engine_sid: &str, packet: &Packet| {
if let Some(ns) = ns_clone.get_namespace(&packet.namespace) {
if let Some(socket) = ns.get_socket_by_engine_sid(engine_sid) {
socket.send_packet(packet).map_err(|e| e.to_string())
} else {
Err(format!(
"Socket with engine_sid {} not found in namespace {}",
engine_sid, packet.namespace
))
}
} else {
Err(format!("Namespace {} not found", packet.namespace))
}
};
Arc::new(LocalAdapter::new(send_fn))
});
let adapter_clone = adapter.clone();
let engine = Arc::new(EngineServer::with_store(
self.config,
engine_store,
move |sid, engine_packet| {
let namespaces = namespaces_clone.clone();
let socket_txs = socket_txs_clone.clone();
let engine_store = engine_store_clone.clone();
let adapter = adapter_clone.clone();
tokio::spawn(async move {
handle_engine_message(
sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter,
).await;
});
},
));
let server = SocketServer {
engine,
namespaces,
adapter,
socket_txs,
};
for ns in server.namespaces.all_namespaces() {
let adapter_ref = server.adapter.clone();
tokio::spawn(async move {
ns.set_adapter(adapter_ref).await;
});
}
server
}
}
async fn handle_engine_message(
engine_sid: String,
engine_packet: EnginePacket,
namespaces: &Arc<NamespaceManager>,
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
engine_store: &SessionStore,
adapter: &Arc<dyn Adapter>,
) {
if let EnginePacketData::Text(ref text) = engine_packet.data {
if let Ok(socket_packet) = parser::decode(text) {
match socket_packet.packet_type {
PacketType::Connect => {
handle_connect(&engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter).await;
}
PacketType::Disconnect => {
handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs);
}
PacketType::Event => {
handle_event(&engine_sid, &socket_packet, namespaces);
}
PacketType::Ack => {
handle_ack(&engine_sid, &socket_packet);
}
_ => {}
}
}
}
}
async fn handle_connect(
engine_sid: &str,
packet: &Packet,
namespaces: &Arc<NamespaceManager>,
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
engine_store: &SessionStore,
adapter: &Arc<dyn Adapter>,
) {
// Validate namespace path to prevent DoS via arbitrary namespace creation
if !crate::socket::namespace::is_valid_namespace(&packet.namespace) {
tracing::warn!("Rejected connect with invalid namespace: {}", packet.namespace);
return;
}
let namespace = namespaces.get_or_create_namespace(&packet.namespace);
// Ensure newly created namespaces get the shared adapter
{
let ns_adapter = namespace.adapter.read().await;
if ns_adapter.is_none() {
drop(ns_adapter);
let adapter_ref = adapter.clone();
let ns_clone = namespace.clone();
tokio::spawn(async move {
ns_clone.set_adapter(adapter_ref).await;
});
}
}
let socket_sid = crate::engine::session::generate_sid();
let (tx, mut rx) = mpsc::channel::<Packet>(256);
socket_txs.insert(socket_sid.clone(), tx.clone());
let socket = Arc::new(Socket::new(
socket_sid.clone(),
packet.namespace.clone(),
engine_sid.to_string(),
tx,
));
// Run connect handler and add to namespace.
// If the handler rejects, clean up and do NOT send a Connect response.
if let Err(msg) = namespace.add_socket(socket.clone()).await {
tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg);
socket_txs.remove(&socket_sid);
return;
}
// Connect handler passed — spawn forwarding task
let engine_store_clone = engine_store.clone();
let engine_sid_clone = engine_sid.to_string();
let socket_sid_clone = socket_sid.clone();
let socket_txs_clone = socket_txs.clone();
let namespace_clone = namespace.clone();
tokio::spawn(async move {
while let Some(socket_packet) = rx.recv().await {
let encoded = parser::encode(&socket_packet);
let engine_packet = EnginePacket::message_text(encoded);
if let Some(session) = engine_store_clone.get(&engine_sid_clone) {
let mut s = session.write().await;
if s.state == crate::engine::session::SessionState::Closed {
break;
}
s.push_packet(engine_packet);
} else {
break;
}
}
// Forwarding task ended — ensure socket is cleaned up from namespace
socket_txs_clone.remove(&socket_sid_clone);
namespace_clone.remove_socket_by_sid(&socket_sid_clone).await;
});
// Send Connect response (only after handler passed)
let response = Packet::connect(
&socket.namespace,
Some(serde_json::json!({ "sid": &socket.sid })),
);
if socket.send_packet(&response).is_err() {
tracing::warn!("Failed to send connect response to socket {}", socket.sid);
}
}
fn handle_disconnect(
engine_sid: &str,
packet: &Packet,
namespaces: &Arc<NamespaceManager>,
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
) {
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
// Look up socket by engine_sid, then remove by socket_sid
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
socket_txs.remove(&socket.sid);
let socket_sid = socket.sid.clone();
let ns_clone = namespace.clone();
tokio::spawn(async move {
ns_clone.remove_socket_by_sid(&socket_sid).await;
});
}
}
}
fn handle_event(
engine_sid: &str,
packet: &Packet,
namespaces: &Arc<NamespaceManager>,
) {
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
if let Some(ref data) = packet.data {
if let Some(arr) = data.as_array() {
if let Some(event) = arr.first().and_then(|v| v.as_str()) {
let event_data = if arr.len() > 1 {
serde_json::Value::Array(arr[1..].to_vec())
} else {
serde_json::Value::Null
};
let namespace_clone = namespace.clone();
let event = event.to_string();
let socket_clone = socket.clone();
tokio::spawn(async move {
namespace_clone
.handle_event(&socket_clone, &event, &event_data)
.await;
});
}
}
}
}
}
}
fn handle_ack(engine_sid: &str, packet: &Packet) {
tracing::debug!(
"Received ACK from {} for namespace {} with id {:?}",
engine_sid,
packet.namespace,
packet.id
);
}
+88
View File
@@ -0,0 +1,88 @@
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use dashmap::DashMap;
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
pub struct InMemorySessionStore {
sessions: Arc<DashMap<String, SessionInfo>>,
}
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemorySessionStore {
fn default() -> Self {
Self::new()
}
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[async_trait]
impl SessionStoreTrait for InMemorySessionStore {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
let info = SessionInfo {
sid: sid.to_string(),
transport: transport.to_string(),
state: "connecting".to_string(),
server_id: server_id.to_string(),
created_at: now_millis(),
last_ping: now_millis(),
};
self.sessions.insert(sid.to_string(), info);
Ok(())
}
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
Ok(self.sessions.get(sid).map(|r| r.value().clone()))
}
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
if let Some(mut entry) = self.sessions.get_mut(sid) {
entry.value_mut().state = state.to_string();
Ok(())
} else {
Err(SessionError::NotFound(sid.to_string()))
}
}
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
if let Some(mut entry) = self.sessions.get_mut(sid) {
entry.value_mut().transport = transport.to_string();
Ok(())
} else {
Err(SessionError::NotFound(sid.to_string()))
}
}
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
if let Some(mut entry) = self.sessions.get_mut(sid) {
entry.value_mut().last_ping = now_millis();
Ok(())
} else {
Err(SessionError::NotFound(sid.to_string()))
}
}
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
self.sessions.remove(sid);
Ok(())
}
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
Ok(self.sessions.contains_key(sid))
}
}
+41
View File
@@ -0,0 +1,41 @@
pub mod memory;
pub mod redis;
use async_trait::async_trait;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum SessionError {
#[error("Redis error: {0}")]
Redis(String),
#[error("Session not found: {0}")]
NotFound(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Session expired: {0}")]
Expired(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SessionInfo {
pub sid: String,
pub transport: String,
pub state: String,
pub server_id: String,
pub created_at: u64,
pub last_ping: u64,
}
#[async_trait]
pub trait SessionStoreTrait: Send + Sync + 'static {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError>;
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError>;
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>;
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>;
async fn update_ping(&self, sid: &str) -> Result<(), SessionError>;
async fn remove(&self, sid: &str) -> Result<(), SessionError>;
async fn exists(&self, sid: &str) -> Result<bool, SessionError>;
}
pub use memory::InMemorySessionStore;
pub use redis::RedisSessionStore;
+164
View File
@@ -0,0 +1,164 @@
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use fred::prelude::*;
use crate::socket::message_bus::redis::RedisMessageBus;
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
const DEFAULT_TTL_SECS: u64 = 60;
const KEY_PREFIX: &str = "socket.io:session";
pub struct RedisSessionStore {
client: Client,
ttl_secs: u64,
}
impl RedisSessionStore {
pub fn new(bus: &RedisMessageBus, ttl_secs: Option<u64>) -> Self {
Self {
client: bus.client().clone(),
ttl_secs: ttl_secs.unwrap_or(DEFAULT_TTL_SECS),
}
}
fn key(&self, sid: &str) -> String {
format!("{}:{}", KEY_PREFIX, sid)
}
}
#[async_trait]
impl SessionStoreTrait for RedisSessionStore {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
let key = self.key(sid);
let now = now_millis();
// Batch all fields in a single HSET call for efficiency
let fields: Vec<(&str, String)> = vec![
("sid", sid.to_string()),
("transport", transport.to_string()),
("state", "connecting".to_string()),
("server_id", server_id.to_string()),
("created_at", now.to_string()),
("last_ping", now.to_string()),
];
self.client
.hset::<(), _, _>(&key, fields)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
let key = self.key(sid);
// Use hgetall directly — if the key doesn't exist Redis returns an empty map.
// This avoids the TOCTOU race between EXISTS and HGETALL.
let values: std::collections::HashMap<String, String> = self.client
.hgetall::<std::collections::HashMap<String, String>, _>(&key)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
if values.is_empty() {
return Ok(None);
}
let info = SessionInfo {
sid: values.get("sid").cloned().unwrap_or_default(),
transport: values.get("transport").cloned().unwrap_or_default(),
state: values.get("state").cloned().unwrap_or_default(),
server_id: values.get("server_id").cloned().unwrap_or_default(),
created_at: values.get("created_at").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
last_ping: values.get("last_ping").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
};
Ok(Some(info))
}
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
let key = self.key(sid);
// Use HSET (not HSETNX) to overwrite existing fields
self.client
.hset::<(), _, _>(&key, ("state", state))
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
let key = self.key(sid);
// Use HSET (not HSETNX) to overwrite existing fields
self.client
.hset::<(), _, _>(&key, ("transport", transport))
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
let key = self.key(sid);
let now = now_millis();
// Use HSET (not HSETNX) to overwrite existing fields
self.client
.hset::<(), _, _>(&key, ("last_ping", now.to_string()))
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
let key = self.key(sid);
self.client
.del::<(), _>(&key)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
let key = self.key(sid);
let exists: bool = self.client
.exists::<bool, _>(&key)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(exists)
}
}
+72
View File
@@ -0,0 +1,72 @@
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
use crate::socket::packet::Packet;
pub struct Socket {
pub sid: String,
pub namespace: String,
pub engine_sid: String,
ack_id: AtomicU64,
tx: mpsc::Sender<Packet>,
}
impl Socket {
pub fn new(
sid: String,
namespace: String,
engine_sid: String,
tx: mpsc::Sender<Packet>,
) -> Self {
Self {
sid,
namespace,
engine_sid,
ack_id: AtomicU64::new(0),
tx,
}
}
pub fn next_ack_id(&self) -> u64 {
self.ack_id.fetch_add(1, Ordering::SeqCst)
}
pub fn send_packet(&self, packet: &Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
self.tx.try_send(packet.clone())
}
pub fn emit(&self, event: impl Into<String>, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::event(
&self.namespace,
serde_json::json!([event.into(), data]),
None,
);
self.send_packet(&packet)
}
pub fn emit_with_ack(
&self,
event: impl Into<String>,
data: serde_json::Value,
) -> Result<u64, mpsc::error::TrySendError<Packet>> {
let ack_id = self.next_ack_id();
let packet = Packet::event(
&self.namespace,
serde_json::json!([event.into(), data]),
Some(ack_id),
);
self.send_packet(&packet)?;
Ok(ack_id)
}
pub fn disconnect(&self) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::disconnect(&self.namespace);
self.send_packet(&packet)
}
pub fn send_ack(&self, id: u64, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::ack(&self.namespace, data, id);
self.send_packet(&packet)
}
}
+344
View File
@@ -0,0 +1,344 @@
use std::collections::HashSet;
use std::sync::Arc;
use imks::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, SocketInfo};
use imks::socket::packet::Packet;
use imks::socket::session_store::{InMemorySessionStore, SessionInfo, SessionStoreTrait};
#[tokio::test]
async fn test_local_adapter_add_and_del() {
let sent_packets: dashmap::DashMap<String, Vec<Packet>> = dashmap::DashMap::new();
let sent_packets_clone = sent_packets.clone();
let send_fn = move |engine_sid: &str, packet: &Packet| {
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
Ok(())
};
let adapter = LocalAdapter::new(send_fn);
adapter.add("sid1", "room1", "/").await.unwrap();
adapter.add("sid1", "room2", "/").await.unwrap();
adapter.add("sid2", "room1", "/").await.unwrap();
let rooms = adapter.socket_rooms("sid1").await.unwrap();
assert!(rooms.contains("room1"));
assert!(rooms.contains("room2"));
adapter.del("sid1", "room1", "/").await.unwrap();
let rooms = adapter.socket_rooms("sid1").await.unwrap();
assert!(!rooms.contains("room1"));
assert!(rooms.contains("room2"));
}
#[tokio::test]
async fn test_local_adapter_del_all() {
let send_fn = move |_engine_sid: &str, _packet: &Packet| Ok(());
let adapter = LocalAdapter::new(send_fn);
adapter.add("sid1", "room1", "/").await.unwrap();
adapter.add("sid1", "room2", "/").await.unwrap();
adapter.del_all("sid1", "/").await.unwrap();
let rooms = adapter.socket_rooms("sid1").await.unwrap();
assert!(rooms.is_empty());
}
#[tokio::test]
async fn test_local_adapter_register_and_broadcast() {
let sent_packets: Arc<dashmap::DashMap<String, Vec<Packet>>> = Arc::new(dashmap::DashMap::new());
let sent_packets_clone = sent_packets.clone();
let send_fn = move |engine_sid: &str, packet: &Packet| {
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
Ok(())
};
let adapter = LocalAdapter::new(send_fn);
// Register socket_sid → engine_sid mapping
adapter.register("sid1", "engine1", "/").await.unwrap();
adapter.register("sid2", "engine2", "/").await.unwrap();
let packet = Packet::event("/", serde_json::json!(["test", "hello"]), None);
let opts = BroadcastOptions::default();
adapter.broadcast(&packet, &opts).await.unwrap();
assert!(sent_packets.contains_key("engine1"));
assert!(sent_packets.contains_key("engine2"));
assert_eq!(sent_packets.len(), 2);
}
#[tokio::test]
async fn test_local_adapter_broadcast_to_room() {
let sent_packets: Arc<dashmap::DashMap<String, Vec<Packet>>> = Arc::new(dashmap::DashMap::new());
let sent_packets_clone = sent_packets.clone();
let send_fn = move |engine_sid: &str, packet: &Packet| {
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
Ok(())
};
let adapter = LocalAdapter::new(send_fn);
adapter.register("sid1", "engine1", "/").await.unwrap();
adapter.register("sid2", "engine2", "/").await.unwrap();
adapter.add("sid1", "room1", "/").await.unwrap();
adapter.add("sid2", "room2", "/").await.unwrap();
let packet = Packet::event("/", serde_json::json!(["test", "hello"]), None);
let opts = BroadcastOptions {
rooms: HashSet::from(["room1".to_string()]),
except: HashSet::new(),
flags: BroadcastFlags::default(),
};
adapter.broadcast(&packet, &opts).await.unwrap();
assert!(sent_packets.contains_key("engine1"));
assert!(!sent_packets.contains_key("engine2"));
}
#[tokio::test]
async fn test_local_adapter_broadcast_except() {
let sent_packets: Arc<dashmap::DashMap<String, Vec<Packet>>> = Arc::new(dashmap::DashMap::new());
let sent_packets_clone = sent_packets.clone();
let send_fn = move |engine_sid: &str, packet: &Packet| {
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
Ok(())
};
let adapter = LocalAdapter::new(send_fn);
adapter.register("sid1", "engine1", "/").await.unwrap();
adapter.register("sid2", "engine2", "/").await.unwrap();
let packet = Packet::event("/", serde_json::json!(["test", "hello"]), None);
let opts = BroadcastOptions {
rooms: HashSet::new(),
except: HashSet::from(["sid1".to_string()]),
flags: BroadcastFlags::default(),
};
adapter.broadcast(&packet, &opts).await.unwrap();
assert!(!sent_packets.contains_key("engine1"));
assert!(sent_packets.contains_key("engine2"));
}
#[tokio::test]
async fn test_local_adapter_fetch_sockets() {
let send_fn = move |_engine_sid: &str, _packet: &Packet| Ok(());
let adapter = LocalAdapter::new(send_fn);
adapter.register("sid1", "engine1", "/").await.unwrap();
adapter.register("sid2", "engine2", "/").await.unwrap();
adapter.add("sid1", "room1", "/").await.unwrap();
adapter.add("sid2", "room2", "/").await.unwrap();
let opts = BroadcastOptions::default();
let sockets = adapter.fetch_sockets(&opts).await.unwrap();
assert_eq!(sockets.len(), 2);
}
#[tokio::test]
async fn test_local_adapter_server_id_unique() {
let send_fn1 = move |_engine_sid: &str, _packet: &Packet| Ok(());
let send_fn2 = move |_engine_sid: &str, _packet: &Packet| Ok(());
let adapter1 = LocalAdapter::new(send_fn1);
let adapter2 = LocalAdapter::new(send_fn2);
assert_ne!(adapter1.server_id(), adapter2.server_id());
}
#[tokio::test]
async fn test_in_memory_session_store() {
let store = InMemorySessionStore::new();
store.create("sid1", "polling", "server1").await.unwrap();
assert!(store.exists("sid1").await.unwrap());
let info = store.get("sid1").await.unwrap().unwrap();
assert_eq!(info.sid, "sid1");
assert_eq!(info.transport, "polling");
assert_eq!(info.state, "connecting");
assert_eq!(info.server_id, "server1");
store.set_state("sid1", "open").await.unwrap();
let info = store.get("sid1").await.unwrap().unwrap();
assert_eq!(info.state, "open");
store.set_transport("sid1", "websocket").await.unwrap();
let info = store.get("sid1").await.unwrap().unwrap();
assert_eq!(info.transport, "websocket");
store.update_ping("sid1").await.unwrap();
let info = store.get("sid1").await.unwrap().unwrap();
assert!(info.last_ping > 0);
store.remove("sid1").await.unwrap();
assert!(!store.exists("sid1").await.unwrap());
}
#[tokio::test]
async fn test_in_memory_session_store_not_found() {
let store = InMemorySessionStore::new();
let result = store.get("nonexistent").await.unwrap();
assert!(result.is_none());
let result = store.set_state("nonexistent", "open").await;
assert!(result.is_err());
}
#[test]
fn test_bus_message_serialization() {
let msg = BusMessage::Broadcast {
namespace: "/".to_string(),
packet: "2[\"hello\"]".to_string(),
opts: BroadcastOptions::default(),
server_id: "server-1".to_string(),
};
let encoded = serde_json::to_vec(&msg).unwrap();
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_bus_message_socket_join() {
let msg = BusMessage::SocketJoin {
namespace: "/admin".to_string(),
sid: "sid-1".to_string(),
room: "room-1".to_string(),
server_id: "server-1".to_string(),
};
let encoded = serde_json::to_vec(&msg).unwrap();
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_bus_message_socket_leave() {
let msg = BusMessage::SocketLeave {
namespace: "/".to_string(),
sid: "sid-1".to_string(),
room: "room-1".to_string(),
server_id: "server-1".to_string(),
};
let encoded = serde_json::to_vec(&msg).unwrap();
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_bus_message_socket_disconnect() {
let msg = BusMessage::SocketDisconnect {
namespace: "/".to_string(),
sid: "sid-1".to_string(),
server_id: "server-1".to_string(),
};
let encoded = serde_json::to_vec(&msg).unwrap();
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_broadcast_options_serialization() {
let opts = BroadcastOptions {
rooms: HashSet::from(["room1".to_string(), "room2".to_string()]),
except: HashSet::from(["sid1".to_string()]),
flags: BroadcastFlags {
local_only: true,
broadcast: false,
},
};
let encoded = serde_json::to_vec(&opts).unwrap();
let decoded: BroadcastOptions = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded.rooms, opts.rooms);
assert_eq!(decoded.except, opts.except);
assert_eq!(decoded.flags.local_only, opts.flags.local_only);
}
#[test]
fn test_socket_info_serialization() {
let info = SocketInfo {
sid: "sid-1".to_string(),
namespace: "/admin".to_string(),
rooms: HashSet::from(["room1".to_string()]),
};
let encoded = serde_json::to_vec(&info).unwrap();
let decoded: SocketInfo = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded.sid, info.sid);
assert_eq!(decoded.namespace, info.namespace);
}
#[test]
fn test_session_info_serialization() {
let info = SessionInfo {
sid: "sid-1".to_string(),
transport: "websocket".to_string(),
state: "open".to_string(),
server_id: "server-1".to_string(),
created_at: 1234567890,
last_ping: 1234567900,
};
let encoded = serde_json::to_vec(&info).unwrap();
let decoded: SessionInfo = serde_json::from_slice(&encoded).unwrap();
assert_eq!(decoded.sid, info.sid);
assert_eq!(decoded.transport, info.transport);
assert_eq!(decoded.state, info.state);
}
#[test]
fn test_message_bus_error_display() {
let err = imks::socket::message_bus::MessageBusError::Redis("connection refused".to_string());
assert_eq!(format!("{}", err), "Redis error: connection refused");
let err = imks::socket::message_bus::MessageBusError::Nats("timeout".to_string());
assert_eq!(format!("{}", err), "NATS error: timeout");
}
#[test]
fn test_adapter_error_display() {
let err = AdapterError::Redis("SADD failed".to_string());
assert_eq!(format!("{}", err), "Redis error: SADD failed");
let err = AdapterError::Nats("publish failed".to_string());
assert_eq!(format!("{}", err), "NATS error: publish failed");
let err = AdapterError::Serialization("json error".to_string());
assert_eq!(format!("{}", err), "Serialization error: json error");
}
#[test]
fn test_session_error_display() {
let err = imks::socket::session_store::SessionError::Redis("timeout".to_string());
assert_eq!(format!("{}", err), "Redis error: timeout");
let err = imks::socket::session_store::SessionError::NotFound("sid-1".to_string());
assert_eq!(format!("{}", err), "Session not found: sid-1");
}
#[test]
fn test_is_valid_namespace() {
assert!(imks::socket::namespace::is_valid_namespace("/"));
assert!(imks::socket::namespace::is_valid_namespace("/admin"));
assert!(imks::socket::namespace::is_valid_namespace("/chat/room1"));
assert!(!imks::socket::namespace::is_valid_namespace(""));
assert!(!imks::socket::namespace::is_valid_namespace("admin"));
assert!(!imks::socket::namespace::is_valid_namespace(&"/".repeat(257)));
}
+158
View File
@@ -0,0 +1,158 @@
use imks::engine::codec;
use imks::engine::packet::{HandshakeData, Packet, PacketData, PacketType};
#[test]
fn test_engine_io_handshake_encoding() {
let handshake = HandshakeData {
sid: "lv_VI97HAXpY6yYWAAAC".to_string(),
upgrades: vec!["websocket".to_string()],
ping_interval: 25000,
ping_timeout: 20000,
max_payload: 1000000,
};
let packet = Packet::open(&handshake);
let encoded = codec::encode_packet(&packet);
assert!(encoded.starts_with('0'));
assert!(encoded.contains("\"sid\":\"lv_VI97HAXpY6yYWAAAC\""));
assert!(encoded.contains("\"upgrades\":[\"websocket\"]"));
assert!(encoded.contains("\"pingInterval\":25000"));
assert!(encoded.contains("\"pingTimeout\":20000"));
assert!(encoded.contains("\"maxPayload\":1000000"));
}
#[test]
fn test_engine_io_packet_types() {
let open = Packet::open(&HandshakeData {
sid: "test".to_string(),
upgrades: vec![],
ping_interval: 25000,
ping_timeout: 20000,
max_payload: 1000000,
});
assert_eq!(open.packet_type, PacketType::Open);
let close = Packet::close();
assert_eq!(close.packet_type, PacketType::Close);
let ping = Packet::ping("test");
assert_eq!(ping.packet_type, PacketType::Ping);
let pong = Packet::pong("test");
assert_eq!(pong.packet_type, PacketType::Pong);
let msg = Packet::message_text("hello");
assert_eq!(msg.packet_type, PacketType::Message);
let upgrade = Packet::upgrade();
assert_eq!(upgrade.packet_type, PacketType::Upgrade);
let noop = Packet::noop();
assert_eq!(noop.packet_type, PacketType::Noop);
}
#[test]
fn test_engine_io_polling_payload() {
let packets = vec![
Packet::message_text("hello"),
Packet::ping(""),
Packet::message_text("world"),
];
let encoded = codec::encode_payload(&packets);
assert_eq!(encoded, "4hello\x1e2\x1e4world");
let decoded = codec::decode_payload(&encoded).unwrap();
assert_eq!(decoded.len(), 3);
assert_eq!(decoded[0].packet_type, PacketType::Message);
assert_eq!(decoded[1].packet_type, PacketType::Ping);
assert_eq!(decoded[2].packet_type, PacketType::Message);
}
#[test]
fn test_engine_io_binary_encoding() {
let packet = Packet::message_binary(vec![0x01, 0x02, 0x03, 0x04]);
let encoded = codec::encode_packet(&packet);
assert_eq!(encoded, "4bAQIDBA==");
let decoded = codec::decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Binary(vec![0x01, 0x02, 0x03, 0x04]));
}
#[test]
fn test_engine_io_webtransport_header_encoding() {
let header = codec::encode_webtransport_header(6, false);
assert_eq!(header, vec![0x06]);
let header = codec::encode_webtransport_header(200, true);
assert_eq!(header.len(), 3);
assert_eq!(header[0], 0x80 | 126);
let header = codec::encode_webtransport_header(70000, false);
assert_eq!(header.len(), 9);
assert_eq!(header[0], 127);
}
#[test]
fn test_engine_io_webtransport_header_decoding() {
let header = vec![0x06];
let (len, is_binary) = codec::decode_webtransport_header(&header).unwrap();
assert_eq!(len, 6);
assert!(!is_binary);
let header = vec![0x80 | 126, 0x00, 0xC8];
let (len, is_binary) = codec::decode_webtransport_header(&header).unwrap();
assert_eq!(len, 200);
assert!(is_binary);
}
#[test]
fn test_engine_io_probe_ping_pong() {
let ping = Packet::ping("probe");
let encoded = codec::encode_packet(&ping);
assert_eq!(encoded, "2probe");
let decoded = codec::decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Ping);
assert_eq!(decoded.data, PacketData::Text("probe".to_string()));
let pong = Packet::pong("probe");
let encoded = codec::encode_packet(&pong);
assert_eq!(encoded, "3probe");
let decoded = codec::decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Pong);
assert_eq!(decoded.data, PacketData::Text("probe".to_string()));
}
#[test]
fn test_engine_io_upgrade_packet() {
let upgrade = Packet::upgrade();
let encoded = codec::encode_packet(&upgrade);
assert_eq!(encoded, "5");
let decoded = codec::decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Upgrade);
}
#[test]
fn test_engine_io_noop_packet() {
let noop = Packet::noop();
let encoded = codec::encode_packet(&noop);
assert_eq!(encoded, "6");
let decoded = codec::decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Noop);
}
#[test]
fn test_engine_io_close_packet() {
let close = Packet::close();
let encoded = codec::encode_packet(&close);
assert_eq!(encoded, "1");
let decoded = codec::decode_packet(&encoded).unwrap();
assert_eq!(decoded.packet_type, PacketType::Close);
}
+129
View File
@@ -0,0 +1,129 @@
use imks::engine::session::{generate_sid, SessionState, SessionStore, TransportType};
#[test]
fn test_session_store_create_and_get() {
let store = SessionStore::new();
let sid = generate_sid();
let _rx = store.create(sid.clone(), TransportType::Polling);
assert!(store.exists(&sid));
assert!(store.get(&sid).is_some());
assert_eq!(store.len(), 1);
}
#[test]
fn test_session_store_remove() {
let store = SessionStore::new();
let sid = generate_sid();
let _rx = store.create(sid.clone(), TransportType::Polling);
assert!(store.exists(&sid));
store.remove(&sid);
assert!(!store.exists(&sid));
assert!(store.get(&sid).is_none());
}
#[test]
fn test_session_store_multiple_sessions() {
let store = SessionStore::new();
let sid1 = generate_sid();
let sid2 = generate_sid();
let sid3 = generate_sid();
let _rx1 = store.create(sid1.clone(), TransportType::Polling);
let _rx2 = store.create(sid2.clone(), TransportType::WebSocket);
let _rx3 = store.create(sid3.clone(), TransportType::WebTransport);
assert_eq!(store.len(), 3);
assert!(store.exists(&sid1));
assert!(store.exists(&sid2));
assert!(store.exists(&sid3));
}
#[test]
fn test_generate_sid_uniqueness() {
let sids: Vec<String> = (0..100).map(|_| generate_sid()).collect();
let unique_sids: std::collections::HashSet<String> = sids.into_iter().collect();
assert_eq!(unique_sids.len(), 100);
}
#[test]
fn test_generate_sid_format() {
let sid = generate_sid();
assert_eq!(sid.len(), 20);
assert!(sid.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'));
}
#[tokio::test]
async fn test_session_state_transitions() {
let store = SessionStore::new();
let sid = generate_sid();
let _rx = store.create(sid.clone(), TransportType::Polling);
if let Some(session) = store.get(&sid) {
let mut session = session.write().await;
assert_eq!(session.state, SessionState::Connecting);
session.set_state(SessionState::Open);
assert_eq!(session.state, SessionState::Open);
session.set_state(SessionState::Upgrading);
assert_eq!(session.state, SessionState::Upgrading);
session.set_state(SessionState::Open);
assert_eq!(session.state, SessionState::Open);
session.set_state(SessionState::Closing);
assert_eq!(session.state, SessionState::Closing);
session.set_state(SessionState::Closed);
assert_eq!(session.state, SessionState::Closed);
}
}
#[tokio::test]
async fn test_session_transport_change() {
let store = SessionStore::new();
let sid = generate_sid();
let _rx = store.create(sid.clone(), TransportType::Polling);
if let Some(session) = store.get(&sid) {
let mut session = session.write().await;
assert_eq!(session.transport, TransportType::Polling);
session.set_transport(TransportType::WebSocket);
assert_eq!(session.transport, TransportType::WebSocket);
}
}
#[tokio::test]
async fn test_session_ping_update() {
let store = SessionStore::new();
let sid = generate_sid();
let _rx = store.create(sid.clone(), TransportType::Polling);
if let Some(session) = store.get(&sid) {
let mut session = session.write().await;
let initial_ping = session.last_ping;
std::thread::sleep(std::time::Duration::from_millis(10));
session.update_ping();
assert!(session.last_ping > initial_ping);
}
}
#[test]
fn test_transport_type_as_str() {
assert_eq!(TransportType::Polling.as_str(), "polling");
assert_eq!(TransportType::WebSocket.as_str(), "websocket");
assert_eq!(TransportType::WebTransport.as_str(), "webtransport");
}
+203
View File
@@ -0,0 +1,203 @@
use imks::socket::packet::{Packet, PacketType};
use imks::socket::parser;
use serde_json::json;
#[test]
fn test_socket_io_connect_encoding() {
let packet = Packet::connect("/", None);
let encoded = parser::encode(&packet);
assert_eq!(encoded, "0");
let packet = Packet::connect("/admin", Some(json!({"sid": "abc123"})));
let encoded = parser::encode(&packet);
assert_eq!(encoded, "0/admin,{\"sid\":\"abc123\"}");
}
#[test]
fn test_socket_io_disconnect_encoding() {
let packet = Packet::disconnect("/");
let encoded = parser::encode(&packet);
assert_eq!(encoded, "1");
let packet = Packet::disconnect("/admin");
let encoded = parser::encode(&packet);
assert_eq!(encoded, "1/admin,");
}
#[test]
fn test_socket_io_event_encoding() {
let packet = Packet::event("/", json!(["foo"]), None);
let encoded = parser::encode(&packet);
assert_eq!(encoded, "2[\"foo\"]");
let packet = Packet::event("/admin", json!(["bar"]), None);
let encoded = parser::encode(&packet);
assert_eq!(encoded, "2/admin,[\"bar\"]");
}
#[test]
fn test_socket_io_event_with_ack_encoding() {
let packet = Packet::event("/", json!(["foo"]), Some(12));
let encoded = parser::encode(&packet);
assert_eq!(encoded, "212[\"foo\"]");
let packet = Packet::event("/admin", json!(["bar"]), Some(13));
let encoded = parser::encode(&packet);
assert_eq!(encoded, "2/admin,13[\"bar\"]");
}
#[test]
fn test_socket_io_ack_encoding() {
let packet = Packet::ack("/", json!([]), 12);
let encoded = parser::encode(&packet);
assert_eq!(encoded, "312[]");
let packet = Packet::ack("/admin", json!(["bar"]), 13);
let encoded = parser::encode(&packet);
assert_eq!(encoded, "3/admin,13[\"bar\"]");
}
#[test]
fn test_socket_io_connect_error_encoding() {
let packet = Packet::connect_error("/", "Not authorized");
let encoded = parser::encode(&packet);
assert_eq!(encoded, "4{\"message\":\"Not authorized\"}");
}
#[test]
fn test_socket_io_connect_decoding() {
let packet = parser::decode("0").unwrap();
assert_eq!(packet.packet_type, PacketType::Connect);
assert_eq!(packet.namespace, "/");
assert!(packet.data.is_none());
}
#[test]
fn test_socket_io_connect_with_namespace_decoding() {
let packet = parser::decode("0/admin,{\"sid\":\"abc\"}").unwrap();
assert_eq!(packet.packet_type, PacketType::Connect);
assert_eq!(packet.namespace, "/admin");
assert!(packet.data.is_some());
}
#[test]
fn test_socket_io_disconnect_decoding() {
let packet = parser::decode("1").unwrap();
assert_eq!(packet.packet_type, PacketType::Disconnect);
assert_eq!(packet.namespace, "/");
}
#[test]
fn test_socket_io_disconnect_with_namespace_decoding() {
let packet = parser::decode("1/admin,").unwrap();
assert_eq!(packet.packet_type, PacketType::Disconnect);
assert_eq!(packet.namespace, "/admin");
}
#[test]
fn test_socket_io_event_decoding() {
let packet = parser::decode("2[\"foo\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.namespace, "/");
assert_eq!(packet.data, Some(json!(["foo"])));
}
#[test]
fn test_socket_io_event_with_namespace_decoding() {
let packet = parser::decode("2/admin,[\"bar\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.namespace, "/admin");
assert_eq!(packet.data, Some(json!(["bar"])));
}
#[test]
fn test_socket_io_event_with_ack_decoding() {
let packet = parser::decode("212[\"foo\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.id, Some(12));
assert_eq!(packet.data, Some(json!(["foo"])));
}
#[test]
fn test_socket_io_ack_decoding() {
let packet = parser::decode("312[]").unwrap();
assert_eq!(packet.packet_type, PacketType::Ack);
assert_eq!(packet.id, Some(12));
}
#[test]
fn test_socket_io_connect_error_decoding() {
let packet = parser::decode("4{\"message\":\"Not authorized\"}").unwrap();
assert_eq!(packet.packet_type, PacketType::ConnectError);
assert_eq!(packet.data, Some(json!({"message": "Not authorized"})));
}
#[test]
fn test_socket_io_binary_event_encoding() {
let packet = Packet::binary_event(
"/",
json!(["baz", {"_placeholder": true, "num": 0}]),
None,
vec![vec![0x01, 0x02, 0x03, 0x04]],
);
let encoded = parser::encode(&packet);
assert!(encoded.starts_with("51-"));
}
#[test]
fn test_socket_io_binary_ack_encoding() {
let packet = Packet::binary_ack(
"/",
json!(["bar", {"_placeholder": true, "num": 0}]),
15,
vec![vec![0x01, 0x02, 0x03, 0x04]],
);
let encoded = parser::encode(&packet);
assert!(encoded.starts_with("61-"));
assert!(encoded.contains("15"));
}
#[test]
fn test_socket_io_roundtrip() {
let original = Packet::event("/admin", json!(["hello", "world"]), Some(42));
let encoded = parser::encode(&original);
let decoded = parser::decode(&encoded).unwrap();
assert_eq!(decoded.packet_type, original.packet_type);
assert_eq!(decoded.namespace, original.namespace);
assert_eq!(decoded.id, original.id);
assert_eq!(decoded.data, original.data);
}
#[test]
fn test_socket_io_namespace_handling() {
let namespaces = vec!["/", "/admin", "/chat", "/api/v1"];
for ns in namespaces {
let packet = Packet::event(ns, json!(["test"]), None);
let encoded = parser::encode(&packet);
let decoded = parser::decode(&encoded).unwrap();
assert_eq!(decoded.namespace, ns);
}
}
#[test]
fn test_socket_io_complex_data() {
let complex_data = json!([
"event_name",
{
"user": {
"id": 123,
"name": "test",
"roles": ["admin", "user"]
},
"timestamp": 1234567890
}
]);
let packet = Packet::event("/", complex_data.clone(), None);
let encoded = parser::encode(&packet);
let decoded = parser::decode(&encoded).unwrap();
assert_eq!(decoded.data, Some(complex_data));
}