feat(auth): add authentication protocol definitions and build configuration

- Add TokenClaims message for JWT payload structure with user id, issuer, timestamps, and scopes
- Implement IssueTokenRequest/Response for creating access and refresh tokens with TTL support
- Create RefreshTokenRequest/Response for token rotation functionality
- Define RevokeTokenRequest/Response with support for single token or user-wide revocation
- Add VerifyTokenRequest/Response for validating JWT tokens with detailed claims information
- Implement signing key distribution system with GetSigningKeysRequest/Response
- Create TokenService gRPC service with IssueToken, RefreshToken, RevokeToken, VerifyToken, and GetSigningKeys methods
- Add build.rs configuration to compile proto files using tonic_prost_build
- Include channel, channel_settings, member, and permission protocol definitions for IM services
- Generate Rust code bindings through pb/core.rs and pb/im.rs modules
This commit is contained in:
zhenyi
2026-06-10 23:45:40 +08:00
commit 06e8ee96a5
43 changed files with 9671 additions and 0 deletions
+199
View File
@@ -0,0 +1,199 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use uuid::Uuid;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, SocketInfo};
use crate::socket::packet::Packet;
pub struct LocalAdapter {
server_id: String,
rooms: Arc<DashMap<String, HashSet<String>>>,
socket_rooms: Arc<DashMap<String, HashSet<String>>>,
/// socket_sid → engine_sid
pub socket_sids: Arc<DashMap<String, String>>,
/// socket_sid → namespace path
socket_namespace: Arc<DashMap<String, String>>,
send_fn: Arc<dyn Fn(&str, &Packet) -> Result<(), String> + Send + Sync>,
}
impl LocalAdapter {
pub fn new(
send_fn: impl Fn(&str, &Packet) -> Result<(), String> + Send + Sync + 'static,
) -> Self {
Self {
server_id: Uuid::new_v4().to_string(),
rooms: Arc::new(DashMap::new()),
socket_rooms: Arc::new(DashMap::new()),
socket_sids: Arc::new(DashMap::new()),
socket_namespace: Arc::new(DashMap::new()),
send_fn: Arc::new(send_fn),
}
}
fn room_key(ns: &str, room: &str) -> String {
format!("{}:{}", ns, room)
}
/// Collect socket SIDs matching the broadcast options, scoped to the given namespace.
fn collect_matching_sids(&self, opts: &BroadcastOptions, namespace: &str) -> Vec<String> {
if opts.rooms.is_empty() {
// Broadcast to all sockets in this namespace only
self.socket_sids
.iter()
.filter(|e| {
self.socket_namespace
.get(e.key())
.map(|ns| ns.value() == namespace)
.unwrap_or(false)
})
.map(|e| e.key().clone())
.collect()
} else {
let mut sids = HashSet::new();
for room in &opts.rooms {
let key = Self::room_key(namespace, room);
if let Some(entry) = self.rooms.get(&key) {
for sid in entry.value() {
sids.insert(sid.clone());
}
}
}
sids.into_iter().collect()
}
}
}
#[async_trait]
impl Adapter for LocalAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
let namespace = &packet.namespace;
let sids = self.collect_matching_sids(opts, namespace);
for sid in &sids {
if opts.except.contains(sid) {
continue;
}
// socket_sids maps socket SID -> engine SID
if let Some(entry) = self.socket_sids.get(sid) {
let engine_sid = entry.value();
let result = (self.send_fn)(engine_sid, packet);
if let Err(e) = result {
tracing::warn!("Failed to broadcast to {}: {}", sid, e);
}
}
}
Ok(())
}
async fn register(&self, socket_sid: &str, engine_sid: &str, ns: &str) -> Result<(), AdapterError> {
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
self.socket_namespace.insert(socket_sid.to_string(), ns.to_string());
Ok(())
}
async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> {
self.del_all(socket_sid, ns).await
}
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let key = Self::room_key(ns, room);
self.rooms.entry(key).or_insert_with(HashSet::new).value_mut().insert(sid.to_string());
self.socket_rooms.entry(sid.to_string()).or_insert_with(HashSet::new).value_mut().insert(room.to_string());
Ok(())
}
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let key = Self::room_key(ns, room);
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
room_sids.value_mut().remove(sid);
if room_sids.value_mut().is_empty() {
drop(room_sids);
self.rooms.remove(&key);
}
}
if let Some(mut rooms) = self.socket_rooms.get_mut(sid) {
rooms.value_mut().remove(room);
if rooms.value_mut().is_empty() {
drop(rooms);
self.socket_rooms.remove(sid);
}
}
Ok(())
}
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
for room in &rooms {
let key = Self::room_key(ns, room);
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
room_sids.value_mut().remove(sid);
if room_sids.value_mut().is_empty() {
drop(room_sids);
self.rooms.remove(&key);
}
}
}
}
self.socket_sids.remove(sid);
Ok(())
}
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
// fetch_sockets needs namespace context; use an empty namespace to match all
// (this method is typically called for inspection, not delivery)
let sids: Vec<String> = if opts.rooms.is_empty() {
self.socket_sids.iter().map(|e| e.key().clone()).collect()
} else {
let mut sids_set = HashSet::new();
for room in &opts.rooms {
for entry in self.rooms.iter() {
if entry.key().ends_with(&format!(":{}", room)) {
for sid in entry.value() {
sids_set.insert(sid.clone());
}
}
}
}
sids_set.into_iter().collect()
};
let mut result = Vec::new();
for sid in &sids {
if opts.except.contains(sid) {
continue;
}
if self.socket_sids.contains_key(sid) {
let namespace = self.socket_namespace
.get(sid)
.map(|r| r.value().clone())
.unwrap_or_default();
let rooms = self.socket_rooms
.get(sid)
.map(|r| r.value().clone())
.unwrap_or_default();
result.push(SocketInfo {
sid: sid.clone(),
namespace,
rooms,
});
}
}
Ok(result)
}
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms
.get(sid)
.map(|r| r.value().clone())
.unwrap_or_default())
}
fn server_id(&self) -> &str {
&self.server_id
}
async fn close(&self) -> Result<(), AdapterError> {
Ok(())
}
}
+99
View File
@@ -0,0 +1,99 @@
pub mod local;
pub mod redis;
pub mod nats;
use std::collections::HashSet;
use async_trait::async_trait;
use thiserror::Error;
use crate::socket::packet::Packet;
#[derive(Error, Debug)]
pub enum AdapterError {
#[error("Redis error: {0}")]
Redis(String),
#[error("NATS error: {0}")]
Nats(String),
#[error("Message bus error: {0}")]
MessageBus(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Room error: {0}")]
Room(String),
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct BroadcastOptions {
pub rooms: HashSet<String>,
pub except: HashSet<String>,
pub flags: BroadcastFlags,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct BroadcastFlags {
pub local_only: bool,
pub broadcast: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SocketInfo {
pub sid: String,
pub namespace: String,
pub rooms: HashSet<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
pub enum BusMessage {
Broadcast {
namespace: String,
packet: String,
opts: BroadcastOptions,
server_id: String,
},
SocketJoin {
namespace: String,
sid: String,
room: String,
server_id: String,
},
SocketLeave {
namespace: String,
sid: String,
room: String,
server_id: String,
},
SocketDisconnect {
namespace: String,
sid: String,
server_id: String,
},
}
#[async_trait]
pub trait Adapter: Send + Sync + 'static {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError>;
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>;
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError>;
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError>;
fn server_id(&self) -> &str;
async fn close(&self) -> Result<(), AdapterError>;
/// Register a socket SID → engine SID mapping in the adapter.
/// Must be called when a socket first connects, before any room operations.
/// The `ns` parameter is the namespace path this socket belongs to.
async fn register(&self, _socket_sid: &str, _engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
Ok(())
}
/// Unregister a socket from the adapter, removing all local mappings.
async fn unregister(&self, _socket_sid: &str, _ns: &str) -> Result<(), AdapterError> {
Ok(())
}
}
pub use local::LocalAdapter;
pub use redis::RedisAdapter;
pub use nats::NatsAdapter;
+302
View File
@@ -0,0 +1,302 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::mpsc;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
use crate::socket::message_bus::MessageBus;
use crate::socket::packet::Packet;
use crate::socket::parser;
use crate::socket::socket::Socket;
/// Handle incoming bus messages from other servers.
/// Only performs local dispatch — no remote state writes needed.
async fn handle_bus_message(
msg: BusMessage,
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
server_id: &str,
) {
match msg {
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
if sender_id == server_id {
return;
}
if let Ok(decoded_packet) = parser::decode(&packet) {
on_local_broadcast(&decoded_packet, &opts);
}
}
// NATS adapter manages room state locally; cross-server join/leave/disconnect
// are informational only and don't require duplicate state writes.
BusMessage::SocketJoin { server_id: sender_id, .. }
| BusMessage::SocketLeave { server_id: sender_id, .. }
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
if sender_id == server_id {
return;
}
}
}
}
/// NATS-based adapter that manages room state locally and uses NATS
/// for cross-server broadcast only. Does NOT depend on Redis.
pub struct NatsAdapter {
message_bus: Arc<dyn MessageBus>,
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
socket_rooms: DashMap<String, HashSet<String>>,
rooms: DashMap<String, HashSet<String>>,
/// socket_sid → engine_sid mapping for local delivery
socket_sids: DashMap<String, String>,
sockets: DashMap<String, Arc<Socket>>,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
}
impl NatsAdapter {
pub fn new(
message_bus: Arc<dyn MessageBus>,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
) -> Self {
Self {
message_bus,
server_id,
namespace,
on_local_broadcast,
room_subscribers: DashMap::new(),
socket_rooms: DashMap::new(),
rooms: DashMap::new(),
socket_sids: DashMap::new(),
sockets: DashMap::new(),
}
}
pub async fn init(&self) -> Result<(), AdapterError> {
let channels = ["broadcast", "join", "leave", "disconnect"];
let prefix = format!("socket.io:{}:", self.namespace);
for channel_type in channels {
let subject = format!("{}{}", prefix, channel_type);
match self.message_bus.subscribe(&subject).await {
Ok(rx) => {
self.room_subscribers.insert(channel_type.to_string(), rx);
}
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
}
}
self.spawn_listener();
Ok(())
}
fn spawn_listener(&self) {
let server_id = self.server_id.clone();
let on_local_broadcast = self.on_local_broadcast.clone();
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
tokio::spawn(async move {
loop {
tokio::select! {
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { join_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
else => break,
}
}
});
}
}
#[async_trait]
impl Adapter for NatsAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
if opts.flags.local_only {
(self.on_local_broadcast)(packet, opts);
return Ok(());
}
let msg = BusMessage::Broadcast {
namespace: self.namespace.clone(),
packet: parser::encode(packet),
opts: opts.clone(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:broadcast", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
(self.on_local_broadcast)(packet, opts);
Ok(())
}
async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
Ok(())
}
async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
self.socket_rooms
.entry(sid.to_string())
.and_modify(|set| { set.insert(room.to_string()); })
.or_insert_with(|| HashSet::from([room.to_string()]));
self.rooms
.entry(room.to_string())
.and_modify(|set| { set.insert(sid.to_string()); })
.or_insert_with(|| HashSet::from([sid.to_string()]));
let msg = BusMessage::SocketJoin {
namespace: self.namespace.clone(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:join", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
entry.value_mut().remove(room);
}
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
self.socket_rooms.remove(sid);
}
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
let msg = BusMessage::SocketLeave {
namespace: self.namespace.clone(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:leave", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del_all(&self, sid: &str, _ns: &str) -> Result<(), AdapterError> {
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
for room in &rooms {
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
}
}
self.socket_sids.remove(sid);
self.sockets.remove(sid);
let msg = BusMessage::SocketDisconnect {
namespace: self.namespace.clone(),
sid: sid.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:disconnect", self.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
let mut result = Vec::new();
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
self.socket_sids.iter().map(|e| e.key().clone()).collect()
} else {
let mut sids = HashSet::new();
for room in &opts.rooms {
if let Some(entry) = self.rooms.get(room) {
sids.extend(entry.value().iter().cloned());
}
}
sids
};
for sid in target_sids {
if opts.except.contains(&sid) {
continue;
}
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
result.push(SocketInfo {
sid: sid.clone(),
namespace: self.namespace.clone(),
rooms,
});
}
Ok(result)
}
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
}
fn server_id(&self) -> &str {
&self.server_id
}
async fn close(&self) -> Result<(), AdapterError> {
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
}
+344
View File
@@ -0,0 +1,344 @@
use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use fred::clients::Client;
use fred::interfaces::{KeysInterface, SetsInterface};
use tokio::sync::mpsc;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
use crate::socket::message_bus::MessageBus;
use crate::socket::packet::Packet;
use crate::socket::parser;
use crate::socket::socket::Socket;
const KEY_PREFIX_ROOMS: &str = "socket.io:rooms";
const KEY_PREFIX_SOCKET_ROOMS: &str = "socket.io:socket_rooms";
fn room_key(ns: &str, room: &str) -> String {
format!("{}:{}:{}", KEY_PREFIX_ROOMS, ns, room)
}
fn socket_rooms_key(ns: &str, sid: &str) -> String {
format!("{}:{}:{}", KEY_PREFIX_SOCKET_ROOMS, ns, sid)
}
/// Handle incoming bus messages from other servers.
/// Only performs local state updates — the remote server already wrote to Redis.
async fn handle_bus_message(
msg: BusMessage,
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
server_id: &str,
) {
match msg {
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
if sender_id == server_id {
return;
}
if let Ok(decoded_packet) = parser::decode(&packet) {
on_local_broadcast(&decoded_packet, &opts);
}
}
BusMessage::SocketJoin { server_id: sender_id, .. }
| BusMessage::SocketLeave { server_id: sender_id, .. }
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
// Skip messages from this server; remote server already updated Redis
if sender_id == server_id {
return;
}
// No duplicate Redis writes — the sender already persisted the state change
}
}
}
pub struct RedisAdapter {
message_bus: Arc<dyn MessageBus>,
redis_client: Client,
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
socket_rooms: DashMap<String, HashSet<String>>,
rooms: DashMap<String, HashSet<String>>,
sockets: DashMap<String, Arc<Socket>>,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
}
impl RedisAdapter {
pub fn new(
message_bus: Arc<dyn MessageBus>,
redis_client: Client,
server_id: String,
namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
) -> Self {
Self {
message_bus,
redis_client,
server_id,
namespace,
on_local_broadcast,
room_subscribers: DashMap::new(),
socket_rooms: DashMap::new(),
rooms: DashMap::new(),
sockets: DashMap::new(),
}
}
pub async fn init(&self) -> Result<(), AdapterError> {
let channels = ["broadcast", "join", "leave", "disconnect"];
let prefix = format!("socket.io:{}:", self.namespace);
for channel_type in channels {
let channel = format!("{}{}", prefix, channel_type);
match self.message_bus.subscribe(&channel).await {
Ok(rx) => {
self.room_subscribers.insert(channel_type.to_string(), rx);
}
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
}
}
self.spawn_listener();
Ok(())
}
fn spawn_listener(&self) {
let server_id = self.server_id.clone();
let on_local_broadcast = self.on_local_broadcast.clone();
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
tokio::spawn(async move {
loop {
tokio::select! {
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { join_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
}
}
else => break,
}
}
});
}
}
#[async_trait]
impl Adapter for RedisAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
if opts.flags.local_only {
(self.on_local_broadcast)(packet, opts);
return Ok(());
}
let msg = BusMessage::Broadcast {
namespace: packet.namespace.clone(),
packet: parser::encode(packet),
opts: opts.clone(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
(self.on_local_broadcast)(packet, opts);
Ok(())
}
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let rk = room_key(ns, room);
let srk = socket_rooms_key(ns, sid);
self.redis_client
.sadd::<(), _, _>(&rk, sid)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.redis_client
.sadd::<(), _, _>(&srk, room)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.socket_rooms
.entry(sid.to_string())
.and_modify(|set| { set.insert(room.to_string()); })
.or_insert_with(|| HashSet::from([room.to_string()]));
self.rooms
.entry(room.to_string())
.and_modify(|set| { set.insert(sid.to_string()); })
.or_insert_with(|| HashSet::from([sid.to_string()]));
let msg = BusMessage::SocketJoin {
namespace: ns.to_string(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:join", ns), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let rk = room_key(ns, room);
let srk = socket_rooms_key(ns, sid);
self.redis_client
.srem::<(), _, _>(&rk, sid)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.redis_client
.srem::<(), _, _>(&srk, room)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
entry.value_mut().remove(room);
}
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
self.socket_rooms.remove(sid);
}
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
let msg = BusMessage::SocketLeave {
namespace: ns.to_string(),
sid: sid.to_string(),
room: room.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:leave", ns), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
for room in &rooms {
if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid);
}
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
self.rooms.remove(room);
}
let rk = room_key(ns, room);
if let Err(e) = self.redis_client.srem::<(), _, _>(&rk, sid).await {
tracing::warn!("Redis SREM room error: {}", e);
}
}
}
let srk = socket_rooms_key(ns, sid);
self.redis_client
.del::<(), _>(&srk)
.await
.map_err(|e| AdapterError::Redis(e.to_string()))?;
self.sockets.remove(sid);
let msg = BusMessage::SocketDisconnect {
namespace: ns.to_string(),
sid: sid.to_string(),
server_id: self.server_id.clone(),
};
let payload = serde_json::to_vec(&msg)
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus
.publish(&format!("socket.io:{}:disconnect", ns), &payload)
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
let mut result = Vec::new();
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
self.sockets.iter().map(|e| e.key().clone()).collect()
} else {
let mut sids = HashSet::new();
for room in &opts.rooms {
if let Some(entry) = self.rooms.get(room) {
sids.extend(entry.value().iter().cloned());
}
}
sids
};
for sid in target_sids {
if opts.except.contains(&sid) {
continue;
}
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
result.push(SocketInfo {
sid: sid.clone(),
namespace: self.namespace.clone(),
rooms,
});
}
Ok(result)
}
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
}
fn server_id(&self) -> &str {
&self.server_id
}
async fn close(&self) -> Result<(), AdapterError> {
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(())
}
}
+31
View File
@@ -0,0 +1,31 @@
pub mod redis;
pub mod nats;
use async_trait::async_trait;
use thiserror::Error;
use tokio::sync::mpsc;
#[derive(Error, Debug)]
pub enum MessageBusError {
#[error("Redis error: {0}")]
Redis(String),
#[error("NATS error: {0}")]
Nats(String),
#[error("Connection closed")]
ConnectionClosed,
#[error("Channel not found: {0}")]
ChannelNotFound(String),
#[error("Serialization error: {0}")]
Serialization(String),
}
#[async_trait]
pub trait MessageBus: Send + Sync + 'static {
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError>;
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError>;
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError>;
async fn close(&self) -> Result<(), MessageBusError>;
}
pub use redis::RedisMessageBus;
pub use nats::NatsMessageBus;
+88
View File
@@ -0,0 +1,88 @@
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::{mpsc, watch};
use crate::socket::message_bus::{MessageBus, MessageBusError};
pub struct NatsMessageBus {
client: async_nats::Client,
shutdowns: DashMap<String, watch::Sender<bool>>,
}
impl NatsMessageBus {
pub async fn new(nats_url: &str) -> Result<Self, MessageBusError> {
let client = async_nats::connect(nats_url)
.await
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
Ok(Self {
client,
shutdowns: DashMap::new(),
})
}
}
#[async_trait]
impl MessageBus for NatsMessageBus {
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
self.client
.publish(channel.to_string(), message.to_vec().into())
.await
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
Ok(())
}
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
let mut subscriber = self.client
.subscribe(channel.to_string())
.await
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
self.shutdowns.insert(channel.to_string(), shutdown_tx);
tokio::spawn(async move {
use futures_util::StreamExt;
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
break;
}
message = subscriber.next() => {
match message {
Some(msg) => {
let data = msg.payload.to_vec();
if tx.send(data).await.is_err() {
break;
}
}
None => break,
}
}
}
}
if let Err(e) = subscriber.unsubscribe().await {
tracing::warn!("NATS unsubscribe error: {}", e);
}
});
Ok(rx)
}
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
if let Some((_, tx)) = self.shutdowns.remove(channel) {
let _ = tx.send(true);
}
Ok(())
}
async fn close(&self) -> Result<(), MessageBusError> {
// Signal all subscribers to shutdown
self.shutdowns.iter().for_each(|entry| {
let _ = entry.value().send(true);
});
self.shutdowns.clear();
Ok(())
}
}
+99
View File
@@ -0,0 +1,99 @@
use async_trait::async_trait;
use fred::clients::{Client, SubscriberClient};
use fred::interfaces::{ClientLike, EventInterface, PubsubInterface};
use fred::prelude::*;
use tokio::sync::mpsc;
use crate::socket::message_bus::{MessageBus, MessageBusError};
pub struct RedisMessageBus {
client: Client,
subscriber: SubscriberClient,
}
impl RedisMessageBus {
pub async fn new(redis_url: &str) -> Result<Self, MessageBusError> {
let config = Config::from_url(redis_url)
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
let client = Client::new(config.clone(), None, None, None);
let subscriber = SubscriberClient::new(config, None, None, None);
// connect() starts the connection task; result is checked by wait_for_connect()
let _ = client.connect().await;
let _ = subscriber.connect().await;
client
.wait_for_connect()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
subscriber
.wait_for_connect()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(Self { client, subscriber })
}
pub fn client(&self) -> &Client {
&self.client
}
}
#[async_trait]
impl MessageBus for RedisMessageBus {
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
self.client
.publish::<(), _, Vec<u8>>(channel, message.to_vec())
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(())
}
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
self.subscriber
.subscribe(channel.to_string())
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
let subscriber = self.subscriber.clone();
let channel_owned = channel.to_string();
let mut message_rx = subscriber.message_rx();
tokio::spawn(async move {
while let Ok(message) = message_rx.recv().await {
if &message.channel == &channel_owned {
let data: Vec<u8> = FromValue::from_value(message.value)
.unwrap_or_default();
if tx.send(data).await.is_err() {
break;
}
}
}
});
Ok(rx)
}
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
self.subscriber
.unsubscribe(channel.to_string())
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(())
}
async fn close(&self) -> Result<(), MessageBusError> {
self.client
.quit()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
self.subscriber
.quit()
.await
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(())
}
}
+16
View File
@@ -0,0 +1,16 @@
pub mod adapter;
pub mod message_bus;
pub mod namespace;
pub mod packet;
pub mod parser;
pub mod server;
pub mod session_store;
pub mod socket;
pub use adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, RedisAdapter, NatsAdapter, SocketInfo};
pub use message_bus::{MessageBus, MessageBusError, RedisMessageBus, NatsMessageBus};
pub use namespace::{is_valid_namespace, Namespace, NamespaceManager};
pub use packet::{Packet, PacketType};
pub use server::{SocketServer, SocketServerBuilder};
pub use session_store::{InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait};
pub use socket::Socket;
+239
View File
@@ -0,0 +1,239 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::RwLock;
use crate::socket::adapter::{Adapter, BroadcastOptions, BroadcastFlags};
use crate::socket::packet::Packet;
use crate::socket::socket::Socket;
pub type EventHandler = Arc<dyn Fn(&Socket, &serde_json::Value) + Send + Sync>;
type ConnectHandler = Arc<dyn Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync>;
pub struct Namespace {
pub path: String,
/// Primary storage: socket_sid → Socket
sockets: DashMap<String, Arc<Socket>>,
/// Reverse index: engine_sid → socket_sid (for engine-level lookups)
engine_to_socket: DashMap<String, String>,
handlers: RwLock<HashMap<String, Vec<EventHandler>>>,
connect_handler: RwLock<Option<ConnectHandler>>,
pub(crate) adapter: RwLock<Option<Arc<dyn Adapter>>>,
}
impl Namespace {
pub fn new(path: impl Into<String>) -> Self {
Self {
path: path.into(),
sockets: DashMap::new(),
engine_to_socket: DashMap::new(),
handlers: RwLock::new(HashMap::new()),
connect_handler: RwLock::new(None),
adapter: RwLock::new(None),
}
}
pub async fn set_adapter(&self, adapter: Arc<dyn Adapter>) {
let mut guard = self.adapter.write().await;
*guard = Some(adapter);
}
/// Add a socket to this namespace. Returns Err if the connect handler rejects.
pub async fn add_socket(&self, socket: Arc<Socket>) -> Result<(), String> {
// Run connect handler before adding to storage
let handler = self.connect_handler.read().await;
if let Some(ref h) = *handler {
h(&socket, None)?;
}
drop(handler);
let socket_sid = socket.sid.clone();
let engine_sid = socket.engine_sid.clone();
// Register with adapter (socket_sid → engine_sid mapping)
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
if let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await {
tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e);
}
}
// Store socket by socket_sid, plus reverse index
self.sockets.insert(socket_sid.clone(), socket);
self.engine_to_socket.insert(engine_sid, socket_sid);
Ok(())
}
/// Remove a socket by its socket SID.
pub async fn remove_socket_by_sid(&self, socket_sid: &str) {
if let Some((_, socket)) = self.sockets.remove(socket_sid) {
self.engine_to_socket.remove(&socket.engine_sid);
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
if let Err(e) = adapter.del_all(socket_sid, &self.path).await {
tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e);
}
}
}
}
/// Remove a socket by its engine SID (for engine-level disconnections).
pub async fn remove_socket(&self, engine_sid: &str) {
if let Some((_, socket_sid)) = self.engine_to_socket.remove(engine_sid) {
self.remove_socket_by_sid(&socket_sid).await;
}
}
/// Look up a socket by its socket SID.
pub fn get_socket(&self, socket_sid: &str) -> Option<Arc<Socket>> {
self.sockets.get(socket_sid).map(|r| r.value().clone())
}
/// Look up a socket by its engine SID (reverse lookup).
pub fn get_socket_by_engine_sid(&self, engine_sid: &str) -> Option<Arc<Socket>> {
self.engine_to_socket
.get(engine_sid)
.and_then(|entry| self.sockets.get(entry.value()).map(|r| r.value().clone()))
}
pub fn socket_count(&self) -> usize {
self.sockets.len()
}
pub async fn on_event(&self, event: impl Into<String>, handler: EventHandler) {
let mut handlers = self.handlers.write().await;
handlers.entry(event.into()).or_default().push(handler);
}
pub async fn on_connect<F>(&self, handler: F)
where
F: Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync + 'static,
{
let mut connect_handler = self.connect_handler.write().await;
*connect_handler = Some(Arc::new(handler));
}
pub async fn emit(&self, event: impl Into<String>, data: serde_json::Value) {
let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
let opts = BroadcastOptions::default();
if let Err(e) = adapter.broadcast(&packet, &opts).await {
tracing::warn!("Adapter broadcast error: {}", e);
}
} else {
self.emit_local(&packet);
}
}
pub async fn emit_to_room(&self, room: &str, event: impl Into<String>, data: serde_json::Value) {
let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
let opts = BroadcastOptions {
rooms: HashSet::from([room.to_string()]),
except: HashSet::new(),
flags: BroadcastFlags::default(),
};
if let Err(e) = adapter.broadcast(&packet, &opts).await {
tracing::warn!("Adapter broadcast to room error: {}", e);
}
} else {
self.emit_local(&packet);
}
}
pub fn emit_local(&self, packet: &Packet) {
for entry in self.sockets.iter() {
let socket = entry.value();
if socket.send_packet(packet).is_err() {
tracing::warn!("Failed to send event to socket {}", socket.sid);
}
}
}
pub async fn emit_to(&self, socket_sid: &str, event: impl Into<String>, data: serde_json::Value) {
if let Some(socket) = self.get_socket(socket_sid) {
let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
if socket.send_packet(&packet).is_err() {
tracing::warn!("Failed to send event to socket {}", socket.sid);
}
}
}
pub async fn handle_event(&self, socket: &Socket, event: &str, data: &serde_json::Value) {
let handlers = self.handlers.read().await;
if let Some(event_handlers) = handlers.get(event) {
for handler in event_handlers {
handler(socket, data);
}
}
}
}
pub struct NamespaceManager {
namespaces: DashMap<String, Arc<Namespace>>,
}
impl NamespaceManager {
pub fn new() -> Self {
let manager = Self {
namespaces: DashMap::new(),
};
manager.create_namespace("/");
manager
}
pub fn create_namespace(&self, path: impl Into<String>) -> Arc<Namespace> {
let path = path.into();
let namespace = Arc::new(Namespace::new(&path));
self.namespaces.insert(path.clone(), namespace.clone());
namespace
}
pub fn get_namespace(&self, path: &str) -> Option<Arc<Namespace>> {
self.namespaces.get(path).map(|r| r.value().clone())
}
pub fn get_or_create_namespace(&self, path: &str) -> Arc<Namespace> {
if let Some(ns) = self.get_namespace(path) {
ns
} else {
self.create_namespace(path)
}
}
pub fn remove_namespace(&self, path: &str) {
self.namespaces.remove(path);
}
pub fn namespace_count(&self) -> usize {
self.namespaces.len()
}
pub fn all_namespaces(&self) -> Vec<Arc<Namespace>> {
self.namespaces.iter().map(|e| e.value().clone()).collect()
}
}
impl Default for NamespaceManager {
fn default() -> Self {
Self::new()
}
}
/// Validate a namespace path. Returns true if the path is valid.
/// Rules: must start with '/', max 256 chars, no control characters.
pub fn is_valid_namespace(path: &str) -> bool {
!path.is_empty()
&& path.starts_with('/')
&& path.len() <= 256
&& !path.chars().any(|c| c.is_control())
}
+174
View File
@@ -0,0 +1,174 @@
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum PacketType {
Connect = 0,
Disconnect = 1,
Event = 2,
Ack = 3,
ConnectError = 4,
BinaryEvent = 5,
BinaryAck = 6,
}
impl TryFrom<u8> for PacketType {
type Error = PacketError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Connect),
1 => Ok(Self::Disconnect),
2 => Ok(Self::Event),
3 => Ok(Self::Ack),
4 => Ok(Self::ConnectError),
5 => Ok(Self::BinaryEvent),
6 => Ok(Self::BinaryAck),
_ => Err(PacketError::InvalidType(value)),
}
}
}
impl TryFrom<char> for PacketType {
type Error = PacketError;
fn try_from(value: char) -> Result<Self, Self::Error> {
match value {
'0' => Ok(Self::Connect),
'1' => Ok(Self::Disconnect),
'2' => Ok(Self::Event),
'3' => Ok(Self::Ack),
'4' => Ok(Self::ConnectError),
'5' => Ok(Self::BinaryEvent),
'6' => Ok(Self::BinaryAck),
_ => Err(PacketError::InvalidTypeChar(value)),
}
}
}
#[derive(Debug, Clone)]
pub struct Packet {
pub packet_type: PacketType,
pub namespace: String,
pub data: Option<Value>,
pub id: Option<u64>,
pub attachments: Vec<Vec<u8>>,
/// Expected number of binary attachments (set during decode for binary packets).
/// Used to validate attachment count before assembling the full packet.
pub expected_attachments: Option<usize>,
}
impl Packet {
pub fn connect(namespace: impl Into<String>, data: Option<Value>) -> Self {
Self {
packet_type: PacketType::Connect,
namespace: namespace.into(),
data,
id: None,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn disconnect(namespace: impl Into<String>) -> Self {
Self {
packet_type: PacketType::Disconnect,
namespace: namespace.into(),
data: None,
id: None,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn event(namespace: impl Into<String>, data: Value, id: Option<u64>) -> Self {
Self {
packet_type: PacketType::Event,
namespace: namespace.into(),
data: Some(data),
id,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn ack(namespace: impl Into<String>, data: Value, id: u64) -> Self {
Self {
packet_type: PacketType::Ack,
namespace: namespace.into(),
data: Some(data),
id: Some(id),
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn connect_error(namespace: impl Into<String>, message: impl Into<String>) -> Self {
Self {
packet_type: PacketType::ConnectError,
namespace: namespace.into(),
data: Some(serde_json::json!({ "message": message.into() })),
id: None,
attachments: Vec::new(),
expected_attachments: None,
}
}
pub fn binary_event(
namespace: impl Into<String>,
data: Value,
id: Option<u64>,
attachments: Vec<Vec<u8>>,
) -> Self {
Self {
packet_type: PacketType::BinaryEvent,
namespace: namespace.into(),
data: Some(data),
id,
attachments,
expected_attachments: None,
}
}
pub fn binary_ack(
namespace: impl Into<String>,
data: Value,
id: u64,
attachments: Vec<Vec<u8>>,
) -> Self {
Self {
packet_type: PacketType::BinaryAck,
namespace: namespace.into(),
data: Some(data),
id: Some(id),
attachments,
expected_attachments: None,
}
}
pub fn has_binary(&self) -> bool {
!self.attachments.is_empty()
}
pub fn attachment_count(&self) -> usize {
self.attachments.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum PacketError {
#[error("invalid packet type: {0}")]
InvalidType(u8),
#[error("invalid packet type char: {0}")]
InvalidTypeChar(char),
#[error("empty packet")]
Empty,
#[error("invalid format: {0}")]
InvalidFormat(String),
#[error("json error: {0}")]
Json(#[from] serde_json::Error),
#[error("missing namespace")]
MissingNamespace,
#[error("invalid attachment count")]
InvalidAttachmentCount,
}
+392
View File
@@ -0,0 +1,392 @@
use serde_json::Value;
use crate::socket::packet::{Packet, PacketError, PacketType};
pub fn encode(packet: &Packet) -> String {
let type_char = packet.packet_type as u8 + b'0';
let mut result = String::new();
result.push(type_char as char);
if packet.has_binary() {
result.push_str(&packet.attachment_count().to_string());
result.push('-');
}
if packet.namespace != "/" {
result.push_str(&packet.namespace);
result.push(',');
}
if let Some(id) = packet.id {
result.push_str(&id.to_string());
}
if let Some(ref data) = packet.data {
if packet.has_binary() {
let data_with_placeholders = replace_binary_with_placeholders(data, packet.attachment_count());
let encoded_data = serde_json::to_string(&data_with_placeholders)
.unwrap_or_else(|e| {
tracing::error!("Failed to serialize socket packet data: {}", e);
"null".to_string()
});
result.push_str(&encoded_data);
} else {
let encoded_data = serde_json::to_string(data)
.unwrap_or_else(|e| {
tracing::error!("Failed to serialize socket packet data: {}", e);
"null".to_string()
});
result.push_str(&encoded_data);
}
}
result
}
pub fn encode_with_attachments(packet: &Packet) -> Vec<Vec<u8>> {
let mut result = Vec::new();
let encoded = encode(packet);
result.push(encoded.into_bytes());
for attachment in &packet.attachments {
result.push(attachment.clone());
}
result
}
pub fn decode(input: &str) -> Result<Packet, PacketError> {
if input.is_empty() {
return Err(PacketError::Empty);
}
let mut chars = input.chars().peekable();
let type_char = chars.next().ok_or(PacketError::Empty)?;
let packet_type = PacketType::try_from(type_char)?;
let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) {
let mut count_str = String::new();
while let Some(&c) = chars.peek() {
if c == '-' {
chars.next();
break;
}
if c.is_ascii_digit() {
count_str.push(c);
chars.next();
} else {
break;
}
}
count_str.parse::<usize>().unwrap_or(0)
} else {
0
};
let remaining: String = chars.collect();
let (namespace, rest) = if let Some(after_slash) = remaining.strip_prefix('/') {
// Check if this is a custom namespace (has a comma separating namespace from data/id)
// or if '/' is just the root namespace prefix followed immediately by data
if let Some(comma_pos) = after_slash.find(',') {
let ns = format!("/{}", &after_slash[..comma_pos]);
let rest = after_slash[comma_pos + 1..].to_string();
(ns, rest)
} else if after_slash.starts_with('[')
|| after_slash.starts_with(|c: char| c.is_ascii_digit())
|| after_slash.is_empty()
{
// '/[' means '/' is the root namespace and '[' starts the data
// '/<digits>' means root namespace followed by ack id
// '/' alone means disconnect on root namespace
("/".to_string(), after_slash.to_string())
} else {
// Non-root namespace without data (e.g., disconnect on custom namespace)
(remaining, String::new())
}
} else {
("/".to_string(), remaining)
};
let (id, data_str) = parse_id_and_data(&rest);
let data = if data_str.is_empty() {
None
} else {
Some(serde_json::from_str(&data_str)?)
};
Ok(Packet {
packet_type,
namespace,
data,
id,
attachments: Vec::new(),
// Store attachment_count for binary packets; actual attachments come via decode_with_attachments
expected_attachments: if attachment_count > 0 { Some(attachment_count) } else { None },
})
}
pub fn decode_with_attachments(
main_packet: &str,
attachments: Vec<Vec<u8>>,
) -> Result<Packet, PacketError> {
let mut packet = decode(main_packet)?;
let expected = packet.expected_attachments.unwrap_or(0);
if expected != attachments.len() {
return Err(PacketError::InvalidAttachmentCount);
}
packet.attachments = attachments;
packet.expected_attachments = None;
if packet.has_binary() {
if let Some(ref data) = packet.data {
packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments));
}
}
Ok(packet)
}
fn parse_id_and_data(input: &str) -> (Option<u64>, String) {
let mut id_str = String::new();
let mut chars = input.chars().peekable();
while let Some(&c) = chars.peek() {
if c.is_ascii_digit() {
id_str.push(c);
chars.next();
} else {
break;
}
}
let id = if id_str.is_empty() {
None
} else {
id_str.parse::<u64>().ok()
};
let data: String = chars.collect();
(id, data)
}
/// Replace binary values in the data with { "_placeholder": true, "num": N } placeholders.
/// This is used when encoding binary events/acks for transmission over text-based transports.
fn replace_binary_with_placeholders(value: &Value, total_attachments: usize) -> Value {
match value {
Value::Array(arr) => {
let mut placeholder_idx = total_attachments; // Start from known count
let new_arr: Vec<Value> = arr
.iter()
.map(|v| replace_binary_with_placeholders_inner(v, &mut placeholder_idx))
.collect();
Value::Array(new_arr)
}
Value::Object(map) => {
let mut placeholder_idx = total_attachments;
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(
k.clone(),
replace_binary_with_placeholders_inner(v, &mut placeholder_idx),
);
}
Value::Object(new_map)
}
_ => value.clone(),
}
}
fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut usize) -> Value {
match value {
Value::Array(arr) => {
let new_arr: Vec<Value> = arr
.iter()
.map(|v| replace_binary_with_placeholders_inner(v, placeholder_idx))
.collect();
Value::Array(new_arr)
}
Value::Object(map) => {
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(
k.clone(),
replace_binary_with_placeholders_inner(v, placeholder_idx),
);
}
Value::Object(new_map)
}
// Binary data would be represented as base64 strings in the initial data;
// in the Socket.IO protocol, binary attachments are separate and referenced by placeholder.
// This function handles the case where the data structure itself contains placeholder markers.
_ => value.clone(),
}
}
fn replace_placeholders_with_binary(value: &Value, attachments: &[Vec<u8>]) -> Value {
match value {
Value::Object(map) => {
// Check if this is a placeholder object: { "_placeholder": true, "num": N }
if let (Some(Value::Bool(true)), Some(Value::Number(num))) =
(map.get("_placeholder"), map.get("num"))
{
if let Some(idx) = num.as_u64() {
if let Some(attachment) = attachments.get(idx as usize) {
return Value::String(base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
attachment,
));
}
}
}
let mut new_map = serde_json::Map::new();
for (k, v) in map {
new_map.insert(k.clone(), replace_placeholders_with_binary(v, attachments));
}
Value::Object(new_map)
}
Value::Array(arr) => Value::Array(
arr.iter()
.map(|v| replace_placeholders_with_binary(v, attachments))
.collect(),
),
_ => value.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_encode_connect() {
let packet = Packet::connect("/", None);
let encoded = encode(&packet);
assert_eq!(encoded, "0");
let packet = Packet::connect("/admin", Some(json!({"sid": "abc"})));
let encoded = encode(&packet);
assert_eq!(encoded, "0/admin,{\"sid\":\"abc\"}");
}
#[test]
fn test_encode_event() {
let packet = Packet::event("/", json!(["foo"]), None);
let encoded = encode(&packet);
assert_eq!(encoded, "2[\"foo\"]");
let packet = Packet::event("/admin", json!(["bar"]), None);
let encoded = encode(&packet);
assert_eq!(encoded, "2/admin,[\"bar\"]");
}
#[test]
fn test_encode_event_with_ack() {
let packet = Packet::event("/", json!(["foo"]), Some(12));
let encoded = encode(&packet);
assert_eq!(encoded, "212[\"foo\"]");
}
#[test]
fn test_encode_ack() {
let packet = Packet::ack("/", json!([]), 12);
let encoded = encode(&packet);
assert_eq!(encoded, "312[]");
}
#[test]
fn test_encode_disconnect() {
let packet = Packet::disconnect("/");
let encoded = encode(&packet);
assert_eq!(encoded, "1");
let packet = Packet::disconnect("/admin");
let encoded = encode(&packet);
assert_eq!(encoded, "1/admin,");
}
#[test]
fn test_encode_connect_error() {
let packet = Packet::connect_error("/", "Not authorized");
let encoded = encode(&packet);
assert_eq!(encoded, "4{\"message\":\"Not authorized\"}");
}
#[test]
fn test_decode_connect() {
let packet = decode("0").unwrap();
assert_eq!(packet.packet_type, PacketType::Connect);
assert_eq!(packet.namespace, "/");
assert!(packet.data.is_none());
}
#[test]
fn test_decode_connect_with_namespace() {
let packet = decode("0/admin,{\"sid\":\"abc\"}").unwrap();
assert_eq!(packet.packet_type, PacketType::Connect);
assert_eq!(packet.namespace, "/admin");
assert!(packet.data.is_some());
}
#[test]
fn test_decode_event() {
let packet = decode("2[\"foo\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.namespace, "/");
assert_eq!(packet.data, Some(json!(["foo"])));
}
#[test]
fn test_decode_event_with_namespace() {
let packet = decode("2/admin,[\"bar\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.namespace, "/admin");
assert_eq!(packet.data, Some(json!(["bar"])));
}
#[test]
fn test_decode_event_with_ack() {
let packet = decode("212[\"foo\"]").unwrap();
assert_eq!(packet.packet_type, PacketType::Event);
assert_eq!(packet.id, Some(12));
assert_eq!(packet.data, Some(json!(["foo"])));
}
#[test]
fn test_decode_ack() {
let packet = decode("312[]").unwrap();
assert_eq!(packet.packet_type, PacketType::Ack);
assert_eq!(packet.id, Some(12));
}
#[test]
fn test_decode_disconnect() {
let packet = decode("1").unwrap();
assert_eq!(packet.packet_type, PacketType::Disconnect);
assert_eq!(packet.namespace, "/");
}
#[test]
fn test_decode_disconnect_with_namespace() {
let packet = decode("1/admin,").unwrap();
assert_eq!(packet.packet_type, PacketType::Disconnect);
assert_eq!(packet.namespace, "/admin");
}
#[test]
fn test_decode_binary_event_attachment_count() {
let packet = decode("51-[\"baz\",{\"_placeholder\":true,\"num\":0}]").unwrap();
assert_eq!(packet.packet_type, PacketType::BinaryEvent);
assert_eq!(packet.expected_attachments, Some(1));
assert_eq!(packet.namespace, "/");
}
}
+301
View File
@@ -0,0 +1,301 @@
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::mpsc;
use crate::engine::packet::Packet as EnginePacket;
use crate::engine::packet::PacketData as EnginePacketData;
use crate::engine::server::{EngineConfig, EngineServer};
use crate::engine::session::SessionStore;
use crate::socket::adapter::{Adapter, LocalAdapter};
use crate::socket::namespace::NamespaceManager;
use crate::socket::packet::{Packet, PacketType};
use crate::socket::parser;
use crate::socket::socket::Socket;
pub struct SocketServer {
pub engine: Arc<EngineServer>,
pub namespaces: Arc<NamespaceManager>,
pub adapter: Arc<dyn Adapter>,
socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>>,
}
impl SocketServer {
pub fn new(config: EngineConfig) -> Self {
SocketServerBuilder::new(config).build()
}
pub fn builder(config: EngineConfig) -> SocketServerBuilder {
SocketServerBuilder::new(config)
}
pub fn of(&self, path: impl Into<String>) -> Arc<crate::socket::namespace::Namespace> {
self.namespaces.get_or_create_namespace(&path.into())
}
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
self.engine.clone().run_http(addr).await
}
pub fn register_socket(&self, sid: String, tx: mpsc::Sender<Packet>) {
self.socket_txs.insert(sid, tx);
}
pub fn unregister_socket(&self, sid: &str) {
self.socket_txs.remove(sid);
}
}
pub struct SocketServerBuilder {
config: EngineConfig,
adapter: Option<Arc<dyn Adapter>>,
}
impl SocketServerBuilder {
pub fn new(config: EngineConfig) -> Self {
Self {
config,
adapter: None,
}
}
pub fn adapter(mut self, adapter: Arc<dyn Adapter>) -> Self {
self.adapter = Some(adapter);
self
}
pub fn build(self) -> SocketServer {
let namespaces = Arc::new(NamespaceManager::new());
let socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>> = Arc::new(DashMap::new());
let engine_store = SessionStore::new();
let namespaces_clone = namespaces.clone();
let socket_txs_clone = socket_txs.clone();
let engine_store_clone = engine_store.clone();
let adapter: Arc<dyn Adapter> = self.adapter.unwrap_or_else(|| {
let ns_clone = namespaces.clone();
let send_fn = move |engine_sid: &str, packet: &Packet| {
if let Some(ns) = ns_clone.get_namespace(&packet.namespace) {
if let Some(socket) = ns.get_socket_by_engine_sid(engine_sid) {
socket.send_packet(packet).map_err(|e| e.to_string())
} else {
Err(format!(
"Socket with engine_sid {} not found in namespace {}",
engine_sid, packet.namespace
))
}
} else {
Err(format!("Namespace {} not found", packet.namespace))
}
};
Arc::new(LocalAdapter::new(send_fn))
});
let adapter_clone = adapter.clone();
let engine = Arc::new(EngineServer::with_store(
self.config,
engine_store,
move |sid, engine_packet| {
let namespaces = namespaces_clone.clone();
let socket_txs = socket_txs_clone.clone();
let engine_store = engine_store_clone.clone();
let adapter = adapter_clone.clone();
tokio::spawn(async move {
handle_engine_message(
sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter,
).await;
});
},
));
let server = SocketServer {
engine,
namespaces,
adapter,
socket_txs,
};
for ns in server.namespaces.all_namespaces() {
let adapter_ref = server.adapter.clone();
tokio::spawn(async move {
ns.set_adapter(adapter_ref).await;
});
}
server
}
}
async fn handle_engine_message(
engine_sid: String,
engine_packet: EnginePacket,
namespaces: &Arc<NamespaceManager>,
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
engine_store: &SessionStore,
adapter: &Arc<dyn Adapter>,
) {
if let EnginePacketData::Text(ref text) = engine_packet.data {
if let Ok(socket_packet) = parser::decode(text) {
match socket_packet.packet_type {
PacketType::Connect => {
handle_connect(&engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter).await;
}
PacketType::Disconnect => {
handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs);
}
PacketType::Event => {
handle_event(&engine_sid, &socket_packet, namespaces);
}
PacketType::Ack => {
handle_ack(&engine_sid, &socket_packet);
}
_ => {}
}
}
}
}
async fn handle_connect(
engine_sid: &str,
packet: &Packet,
namespaces: &Arc<NamespaceManager>,
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
engine_store: &SessionStore,
adapter: &Arc<dyn Adapter>,
) {
// Validate namespace path to prevent DoS via arbitrary namespace creation
if !crate::socket::namespace::is_valid_namespace(&packet.namespace) {
tracing::warn!("Rejected connect with invalid namespace: {}", packet.namespace);
return;
}
let namespace = namespaces.get_or_create_namespace(&packet.namespace);
// Ensure newly created namespaces get the shared adapter
{
let ns_adapter = namespace.adapter.read().await;
if ns_adapter.is_none() {
drop(ns_adapter);
let adapter_ref = adapter.clone();
let ns_clone = namespace.clone();
tokio::spawn(async move {
ns_clone.set_adapter(adapter_ref).await;
});
}
}
let socket_sid = crate::engine::session::generate_sid();
let (tx, mut rx) = mpsc::channel::<Packet>(256);
socket_txs.insert(socket_sid.clone(), tx.clone());
let socket = Arc::new(Socket::new(
socket_sid.clone(),
packet.namespace.clone(),
engine_sid.to_string(),
tx,
));
// Run connect handler and add to namespace.
// If the handler rejects, clean up and do NOT send a Connect response.
if let Err(msg) = namespace.add_socket(socket.clone()).await {
tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg);
socket_txs.remove(&socket_sid);
return;
}
// Connect handler passed — spawn forwarding task
let engine_store_clone = engine_store.clone();
let engine_sid_clone = engine_sid.to_string();
let socket_sid_clone = socket_sid.clone();
let socket_txs_clone = socket_txs.clone();
let namespace_clone = namespace.clone();
tokio::spawn(async move {
while let Some(socket_packet) = rx.recv().await {
let encoded = parser::encode(&socket_packet);
let engine_packet = EnginePacket::message_text(encoded);
if let Some(session) = engine_store_clone.get(&engine_sid_clone) {
let mut s = session.write().await;
if s.state == crate::engine::session::SessionState::Closed {
break;
}
s.push_packet(engine_packet);
} else {
break;
}
}
// Forwarding task ended — ensure socket is cleaned up from namespace
socket_txs_clone.remove(&socket_sid_clone);
namespace_clone.remove_socket_by_sid(&socket_sid_clone).await;
});
// Send Connect response (only after handler passed)
let response = Packet::connect(
&socket.namespace,
Some(serde_json::json!({ "sid": &socket.sid })),
);
if socket.send_packet(&response).is_err() {
tracing::warn!("Failed to send connect response to socket {}", socket.sid);
}
}
fn handle_disconnect(
engine_sid: &str,
packet: &Packet,
namespaces: &Arc<NamespaceManager>,
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
) {
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
// Look up socket by engine_sid, then remove by socket_sid
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
socket_txs.remove(&socket.sid);
let socket_sid = socket.sid.clone();
let ns_clone = namespace.clone();
tokio::spawn(async move {
ns_clone.remove_socket_by_sid(&socket_sid).await;
});
}
}
}
fn handle_event(
engine_sid: &str,
packet: &Packet,
namespaces: &Arc<NamespaceManager>,
) {
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
if let Some(ref data) = packet.data {
if let Some(arr) = data.as_array() {
if let Some(event) = arr.first().and_then(|v| v.as_str()) {
let event_data = if arr.len() > 1 {
serde_json::Value::Array(arr[1..].to_vec())
} else {
serde_json::Value::Null
};
let namespace_clone = namespace.clone();
let event = event.to_string();
let socket_clone = socket.clone();
tokio::spawn(async move {
namespace_clone
.handle_event(&socket_clone, &event, &event_data)
.await;
});
}
}
}
}
}
}
fn handle_ack(engine_sid: &str, packet: &Packet) {
tracing::debug!(
"Received ACK from {} for namespace {} with id {:?}",
engine_sid,
packet.namespace,
packet.id
);
}
+88
View File
@@ -0,0 +1,88 @@
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use dashmap::DashMap;
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
pub struct InMemorySessionStore {
sessions: Arc<DashMap<String, SessionInfo>>,
}
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemorySessionStore {
fn default() -> Self {
Self::new()
}
}
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[async_trait]
impl SessionStoreTrait for InMemorySessionStore {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
let info = SessionInfo {
sid: sid.to_string(),
transport: transport.to_string(),
state: "connecting".to_string(),
server_id: server_id.to_string(),
created_at: now_millis(),
last_ping: now_millis(),
};
self.sessions.insert(sid.to_string(), info);
Ok(())
}
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
Ok(self.sessions.get(sid).map(|r| r.value().clone()))
}
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
if let Some(mut entry) = self.sessions.get_mut(sid) {
entry.value_mut().state = state.to_string();
Ok(())
} else {
Err(SessionError::NotFound(sid.to_string()))
}
}
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
if let Some(mut entry) = self.sessions.get_mut(sid) {
entry.value_mut().transport = transport.to_string();
Ok(())
} else {
Err(SessionError::NotFound(sid.to_string()))
}
}
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
if let Some(mut entry) = self.sessions.get_mut(sid) {
entry.value_mut().last_ping = now_millis();
Ok(())
} else {
Err(SessionError::NotFound(sid.to_string()))
}
}
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
self.sessions.remove(sid);
Ok(())
}
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
Ok(self.sessions.contains_key(sid))
}
}
+41
View File
@@ -0,0 +1,41 @@
pub mod memory;
pub mod redis;
use async_trait::async_trait;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum SessionError {
#[error("Redis error: {0}")]
Redis(String),
#[error("Session not found: {0}")]
NotFound(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Session expired: {0}")]
Expired(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SessionInfo {
pub sid: String,
pub transport: String,
pub state: String,
pub server_id: String,
pub created_at: u64,
pub last_ping: u64,
}
#[async_trait]
pub trait SessionStoreTrait: Send + Sync + 'static {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError>;
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError>;
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>;
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>;
async fn update_ping(&self, sid: &str) -> Result<(), SessionError>;
async fn remove(&self, sid: &str) -> Result<(), SessionError>;
async fn exists(&self, sid: &str) -> Result<bool, SessionError>;
}
pub use memory::InMemorySessionStore;
pub use redis::RedisSessionStore;
+164
View File
@@ -0,0 +1,164 @@
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use fred::prelude::*;
use crate::socket::message_bus::redis::RedisMessageBus;
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
fn now_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
const DEFAULT_TTL_SECS: u64 = 60;
const KEY_PREFIX: &str = "socket.io:session";
pub struct RedisSessionStore {
client: Client,
ttl_secs: u64,
}
impl RedisSessionStore {
pub fn new(bus: &RedisMessageBus, ttl_secs: Option<u64>) -> Self {
Self {
client: bus.client().clone(),
ttl_secs: ttl_secs.unwrap_or(DEFAULT_TTL_SECS),
}
}
fn key(&self, sid: &str) -> String {
format!("{}:{}", KEY_PREFIX, sid)
}
}
#[async_trait]
impl SessionStoreTrait for RedisSessionStore {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
let key = self.key(sid);
let now = now_millis();
// Batch all fields in a single HSET call for efficiency
let fields: Vec<(&str, String)> = vec![
("sid", sid.to_string()),
("transport", transport.to_string()),
("state", "connecting".to_string()),
("server_id", server_id.to_string()),
("created_at", now.to_string()),
("last_ping", now.to_string()),
];
self.client
.hset::<(), _, _>(&key, fields)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
let key = self.key(sid);
// Use hgetall directly — if the key doesn't exist Redis returns an empty map.
// This avoids the TOCTOU race between EXISTS and HGETALL.
let values: std::collections::HashMap<String, String> = self.client
.hgetall::<std::collections::HashMap<String, String>, _>(&key)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
if values.is_empty() {
return Ok(None);
}
let info = SessionInfo {
sid: values.get("sid").cloned().unwrap_or_default(),
transport: values.get("transport").cloned().unwrap_or_default(),
state: values.get("state").cloned().unwrap_or_default(),
server_id: values.get("server_id").cloned().unwrap_or_default(),
created_at: values.get("created_at").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
last_ping: values.get("last_ping").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
};
Ok(Some(info))
}
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
let key = self.key(sid);
// Use HSET (not HSETNX) to overwrite existing fields
self.client
.hset::<(), _, _>(&key, ("state", state))
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
let key = self.key(sid);
// Use HSET (not HSETNX) to overwrite existing fields
self.client
.hset::<(), _, _>(&key, ("transport", transport))
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
let key = self.key(sid);
let now = now_millis();
// Use HSET (not HSETNX) to overwrite existing fields
self.client
.hset::<(), _, _>(&key, ("last_ping", now.to_string()))
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
self.client
.expire::<(), _>(&key, self.ttl_secs as i64, None)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
let key = self.key(sid);
self.client
.del::<(), _>(&key)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(())
}
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
let key = self.key(sid);
let exists: bool = self.client
.exists::<bool, _>(&key)
.await
.map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(exists)
}
}
+72
View File
@@ -0,0 +1,72 @@
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
use crate::socket::packet::Packet;
pub struct Socket {
pub sid: String,
pub namespace: String,
pub engine_sid: String,
ack_id: AtomicU64,
tx: mpsc::Sender<Packet>,
}
impl Socket {
pub fn new(
sid: String,
namespace: String,
engine_sid: String,
tx: mpsc::Sender<Packet>,
) -> Self {
Self {
sid,
namespace,
engine_sid,
ack_id: AtomicU64::new(0),
tx,
}
}
pub fn next_ack_id(&self) -> u64 {
self.ack_id.fetch_add(1, Ordering::SeqCst)
}
pub fn send_packet(&self, packet: &Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
self.tx.try_send(packet.clone())
}
pub fn emit(&self, event: impl Into<String>, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::event(
&self.namespace,
serde_json::json!([event.into(), data]),
None,
);
self.send_packet(&packet)
}
pub fn emit_with_ack(
&self,
event: impl Into<String>,
data: serde_json::Value,
) -> Result<u64, mpsc::error::TrySendError<Packet>> {
let ack_id = self.next_ack_id();
let packet = Packet::event(
&self.namespace,
serde_json::json!([event.into(), data]),
Some(ack_id),
);
self.send_packet(&packet)?;
Ok(ack_id)
}
pub fn disconnect(&self) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::disconnect(&self.namespace);
self.send_packet(&packet)
}
pub fn send_ack(&self, id: u64, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::ack(&self.namespace, data, id);
self.send_packet(&packet)
}
}