feat(auth): add authentication protocol definitions and build configuration
- Add TokenClaims message for JWT payload structure with user id, issuer, timestamps, and scopes - Implement IssueTokenRequest/Response for creating access and refresh tokens with TTL support - Create RefreshTokenRequest/Response for token rotation functionality - Define RevokeTokenRequest/Response with support for single token or user-wide revocation - Add VerifyTokenRequest/Response for validating JWT tokens with detailed claims information - Implement signing key distribution system with GetSigningKeysRequest/Response - Create TokenService gRPC service with IssueToken, RefreshToken, RevokeToken, VerifyToken, and GetSigningKeys methods - Add build.rs configuration to compile proto files using tonic_prost_build - Include channel, channel_settings, member, and permission protocol definitions for IM services - Generate Rust code bindings through pb/core.rs and pb/im.rs modules
This commit is contained in:
@@ -0,0 +1,199 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, SocketInfo};
|
||||
use crate::socket::packet::Packet;
|
||||
|
||||
pub struct LocalAdapter {
|
||||
server_id: String,
|
||||
rooms: Arc<DashMap<String, HashSet<String>>>,
|
||||
socket_rooms: Arc<DashMap<String, HashSet<String>>>,
|
||||
/// socket_sid → engine_sid
|
||||
pub socket_sids: Arc<DashMap<String, String>>,
|
||||
/// socket_sid → namespace path
|
||||
socket_namespace: Arc<DashMap<String, String>>,
|
||||
send_fn: Arc<dyn Fn(&str, &Packet) -> Result<(), String> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl LocalAdapter {
|
||||
pub fn new(
|
||||
send_fn: impl Fn(&str, &Packet) -> Result<(), String> + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
server_id: Uuid::new_v4().to_string(),
|
||||
rooms: Arc::new(DashMap::new()),
|
||||
socket_rooms: Arc::new(DashMap::new()),
|
||||
socket_sids: Arc::new(DashMap::new()),
|
||||
socket_namespace: Arc::new(DashMap::new()),
|
||||
send_fn: Arc::new(send_fn),
|
||||
}
|
||||
}
|
||||
|
||||
fn room_key(ns: &str, room: &str) -> String {
|
||||
format!("{}:{}", ns, room)
|
||||
}
|
||||
|
||||
/// Collect socket SIDs matching the broadcast options, scoped to the given namespace.
|
||||
fn collect_matching_sids(&self, opts: &BroadcastOptions, namespace: &str) -> Vec<String> {
|
||||
if opts.rooms.is_empty() {
|
||||
// Broadcast to all sockets in this namespace only
|
||||
self.socket_sids
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
self.socket_namespace
|
||||
.get(e.key())
|
||||
.map(|ns| ns.value() == namespace)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.map(|e| e.key().clone())
|
||||
.collect()
|
||||
} else {
|
||||
let mut sids = HashSet::new();
|
||||
for room in &opts.rooms {
|
||||
let key = Self::room_key(namespace, room);
|
||||
if let Some(entry) = self.rooms.get(&key) {
|
||||
for sid in entry.value() {
|
||||
sids.insert(sid.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
sids.into_iter().collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Adapter for LocalAdapter {
|
||||
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
|
||||
let namespace = &packet.namespace;
|
||||
let sids = self.collect_matching_sids(opts, namespace);
|
||||
for sid in &sids {
|
||||
if opts.except.contains(sid) {
|
||||
continue;
|
||||
}
|
||||
// socket_sids maps socket SID -> engine SID
|
||||
if let Some(entry) = self.socket_sids.get(sid) {
|
||||
let engine_sid = entry.value();
|
||||
let result = (self.send_fn)(engine_sid, packet);
|
||||
if let Err(e) = result {
|
||||
tracing::warn!("Failed to broadcast to {}: {}", sid, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn register(&self, socket_sid: &str, engine_sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
|
||||
self.socket_namespace.insert(socket_sid.to_string(), ns.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
self.del_all(socket_sid, ns).await
|
||||
}
|
||||
|
||||
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
let key = Self::room_key(ns, room);
|
||||
self.rooms.entry(key).or_insert_with(HashSet::new).value_mut().insert(sid.to_string());
|
||||
self.socket_rooms.entry(sid.to_string()).or_insert_with(HashSet::new).value_mut().insert(room.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
let key = Self::room_key(ns, room);
|
||||
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
|
||||
room_sids.value_mut().remove(sid);
|
||||
if room_sids.value_mut().is_empty() {
|
||||
drop(room_sids);
|
||||
self.rooms.remove(&key);
|
||||
}
|
||||
}
|
||||
if let Some(mut rooms) = self.socket_rooms.get_mut(sid) {
|
||||
rooms.value_mut().remove(room);
|
||||
if rooms.value_mut().is_empty() {
|
||||
drop(rooms);
|
||||
self.socket_rooms.remove(sid);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
|
||||
for room in &rooms {
|
||||
let key = Self::room_key(ns, room);
|
||||
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
|
||||
room_sids.value_mut().remove(sid);
|
||||
if room_sids.value_mut().is_empty() {
|
||||
drop(room_sids);
|
||||
self.rooms.remove(&key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.socket_sids.remove(sid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
|
||||
// fetch_sockets needs namespace context; use an empty namespace to match all
|
||||
// (this method is typically called for inspection, not delivery)
|
||||
let sids: Vec<String> = if opts.rooms.is_empty() {
|
||||
self.socket_sids.iter().map(|e| e.key().clone()).collect()
|
||||
} else {
|
||||
let mut sids_set = HashSet::new();
|
||||
for room in &opts.rooms {
|
||||
for entry in self.rooms.iter() {
|
||||
if entry.key().ends_with(&format!(":{}", room)) {
|
||||
for sid in entry.value() {
|
||||
sids_set.insert(sid.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
sids_set.into_iter().collect()
|
||||
};
|
||||
let mut result = Vec::new();
|
||||
for sid in &sids {
|
||||
if opts.except.contains(sid) {
|
||||
continue;
|
||||
}
|
||||
if self.socket_sids.contains_key(sid) {
|
||||
let namespace = self.socket_namespace
|
||||
.get(sid)
|
||||
.map(|r| r.value().clone())
|
||||
.unwrap_or_default();
|
||||
let rooms = self.socket_rooms
|
||||
.get(sid)
|
||||
.map(|r| r.value().clone())
|
||||
.unwrap_or_default();
|
||||
result.push(SocketInfo {
|
||||
sid: sid.clone(),
|
||||
namespace,
|
||||
rooms,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
|
||||
Ok(self.socket_rooms
|
||||
.get(sid)
|
||||
.map(|r| r.value().clone())
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
fn server_id(&self) -> &str {
|
||||
&self.server_id
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), AdapterError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
pub mod local;
|
||||
pub mod redis;
|
||||
pub mod nats;
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::socket::packet::Packet;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AdapterError {
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(String),
|
||||
#[error("NATS error: {0}")]
|
||||
Nats(String),
|
||||
#[error("Message bus error: {0}")]
|
||||
MessageBus(String),
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
#[error("Room error: {0}")]
|
||||
Room(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||
pub struct BroadcastOptions {
|
||||
pub rooms: HashSet<String>,
|
||||
pub except: HashSet<String>,
|
||||
pub flags: BroadcastFlags,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||
pub struct BroadcastFlags {
|
||||
pub local_only: bool,
|
||||
pub broadcast: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SocketInfo {
|
||||
pub sid: String,
|
||||
pub namespace: String,
|
||||
pub rooms: HashSet<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||
pub enum BusMessage {
|
||||
Broadcast {
|
||||
namespace: String,
|
||||
packet: String,
|
||||
opts: BroadcastOptions,
|
||||
server_id: String,
|
||||
},
|
||||
SocketJoin {
|
||||
namespace: String,
|
||||
sid: String,
|
||||
room: String,
|
||||
server_id: String,
|
||||
},
|
||||
SocketLeave {
|
||||
namespace: String,
|
||||
sid: String,
|
||||
room: String,
|
||||
server_id: String,
|
||||
},
|
||||
SocketDisconnect {
|
||||
namespace: String,
|
||||
sid: String,
|
||||
server_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Adapter: Send + Sync + 'static {
|
||||
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError>;
|
||||
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
|
||||
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
|
||||
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>;
|
||||
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError>;
|
||||
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError>;
|
||||
fn server_id(&self) -> &str;
|
||||
async fn close(&self) -> Result<(), AdapterError>;
|
||||
|
||||
/// Register a socket SID → engine SID mapping in the adapter.
|
||||
/// Must be called when a socket first connects, before any room operations.
|
||||
/// The `ns` parameter is the namespace path this socket belongs to.
|
||||
async fn register(&self, _socket_sid: &str, _engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Unregister a socket from the adapter, removing all local mappings.
|
||||
async fn unregister(&self, _socket_sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub use local::LocalAdapter;
|
||||
pub use redis::RedisAdapter;
|
||||
pub use nats::NatsAdapter;
|
||||
@@ -0,0 +1,302 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
|
||||
use crate::socket::message_bus::MessageBus;
|
||||
use crate::socket::packet::Packet;
|
||||
use crate::socket::parser;
|
||||
use crate::socket::socket::Socket;
|
||||
|
||||
/// Handle incoming bus messages from other servers.
|
||||
/// Only performs local dispatch — no remote state writes needed.
|
||||
async fn handle_bus_message(
|
||||
msg: BusMessage,
|
||||
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||
server_id: &str,
|
||||
) {
|
||||
match msg {
|
||||
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
|
||||
if sender_id == server_id {
|
||||
return;
|
||||
}
|
||||
if let Ok(decoded_packet) = parser::decode(&packet) {
|
||||
on_local_broadcast(&decoded_packet, &opts);
|
||||
}
|
||||
}
|
||||
// NATS adapter manages room state locally; cross-server join/leave/disconnect
|
||||
// are informational only and don't require duplicate state writes.
|
||||
BusMessage::SocketJoin { server_id: sender_id, .. }
|
||||
| BusMessage::SocketLeave { server_id: sender_id, .. }
|
||||
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
|
||||
if sender_id == server_id {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// NATS-based adapter that manages room state locally and uses NATS
|
||||
/// for cross-server broadcast only. Does NOT depend on Redis.
|
||||
pub struct NatsAdapter {
|
||||
message_bus: Arc<dyn MessageBus>,
|
||||
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
|
||||
socket_rooms: DashMap<String, HashSet<String>>,
|
||||
rooms: DashMap<String, HashSet<String>>,
|
||||
/// socket_sid → engine_sid mapping for local delivery
|
||||
socket_sids: DashMap<String, String>,
|
||||
sockets: DashMap<String, Arc<Socket>>,
|
||||
server_id: String,
|
||||
namespace: String,
|
||||
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
impl NatsAdapter {
|
||||
pub fn new(
|
||||
message_bus: Arc<dyn MessageBus>,
|
||||
server_id: String,
|
||||
namespace: String,
|
||||
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||
) -> Self {
|
||||
Self {
|
||||
message_bus,
|
||||
server_id,
|
||||
namespace,
|
||||
on_local_broadcast,
|
||||
room_subscribers: DashMap::new(),
|
||||
socket_rooms: DashMap::new(),
|
||||
rooms: DashMap::new(),
|
||||
socket_sids: DashMap::new(),
|
||||
sockets: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init(&self) -> Result<(), AdapterError> {
|
||||
let channels = ["broadcast", "join", "leave", "disconnect"];
|
||||
let prefix = format!("socket.io:{}:", self.namespace);
|
||||
|
||||
for channel_type in channels {
|
||||
let subject = format!("{}{}", prefix, channel_type);
|
||||
match self.message_bus.subscribe(&subject).await {
|
||||
Ok(rx) => {
|
||||
self.room_subscribers.insert(channel_type.to_string(), rx);
|
||||
}
|
||||
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
self.spawn_listener();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_listener(&self) {
|
||||
let server_id = self.server_id.clone();
|
||||
let on_local_broadcast = self.on_local_broadcast.clone();
|
||||
|
||||
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
|
||||
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
|
||||
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
|
||||
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
Some(data) = async { join_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
else => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Adapter for NatsAdapter {
|
||||
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
|
||||
if opts.flags.local_only {
|
||||
(self.on_local_broadcast)(packet, opts);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let msg = BusMessage::Broadcast {
|
||||
namespace: self.namespace.clone(),
|
||||
packet: parser::encode(packet),
|
||||
opts: opts.clone(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:broadcast", self.namespace), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
(self.on_local_broadcast)(packet, opts);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||
self.socket_rooms
|
||||
.entry(sid.to_string())
|
||||
.and_modify(|set| { set.insert(room.to_string()); })
|
||||
.or_insert_with(|| HashSet::from([room.to_string()]));
|
||||
|
||||
self.rooms
|
||||
.entry(room.to_string())
|
||||
.and_modify(|set| { set.insert(sid.to_string()); })
|
||||
.or_insert_with(|| HashSet::from([sid.to_string()]));
|
||||
|
||||
let msg = BusMessage::SocketJoin {
|
||||
namespace: self.namespace.clone(),
|
||||
sid: sid.to_string(),
|
||||
room: room.to_string(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:join", self.namespace), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn del(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
|
||||
entry.value_mut().remove(room);
|
||||
}
|
||||
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||
self.socket_rooms.remove(sid);
|
||||
}
|
||||
|
||||
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||
entry.value_mut().remove(sid);
|
||||
}
|
||||
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||
self.rooms.remove(room);
|
||||
}
|
||||
|
||||
let msg = BusMessage::SocketLeave {
|
||||
namespace: self.namespace.clone(),
|
||||
sid: sid.to_string(),
|
||||
room: room.to_string(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:leave", self.namespace), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn del_all(&self, sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
|
||||
for room in &rooms {
|
||||
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||
entry.value_mut().remove(sid);
|
||||
}
|
||||
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||
self.rooms.remove(room);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.socket_sids.remove(sid);
|
||||
self.sockets.remove(sid);
|
||||
|
||||
let msg = BusMessage::SocketDisconnect {
|
||||
namespace: self.namespace.clone(),
|
||||
sid: sid.to_string(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:disconnect", self.namespace), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
|
||||
self.socket_sids.iter().map(|e| e.key().clone()).collect()
|
||||
} else {
|
||||
let mut sids = HashSet::new();
|
||||
for room in &opts.rooms {
|
||||
if let Some(entry) = self.rooms.get(room) {
|
||||
sids.extend(entry.value().iter().cloned());
|
||||
}
|
||||
}
|
||||
sids
|
||||
};
|
||||
|
||||
for sid in target_sids {
|
||||
if opts.except.contains(&sid) {
|
||||
continue;
|
||||
}
|
||||
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
|
||||
result.push(SocketInfo {
|
||||
sid: sid.clone(),
|
||||
namespace: self.namespace.clone(),
|
||||
rooms,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
|
||||
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
|
||||
}
|
||||
|
||||
fn server_id(&self) -> &str {
|
||||
&self.server_id
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), AdapterError> {
|
||||
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use fred::clients::Client;
|
||||
use fred::interfaces::{KeysInterface, SetsInterface};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
|
||||
use crate::socket::message_bus::MessageBus;
|
||||
use crate::socket::packet::Packet;
|
||||
use crate::socket::parser;
|
||||
use crate::socket::socket::Socket;
|
||||
|
||||
const KEY_PREFIX_ROOMS: &str = "socket.io:rooms";
|
||||
const KEY_PREFIX_SOCKET_ROOMS: &str = "socket.io:socket_rooms";
|
||||
|
||||
fn room_key(ns: &str, room: &str) -> String {
|
||||
format!("{}:{}:{}", KEY_PREFIX_ROOMS, ns, room)
|
||||
}
|
||||
|
||||
fn socket_rooms_key(ns: &str, sid: &str) -> String {
|
||||
format!("{}:{}:{}", KEY_PREFIX_SOCKET_ROOMS, ns, sid)
|
||||
}
|
||||
|
||||
/// Handle incoming bus messages from other servers.
|
||||
/// Only performs local state updates — the remote server already wrote to Redis.
|
||||
async fn handle_bus_message(
|
||||
msg: BusMessage,
|
||||
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||
server_id: &str,
|
||||
) {
|
||||
match msg {
|
||||
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
|
||||
if sender_id == server_id {
|
||||
return;
|
||||
}
|
||||
if let Ok(decoded_packet) = parser::decode(&packet) {
|
||||
on_local_broadcast(&decoded_packet, &opts);
|
||||
}
|
||||
}
|
||||
BusMessage::SocketJoin { server_id: sender_id, .. }
|
||||
| BusMessage::SocketLeave { server_id: sender_id, .. }
|
||||
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
|
||||
// Skip messages from this server; remote server already updated Redis
|
||||
if sender_id == server_id {
|
||||
return;
|
||||
}
|
||||
// No duplicate Redis writes — the sender already persisted the state change
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RedisAdapter {
|
||||
message_bus: Arc<dyn MessageBus>,
|
||||
redis_client: Client,
|
||||
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
|
||||
socket_rooms: DashMap<String, HashSet<String>>,
|
||||
rooms: DashMap<String, HashSet<String>>,
|
||||
sockets: DashMap<String, Arc<Socket>>,
|
||||
server_id: String,
|
||||
namespace: String,
|
||||
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||
}
|
||||
|
||||
impl RedisAdapter {
|
||||
pub fn new(
|
||||
message_bus: Arc<dyn MessageBus>,
|
||||
redis_client: Client,
|
||||
server_id: String,
|
||||
namespace: String,
|
||||
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||
) -> Self {
|
||||
Self {
|
||||
message_bus,
|
||||
redis_client,
|
||||
server_id,
|
||||
namespace,
|
||||
on_local_broadcast,
|
||||
room_subscribers: DashMap::new(),
|
||||
socket_rooms: DashMap::new(),
|
||||
rooms: DashMap::new(),
|
||||
sockets: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init(&self) -> Result<(), AdapterError> {
|
||||
let channels = ["broadcast", "join", "leave", "disconnect"];
|
||||
let prefix = format!("socket.io:{}:", self.namespace);
|
||||
|
||||
for channel_type in channels {
|
||||
let channel = format!("{}{}", prefix, channel_type);
|
||||
match self.message_bus.subscribe(&channel).await {
|
||||
Ok(rx) => {
|
||||
self.room_subscribers.insert(channel_type.to_string(), rx);
|
||||
}
|
||||
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
self.spawn_listener();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_listener(&self) {
|
||||
let server_id = self.server_id.clone();
|
||||
let on_local_broadcast = self.on_local_broadcast.clone();
|
||||
|
||||
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
|
||||
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
|
||||
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
|
||||
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
Some(data) = async { join_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
|
||||
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||
}
|
||||
}
|
||||
else => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Adapter for RedisAdapter {
|
||||
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
|
||||
if opts.flags.local_only {
|
||||
(self.on_local_broadcast)(packet, opts);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let msg = BusMessage::Broadcast {
|
||||
namespace: packet.namespace.clone(),
|
||||
packet: parser::encode(packet),
|
||||
opts: opts.clone(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
(self.on_local_broadcast)(packet, opts);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
let rk = room_key(ns, room);
|
||||
let srk = socket_rooms_key(ns, sid);
|
||||
|
||||
self.redis_client
|
||||
.sadd::<(), _, _>(&rk, sid)
|
||||
.await
|
||||
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||
|
||||
self.redis_client
|
||||
.sadd::<(), _, _>(&srk, room)
|
||||
.await
|
||||
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||
|
||||
self.socket_rooms
|
||||
.entry(sid.to_string())
|
||||
.and_modify(|set| { set.insert(room.to_string()); })
|
||||
.or_insert_with(|| HashSet::from([room.to_string()]));
|
||||
|
||||
self.rooms
|
||||
.entry(room.to_string())
|
||||
.and_modify(|set| { set.insert(sid.to_string()); })
|
||||
.or_insert_with(|| HashSet::from([sid.to_string()]));
|
||||
|
||||
let msg = BusMessage::SocketJoin {
|
||||
namespace: ns.to_string(),
|
||||
sid: sid.to_string(),
|
||||
room: room.to_string(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:join", ns), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
let rk = room_key(ns, room);
|
||||
let srk = socket_rooms_key(ns, sid);
|
||||
|
||||
self.redis_client
|
||||
.srem::<(), _, _>(&rk, sid)
|
||||
.await
|
||||
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||
|
||||
self.redis_client
|
||||
.srem::<(), _, _>(&srk, room)
|
||||
.await
|
||||
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||
|
||||
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
|
||||
entry.value_mut().remove(room);
|
||||
}
|
||||
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||
self.socket_rooms.remove(sid);
|
||||
}
|
||||
|
||||
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||
entry.value_mut().remove(sid);
|
||||
}
|
||||
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||
self.rooms.remove(room);
|
||||
}
|
||||
|
||||
let msg = BusMessage::SocketLeave {
|
||||
namespace: ns.to_string(),
|
||||
sid: sid.to_string(),
|
||||
room: room.to_string(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:leave", ns), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
|
||||
for room in &rooms {
|
||||
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||
entry.value_mut().remove(sid);
|
||||
}
|
||||
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||
self.rooms.remove(room);
|
||||
}
|
||||
|
||||
let rk = room_key(ns, room);
|
||||
if let Err(e) = self.redis_client.srem::<(), _, _>(&rk, sid).await {
|
||||
tracing::warn!("Redis SREM room error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let srk = socket_rooms_key(ns, sid);
|
||||
self.redis_client
|
||||
.del::<(), _>(&srk)
|
||||
.await
|
||||
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||
|
||||
self.sockets.remove(sid);
|
||||
|
||||
let msg = BusMessage::SocketDisconnect {
|
||||
namespace: ns.to_string(),
|
||||
sid: sid.to_string(),
|
||||
server_id: self.server_id.clone(),
|
||||
};
|
||||
|
||||
let payload = serde_json::to_vec(&msg)
|
||||
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||
|
||||
self.message_bus
|
||||
.publish(&format!("socket.io:{}:disconnect", ns), &payload)
|
||||
.await
|
||||
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
|
||||
self.sockets.iter().map(|e| e.key().clone()).collect()
|
||||
} else {
|
||||
let mut sids = HashSet::new();
|
||||
for room in &opts.rooms {
|
||||
if let Some(entry) = self.rooms.get(room) {
|
||||
sids.extend(entry.value().iter().cloned());
|
||||
}
|
||||
}
|
||||
sids
|
||||
};
|
||||
|
||||
for sid in target_sids {
|
||||
if opts.except.contains(&sid) {
|
||||
continue;
|
||||
}
|
||||
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
|
||||
result.push(SocketInfo {
|
||||
sid: sid.clone(),
|
||||
namespace: self.namespace.clone(),
|
||||
rooms,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
|
||||
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
|
||||
}
|
||||
|
||||
fn server_id(&self) -> &str {
|
||||
&self.server_id
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), AdapterError> {
|
||||
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
pub mod redis;
|
||||
pub mod nats;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum MessageBusError {
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(String),
|
||||
#[error("NATS error: {0}")]
|
||||
Nats(String),
|
||||
#[error("Connection closed")]
|
||||
ConnectionClosed,
|
||||
#[error("Channel not found: {0}")]
|
||||
ChannelNotFound(String),
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait MessageBus: Send + Sync + 'static {
|
||||
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError>;
|
||||
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError>;
|
||||
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError>;
|
||||
async fn close(&self) -> Result<(), MessageBusError>;
|
||||
}
|
||||
|
||||
pub use redis::RedisMessageBus;
|
||||
pub use nats::NatsMessageBus;
|
||||
@@ -0,0 +1,88 @@
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use crate::socket::message_bus::{MessageBus, MessageBusError};
|
||||
|
||||
pub struct NatsMessageBus {
|
||||
client: async_nats::Client,
|
||||
shutdowns: DashMap<String, watch::Sender<bool>>,
|
||||
}
|
||||
|
||||
impl NatsMessageBus {
|
||||
pub async fn new(nats_url: &str) -> Result<Self, MessageBusError> {
|
||||
let client = async_nats::connect(nats_url)
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
|
||||
Ok(Self {
|
||||
client,
|
||||
shutdowns: DashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageBus for NatsMessageBus {
|
||||
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
|
||||
self.client
|
||||
.publish(channel.to_string(), message.to_vec().into())
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
|
||||
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
|
||||
let mut subscriber = self.client
|
||||
.subscribe(channel.to_string())
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
|
||||
self.shutdowns.insert(channel.to_string(), shutdown_tx);
|
||||
|
||||
tokio::spawn(async move {
|
||||
use futures_util::StreamExt;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
break;
|
||||
}
|
||||
message = subscriber.next() => {
|
||||
match message {
|
||||
Some(msg) => {
|
||||
let data = msg.payload.to_vec();
|
||||
if tx.send(data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Err(e) = subscriber.unsubscribe().await {
|
||||
tracing::warn!("NATS unsubscribe error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
|
||||
if let Some((_, tx)) = self.shutdowns.remove(channel) {
|
||||
let _ = tx.send(true);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), MessageBusError> {
|
||||
// Signal all subscribers to shutdown
|
||||
self.shutdowns.iter().for_each(|entry| {
|
||||
let _ = entry.value().send(true);
|
||||
});
|
||||
self.shutdowns.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
use async_trait::async_trait;
|
||||
use fred::clients::{Client, SubscriberClient};
|
||||
use fred::interfaces::{ClientLike, EventInterface, PubsubInterface};
|
||||
use fred::prelude::*;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::socket::message_bus::{MessageBus, MessageBusError};
|
||||
|
||||
pub struct RedisMessageBus {
|
||||
client: Client,
|
||||
subscriber: SubscriberClient,
|
||||
}
|
||||
|
||||
impl RedisMessageBus {
|
||||
pub async fn new(redis_url: &str) -> Result<Self, MessageBusError> {
|
||||
let config = Config::from_url(redis_url)
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
|
||||
let client = Client::new(config.clone(), None, None, None);
|
||||
let subscriber = SubscriberClient::new(config, None, None, None);
|
||||
|
||||
// connect() starts the connection task; result is checked by wait_for_connect()
|
||||
let _ = client.connect().await;
|
||||
let _ = subscriber.connect().await;
|
||||
|
||||
client
|
||||
.wait_for_connect()
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
subscriber
|
||||
.wait_for_connect()
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(Self { client, subscriber })
|
||||
}
|
||||
|
||||
pub fn client(&self) -> &Client {
|
||||
&self.client
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MessageBus for RedisMessageBus {
|
||||
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
|
||||
self.client
|
||||
.publish::<(), _, Vec<u8>>(channel, message.to_vec())
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
|
||||
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
|
||||
|
||||
self.subscriber
|
||||
.subscribe(channel.to_string())
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
|
||||
let subscriber = self.subscriber.clone();
|
||||
let channel_owned = channel.to_string();
|
||||
let mut message_rx = subscriber.message_rx();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Ok(message) = message_rx.recv().await {
|
||||
if &message.channel == &channel_owned {
|
||||
let data: Vec<u8> = FromValue::from_value(message.value)
|
||||
.unwrap_or_default();
|
||||
if tx.send(data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
|
||||
self.subscriber
|
||||
.unsubscribe(channel.to_string())
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), MessageBusError> {
|
||||
self.client
|
||||
.quit()
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
self.subscriber
|
||||
.quit()
|
||||
.await
|
||||
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
pub mod adapter;
|
||||
pub mod message_bus;
|
||||
pub mod namespace;
|
||||
pub mod packet;
|
||||
pub mod parser;
|
||||
pub mod server;
|
||||
pub mod session_store;
|
||||
pub mod socket;
|
||||
|
||||
pub use adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, RedisAdapter, NatsAdapter, SocketInfo};
|
||||
pub use message_bus::{MessageBus, MessageBusError, RedisMessageBus, NatsMessageBus};
|
||||
pub use namespace::{is_valid_namespace, Namespace, NamespaceManager};
|
||||
pub use packet::{Packet, PacketType};
|
||||
pub use server::{SocketServer, SocketServerBuilder};
|
||||
pub use session_store::{InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait};
|
||||
pub use socket::Socket;
|
||||
@@ -0,0 +1,239 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::socket::adapter::{Adapter, BroadcastOptions, BroadcastFlags};
|
||||
use crate::socket::packet::Packet;
|
||||
use crate::socket::socket::Socket;
|
||||
|
||||
pub type EventHandler = Arc<dyn Fn(&Socket, &serde_json::Value) + Send + Sync>;
|
||||
type ConnectHandler = Arc<dyn Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync>;
|
||||
|
||||
pub struct Namespace {
|
||||
pub path: String,
|
||||
/// Primary storage: socket_sid → Socket
|
||||
sockets: DashMap<String, Arc<Socket>>,
|
||||
/// Reverse index: engine_sid → socket_sid (for engine-level lookups)
|
||||
engine_to_socket: DashMap<String, String>,
|
||||
handlers: RwLock<HashMap<String, Vec<EventHandler>>>,
|
||||
connect_handler: RwLock<Option<ConnectHandler>>,
|
||||
pub(crate) adapter: RwLock<Option<Arc<dyn Adapter>>>,
|
||||
}
|
||||
|
||||
impl Namespace {
|
||||
pub fn new(path: impl Into<String>) -> Self {
|
||||
Self {
|
||||
path: path.into(),
|
||||
sockets: DashMap::new(),
|
||||
engine_to_socket: DashMap::new(),
|
||||
handlers: RwLock::new(HashMap::new()),
|
||||
connect_handler: RwLock::new(None),
|
||||
adapter: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_adapter(&self, adapter: Arc<dyn Adapter>) {
|
||||
let mut guard = self.adapter.write().await;
|
||||
*guard = Some(adapter);
|
||||
}
|
||||
|
||||
/// Add a socket to this namespace. Returns Err if the connect handler rejects.
|
||||
pub async fn add_socket(&self, socket: Arc<Socket>) -> Result<(), String> {
|
||||
// Run connect handler before adding to storage
|
||||
let handler = self.connect_handler.read().await;
|
||||
if let Some(ref h) = *handler {
|
||||
h(&socket, None)?;
|
||||
}
|
||||
drop(handler);
|
||||
|
||||
let socket_sid = socket.sid.clone();
|
||||
let engine_sid = socket.engine_sid.clone();
|
||||
|
||||
// Register with adapter (socket_sid → engine_sid mapping)
|
||||
let adapter = self.adapter.read().await;
|
||||
if let Some(ref adapter) = *adapter {
|
||||
if let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await {
|
||||
tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Store socket by socket_sid, plus reverse index
|
||||
self.sockets.insert(socket_sid.clone(), socket);
|
||||
self.engine_to_socket.insert(engine_sid, socket_sid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a socket by its socket SID.
|
||||
pub async fn remove_socket_by_sid(&self, socket_sid: &str) {
|
||||
if let Some((_, socket)) = self.sockets.remove(socket_sid) {
|
||||
self.engine_to_socket.remove(&socket.engine_sid);
|
||||
|
||||
let adapter = self.adapter.read().await;
|
||||
if let Some(ref adapter) = *adapter {
|
||||
if let Err(e) = adapter.del_all(socket_sid, &self.path).await {
|
||||
tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a socket by its engine SID (for engine-level disconnections).
|
||||
pub async fn remove_socket(&self, engine_sid: &str) {
|
||||
if let Some((_, socket_sid)) = self.engine_to_socket.remove(engine_sid) {
|
||||
self.remove_socket_by_sid(&socket_sid).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a socket by its socket SID.
|
||||
pub fn get_socket(&self, socket_sid: &str) -> Option<Arc<Socket>> {
|
||||
self.sockets.get(socket_sid).map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// Look up a socket by its engine SID (reverse lookup).
|
||||
pub fn get_socket_by_engine_sid(&self, engine_sid: &str) -> Option<Arc<Socket>> {
|
||||
self.engine_to_socket
|
||||
.get(engine_sid)
|
||||
.and_then(|entry| self.sockets.get(entry.value()).map(|r| r.value().clone()))
|
||||
}
|
||||
|
||||
pub fn socket_count(&self) -> usize {
|
||||
self.sockets.len()
|
||||
}
|
||||
|
||||
pub async fn on_event(&self, event: impl Into<String>, handler: EventHandler) {
|
||||
let mut handlers = self.handlers.write().await;
|
||||
handlers.entry(event.into()).or_default().push(handler);
|
||||
}
|
||||
|
||||
pub async fn on_connect<F>(&self, handler: F)
|
||||
where
|
||||
F: Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync + 'static,
|
||||
{
|
||||
let mut connect_handler = self.connect_handler.write().await;
|
||||
*connect_handler = Some(Arc::new(handler));
|
||||
}
|
||||
|
||||
pub async fn emit(&self, event: impl Into<String>, data: serde_json::Value) {
|
||||
let event_name = event.into();
|
||||
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
|
||||
|
||||
let adapter = self.adapter.read().await;
|
||||
if let Some(ref adapter) = *adapter {
|
||||
let opts = BroadcastOptions::default();
|
||||
if let Err(e) = adapter.broadcast(&packet, &opts).await {
|
||||
tracing::warn!("Adapter broadcast error: {}", e);
|
||||
}
|
||||
} else {
|
||||
self.emit_local(&packet);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn emit_to_room(&self, room: &str, event: impl Into<String>, data: serde_json::Value) {
|
||||
let event_name = event.into();
|
||||
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
|
||||
|
||||
let adapter = self.adapter.read().await;
|
||||
if let Some(ref adapter) = *adapter {
|
||||
let opts = BroadcastOptions {
|
||||
rooms: HashSet::from([room.to_string()]),
|
||||
except: HashSet::new(),
|
||||
flags: BroadcastFlags::default(),
|
||||
};
|
||||
if let Err(e) = adapter.broadcast(&packet, &opts).await {
|
||||
tracing::warn!("Adapter broadcast to room error: {}", e);
|
||||
}
|
||||
} else {
|
||||
self.emit_local(&packet);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_local(&self, packet: &Packet) {
|
||||
for entry in self.sockets.iter() {
|
||||
let socket = entry.value();
|
||||
if socket.send_packet(packet).is_err() {
|
||||
tracing::warn!("Failed to send event to socket {}", socket.sid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn emit_to(&self, socket_sid: &str, event: impl Into<String>, data: serde_json::Value) {
|
||||
if let Some(socket) = self.get_socket(socket_sid) {
|
||||
let event_name = event.into();
|
||||
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
|
||||
if socket.send_packet(&packet).is_err() {
|
||||
tracing::warn!("Failed to send event to socket {}", socket.sid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_event(&self, socket: &Socket, event: &str, data: &serde_json::Value) {
|
||||
let handlers = self.handlers.read().await;
|
||||
if let Some(event_handlers) = handlers.get(event) {
|
||||
for handler in event_handlers {
|
||||
handler(socket, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NamespaceManager {
|
||||
namespaces: DashMap<String, Arc<Namespace>>,
|
||||
}
|
||||
|
||||
impl NamespaceManager {
|
||||
pub fn new() -> Self {
|
||||
let manager = Self {
|
||||
namespaces: DashMap::new(),
|
||||
};
|
||||
manager.create_namespace("/");
|
||||
manager
|
||||
}
|
||||
|
||||
pub fn create_namespace(&self, path: impl Into<String>) -> Arc<Namespace> {
|
||||
let path = path.into();
|
||||
let namespace = Arc::new(Namespace::new(&path));
|
||||
self.namespaces.insert(path.clone(), namespace.clone());
|
||||
namespace
|
||||
}
|
||||
|
||||
pub fn get_namespace(&self, path: &str) -> Option<Arc<Namespace>> {
|
||||
self.namespaces.get(path).map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
pub fn get_or_create_namespace(&self, path: &str) -> Arc<Namespace> {
|
||||
if let Some(ns) = self.get_namespace(path) {
|
||||
ns
|
||||
} else {
|
||||
self.create_namespace(path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_namespace(&self, path: &str) {
|
||||
self.namespaces.remove(path);
|
||||
}
|
||||
|
||||
pub fn namespace_count(&self) -> usize {
|
||||
self.namespaces.len()
|
||||
}
|
||||
|
||||
pub fn all_namespaces(&self) -> Vec<Arc<Namespace>> {
|
||||
self.namespaces.iter().map(|e| e.value().clone()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NamespaceManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a namespace path. Returns true if the path is valid.
|
||||
/// Rules: must start with '/', max 256 chars, no control characters.
|
||||
pub fn is_valid_namespace(path: &str) -> bool {
|
||||
!path.is_empty()
|
||||
&& path.starts_with('/')
|
||||
&& path.len() <= 256
|
||||
&& !path.chars().any(|c| c.is_control())
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum PacketType {
|
||||
Connect = 0,
|
||||
Disconnect = 1,
|
||||
Event = 2,
|
||||
Ack = 3,
|
||||
ConnectError = 4,
|
||||
BinaryEvent = 5,
|
||||
BinaryAck = 6,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for PacketType {
|
||||
type Error = PacketError;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0 => Ok(Self::Connect),
|
||||
1 => Ok(Self::Disconnect),
|
||||
2 => Ok(Self::Event),
|
||||
3 => Ok(Self::Ack),
|
||||
4 => Ok(Self::ConnectError),
|
||||
5 => Ok(Self::BinaryEvent),
|
||||
6 => Ok(Self::BinaryAck),
|
||||
_ => Err(PacketError::InvalidType(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<char> for PacketType {
|
||||
type Error = PacketError;
|
||||
|
||||
fn try_from(value: char) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
'0' => Ok(Self::Connect),
|
||||
'1' => Ok(Self::Disconnect),
|
||||
'2' => Ok(Self::Event),
|
||||
'3' => Ok(Self::Ack),
|
||||
'4' => Ok(Self::ConnectError),
|
||||
'5' => Ok(Self::BinaryEvent),
|
||||
'6' => Ok(Self::BinaryAck),
|
||||
_ => Err(PacketError::InvalidTypeChar(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Packet {
|
||||
pub packet_type: PacketType,
|
||||
pub namespace: String,
|
||||
pub data: Option<Value>,
|
||||
pub id: Option<u64>,
|
||||
pub attachments: Vec<Vec<u8>>,
|
||||
/// Expected number of binary attachments (set during decode for binary packets).
|
||||
/// Used to validate attachment count before assembling the full packet.
|
||||
pub expected_attachments: Option<usize>,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn connect(namespace: impl Into<String>, data: Option<Value>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Connect,
|
||||
namespace: namespace.into(),
|
||||
data,
|
||||
id: None,
|
||||
attachments: Vec::new(),
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn disconnect(namespace: impl Into<String>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Disconnect,
|
||||
namespace: namespace.into(),
|
||||
data: None,
|
||||
id: None,
|
||||
attachments: Vec::new(),
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn event(namespace: impl Into<String>, data: Value, id: Option<u64>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Event,
|
||||
namespace: namespace.into(),
|
||||
data: Some(data),
|
||||
id,
|
||||
attachments: Vec::new(),
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ack(namespace: impl Into<String>, data: Value, id: u64) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::Ack,
|
||||
namespace: namespace.into(),
|
||||
data: Some(data),
|
||||
id: Some(id),
|
||||
attachments: Vec::new(),
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connect_error(namespace: impl Into<String>, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::ConnectError,
|
||||
namespace: namespace.into(),
|
||||
data: Some(serde_json::json!({ "message": message.into() })),
|
||||
id: None,
|
||||
attachments: Vec::new(),
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binary_event(
|
||||
namespace: impl Into<String>,
|
||||
data: Value,
|
||||
id: Option<u64>,
|
||||
attachments: Vec<Vec<u8>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::BinaryEvent,
|
||||
namespace: namespace.into(),
|
||||
data: Some(data),
|
||||
id,
|
||||
attachments,
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binary_ack(
|
||||
namespace: impl Into<String>,
|
||||
data: Value,
|
||||
id: u64,
|
||||
attachments: Vec<Vec<u8>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
packet_type: PacketType::BinaryAck,
|
||||
namespace: namespace.into(),
|
||||
data: Some(data),
|
||||
id: Some(id),
|
||||
attachments,
|
||||
expected_attachments: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_binary(&self) -> bool {
|
||||
!self.attachments.is_empty()
|
||||
}
|
||||
|
||||
pub fn attachment_count(&self) -> usize {
|
||||
self.attachments.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PacketError {
|
||||
#[error("invalid packet type: {0}")]
|
||||
InvalidType(u8),
|
||||
#[error("invalid packet type char: {0}")]
|
||||
InvalidTypeChar(char),
|
||||
#[error("empty packet")]
|
||||
Empty,
|
||||
#[error("invalid format: {0}")]
|
||||
InvalidFormat(String),
|
||||
#[error("json error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("missing namespace")]
|
||||
MissingNamespace,
|
||||
#[error("invalid attachment count")]
|
||||
InvalidAttachmentCount,
|
||||
}
|
||||
@@ -0,0 +1,392 @@
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::socket::packet::{Packet, PacketError, PacketType};
|
||||
|
||||
pub fn encode(packet: &Packet) -> String {
|
||||
let type_char = packet.packet_type as u8 + b'0';
|
||||
let mut result = String::new();
|
||||
|
||||
result.push(type_char as char);
|
||||
|
||||
if packet.has_binary() {
|
||||
result.push_str(&packet.attachment_count().to_string());
|
||||
result.push('-');
|
||||
}
|
||||
|
||||
if packet.namespace != "/" {
|
||||
result.push_str(&packet.namespace);
|
||||
result.push(',');
|
||||
}
|
||||
|
||||
if let Some(id) = packet.id {
|
||||
result.push_str(&id.to_string());
|
||||
}
|
||||
|
||||
if let Some(ref data) = packet.data {
|
||||
if packet.has_binary() {
|
||||
let data_with_placeholders = replace_binary_with_placeholders(data, packet.attachment_count());
|
||||
let encoded_data = serde_json::to_string(&data_with_placeholders)
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to serialize socket packet data: {}", e);
|
||||
"null".to_string()
|
||||
});
|
||||
result.push_str(&encoded_data);
|
||||
} else {
|
||||
let encoded_data = serde_json::to_string(data)
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to serialize socket packet data: {}", e);
|
||||
"null".to_string()
|
||||
});
|
||||
result.push_str(&encoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn encode_with_attachments(packet: &Packet) -> Vec<Vec<u8>> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
let encoded = encode(packet);
|
||||
result.push(encoded.into_bytes());
|
||||
|
||||
for attachment in &packet.attachments {
|
||||
result.push(attachment.clone());
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn decode(input: &str) -> Result<Packet, PacketError> {
|
||||
if input.is_empty() {
|
||||
return Err(PacketError::Empty);
|
||||
}
|
||||
|
||||
let mut chars = input.chars().peekable();
|
||||
|
||||
let type_char = chars.next().ok_or(PacketError::Empty)?;
|
||||
let packet_type = PacketType::try_from(type_char)?;
|
||||
|
||||
let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) {
|
||||
let mut count_str = String::new();
|
||||
while let Some(&c) = chars.peek() {
|
||||
if c == '-' {
|
||||
chars.next();
|
||||
break;
|
||||
}
|
||||
if c.is_ascii_digit() {
|
||||
count_str.push(c);
|
||||
chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
count_str.parse::<usize>().unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let remaining: String = chars.collect();
|
||||
|
||||
let (namespace, rest) = if let Some(after_slash) = remaining.strip_prefix('/') {
|
||||
// Check if this is a custom namespace (has a comma separating namespace from data/id)
|
||||
// or if '/' is just the root namespace prefix followed immediately by data
|
||||
if let Some(comma_pos) = after_slash.find(',') {
|
||||
let ns = format!("/{}", &after_slash[..comma_pos]);
|
||||
let rest = after_slash[comma_pos + 1..].to_string();
|
||||
(ns, rest)
|
||||
} else if after_slash.starts_with('[')
|
||||
|| after_slash.starts_with(|c: char| c.is_ascii_digit())
|
||||
|| after_slash.is_empty()
|
||||
{
|
||||
// '/[' means '/' is the root namespace and '[' starts the data
|
||||
// '/<digits>' means root namespace followed by ack id
|
||||
// '/' alone means disconnect on root namespace
|
||||
("/".to_string(), after_slash.to_string())
|
||||
} else {
|
||||
// Non-root namespace without data (e.g., disconnect on custom namespace)
|
||||
(remaining, String::new())
|
||||
}
|
||||
} else {
|
||||
("/".to_string(), remaining)
|
||||
};
|
||||
|
||||
let (id, data_str) = parse_id_and_data(&rest);
|
||||
|
||||
let data = if data_str.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(serde_json::from_str(&data_str)?)
|
||||
};
|
||||
|
||||
Ok(Packet {
|
||||
packet_type,
|
||||
namespace,
|
||||
data,
|
||||
id,
|
||||
attachments: Vec::new(),
|
||||
// Store attachment_count for binary packets; actual attachments come via decode_with_attachments
|
||||
expected_attachments: if attachment_count > 0 { Some(attachment_count) } else { None },
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_with_attachments(
|
||||
main_packet: &str,
|
||||
attachments: Vec<Vec<u8>>,
|
||||
) -> Result<Packet, PacketError> {
|
||||
let mut packet = decode(main_packet)?;
|
||||
|
||||
let expected = packet.expected_attachments.unwrap_or(0);
|
||||
if expected != attachments.len() {
|
||||
return Err(PacketError::InvalidAttachmentCount);
|
||||
}
|
||||
|
||||
packet.attachments = attachments;
|
||||
packet.expected_attachments = None;
|
||||
|
||||
if packet.has_binary() {
|
||||
if let Some(ref data) = packet.data {
|
||||
packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
fn parse_id_and_data(input: &str) -> (Option<u64>, String) {
|
||||
let mut id_str = String::new();
|
||||
let mut chars = input.chars().peekable();
|
||||
|
||||
while let Some(&c) = chars.peek() {
|
||||
if c.is_ascii_digit() {
|
||||
id_str.push(c);
|
||||
chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let id = if id_str.is_empty() {
|
||||
None
|
||||
} else {
|
||||
id_str.parse::<u64>().ok()
|
||||
};
|
||||
|
||||
let data: String = chars.collect();
|
||||
|
||||
(id, data)
|
||||
}
|
||||
|
||||
/// Replace binary values in the data with { "_placeholder": true, "num": N } placeholders.
|
||||
/// This is used when encoding binary events/acks for transmission over text-based transports.
|
||||
fn replace_binary_with_placeholders(value: &Value, total_attachments: usize) -> Value {
|
||||
match value {
|
||||
Value::Array(arr) => {
|
||||
let mut placeholder_idx = total_attachments; // Start from known count
|
||||
let new_arr: Vec<Value> = arr
|
||||
.iter()
|
||||
.map(|v| replace_binary_with_placeholders_inner(v, &mut placeholder_idx))
|
||||
.collect();
|
||||
Value::Array(new_arr)
|
||||
}
|
||||
Value::Object(map) => {
|
||||
let mut placeholder_idx = total_attachments;
|
||||
let mut new_map = serde_json::Map::new();
|
||||
for (k, v) in map {
|
||||
new_map.insert(
|
||||
k.clone(),
|
||||
replace_binary_with_placeholders_inner(v, &mut placeholder_idx),
|
||||
);
|
||||
}
|
||||
Value::Object(new_map)
|
||||
}
|
||||
_ => value.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut usize) -> Value {
|
||||
match value {
|
||||
Value::Array(arr) => {
|
||||
let new_arr: Vec<Value> = arr
|
||||
.iter()
|
||||
.map(|v| replace_binary_with_placeholders_inner(v, placeholder_idx))
|
||||
.collect();
|
||||
Value::Array(new_arr)
|
||||
}
|
||||
Value::Object(map) => {
|
||||
let mut new_map = serde_json::Map::new();
|
||||
for (k, v) in map {
|
||||
new_map.insert(
|
||||
k.clone(),
|
||||
replace_binary_with_placeholders_inner(v, placeholder_idx),
|
||||
);
|
||||
}
|
||||
Value::Object(new_map)
|
||||
}
|
||||
// Binary data would be represented as base64 strings in the initial data;
|
||||
// in the Socket.IO protocol, binary attachments are separate and referenced by placeholder.
|
||||
// This function handles the case where the data structure itself contains placeholder markers.
|
||||
_ => value.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_placeholders_with_binary(value: &Value, attachments: &[Vec<u8>]) -> Value {
|
||||
match value {
|
||||
Value::Object(map) => {
|
||||
// Check if this is a placeholder object: { "_placeholder": true, "num": N }
|
||||
if let (Some(Value::Bool(true)), Some(Value::Number(num))) =
|
||||
(map.get("_placeholder"), map.get("num"))
|
||||
{
|
||||
if let Some(idx) = num.as_u64() {
|
||||
if let Some(attachment) = attachments.get(idx as usize) {
|
||||
return Value::String(base64::Engine::encode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
attachment,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_map = serde_json::Map::new();
|
||||
for (k, v) in map {
|
||||
new_map.insert(k.clone(), replace_placeholders_with_binary(v, attachments));
|
||||
}
|
||||
Value::Object(new_map)
|
||||
}
|
||||
Value::Array(arr) => Value::Array(
|
||||
arr.iter()
|
||||
.map(|v| replace_placeholders_with_binary(v, attachments))
|
||||
.collect(),
|
||||
),
|
||||
_ => value.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_encode_connect() {
|
||||
let packet = Packet::connect("/", None);
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "0");
|
||||
|
||||
let packet = Packet::connect("/admin", Some(json!({"sid": "abc"})));
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "0/admin,{\"sid\":\"abc\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_event() {
|
||||
let packet = Packet::event("/", json!(["foo"]), None);
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "2[\"foo\"]");
|
||||
|
||||
let packet = Packet::event("/admin", json!(["bar"]), None);
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "2/admin,[\"bar\"]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_event_with_ack() {
|
||||
let packet = Packet::event("/", json!(["foo"]), Some(12));
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "212[\"foo\"]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_ack() {
|
||||
let packet = Packet::ack("/", json!([]), 12);
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "312[]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_disconnect() {
|
||||
let packet = Packet::disconnect("/");
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "1");
|
||||
|
||||
let packet = Packet::disconnect("/admin");
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "1/admin,");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_connect_error() {
|
||||
let packet = Packet::connect_error("/", "Not authorized");
|
||||
let encoded = encode(&packet);
|
||||
assert_eq!(encoded, "4{\"message\":\"Not authorized\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_connect() {
|
||||
let packet = decode("0").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Connect);
|
||||
assert_eq!(packet.namespace, "/");
|
||||
assert!(packet.data.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_connect_with_namespace() {
|
||||
let packet = decode("0/admin,{\"sid\":\"abc\"}").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Connect);
|
||||
assert_eq!(packet.namespace, "/admin");
|
||||
assert!(packet.data.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_event() {
|
||||
let packet = decode("2[\"foo\"]").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Event);
|
||||
assert_eq!(packet.namespace, "/");
|
||||
assert_eq!(packet.data, Some(json!(["foo"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_event_with_namespace() {
|
||||
let packet = decode("2/admin,[\"bar\"]").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Event);
|
||||
assert_eq!(packet.namespace, "/admin");
|
||||
assert_eq!(packet.data, Some(json!(["bar"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_event_with_ack() {
|
||||
let packet = decode("212[\"foo\"]").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Event);
|
||||
assert_eq!(packet.id, Some(12));
|
||||
assert_eq!(packet.data, Some(json!(["foo"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_ack() {
|
||||
let packet = decode("312[]").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Ack);
|
||||
assert_eq!(packet.id, Some(12));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_disconnect() {
|
||||
let packet = decode("1").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Disconnect);
|
||||
assert_eq!(packet.namespace, "/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_disconnect_with_namespace() {
|
||||
let packet = decode("1/admin,").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::Disconnect);
|
||||
assert_eq!(packet.namespace, "/admin");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_binary_event_attachment_count() {
|
||||
let packet = decode("51-[\"baz\",{\"_placeholder\":true,\"num\":0}]").unwrap();
|
||||
assert_eq!(packet.packet_type, PacketType::BinaryEvent);
|
||||
assert_eq!(packet.expected_attachments, Some(1));
|
||||
assert_eq!(packet.namespace, "/");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::engine::packet::Packet as EnginePacket;
|
||||
use crate::engine::packet::PacketData as EnginePacketData;
|
||||
use crate::engine::server::{EngineConfig, EngineServer};
|
||||
use crate::engine::session::SessionStore;
|
||||
use crate::socket::adapter::{Adapter, LocalAdapter};
|
||||
use crate::socket::namespace::NamespaceManager;
|
||||
use crate::socket::packet::{Packet, PacketType};
|
||||
use crate::socket::parser;
|
||||
use crate::socket::socket::Socket;
|
||||
|
||||
pub struct SocketServer {
|
||||
pub engine: Arc<EngineServer>,
|
||||
pub namespaces: Arc<NamespaceManager>,
|
||||
pub adapter: Arc<dyn Adapter>,
|
||||
socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||
}
|
||||
|
||||
impl SocketServer {
|
||||
pub fn new(config: EngineConfig) -> Self {
|
||||
SocketServerBuilder::new(config).build()
|
||||
}
|
||||
|
||||
pub fn builder(config: EngineConfig) -> SocketServerBuilder {
|
||||
SocketServerBuilder::new(config)
|
||||
}
|
||||
|
||||
pub fn of(&self, path: impl Into<String>) -> Arc<crate::socket::namespace::Namespace> {
|
||||
self.namespaces.get_or_create_namespace(&path.into())
|
||||
}
|
||||
|
||||
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
|
||||
self.engine.clone().run_http(addr).await
|
||||
}
|
||||
|
||||
pub fn register_socket(&self, sid: String, tx: mpsc::Sender<Packet>) {
|
||||
self.socket_txs.insert(sid, tx);
|
||||
}
|
||||
|
||||
pub fn unregister_socket(&self, sid: &str) {
|
||||
self.socket_txs.remove(sid);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SocketServerBuilder {
|
||||
config: EngineConfig,
|
||||
adapter: Option<Arc<dyn Adapter>>,
|
||||
}
|
||||
|
||||
impl SocketServerBuilder {
|
||||
pub fn new(config: EngineConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
adapter: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn adapter(mut self, adapter: Arc<dyn Adapter>) -> Self {
|
||||
self.adapter = Some(adapter);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> SocketServer {
|
||||
let namespaces = Arc::new(NamespaceManager::new());
|
||||
let socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>> = Arc::new(DashMap::new());
|
||||
let engine_store = SessionStore::new();
|
||||
|
||||
let namespaces_clone = namespaces.clone();
|
||||
let socket_txs_clone = socket_txs.clone();
|
||||
let engine_store_clone = engine_store.clone();
|
||||
|
||||
let adapter: Arc<dyn Adapter> = self.adapter.unwrap_or_else(|| {
|
||||
let ns_clone = namespaces.clone();
|
||||
let send_fn = move |engine_sid: &str, packet: &Packet| {
|
||||
if let Some(ns) = ns_clone.get_namespace(&packet.namespace) {
|
||||
if let Some(socket) = ns.get_socket_by_engine_sid(engine_sid) {
|
||||
socket.send_packet(packet).map_err(|e| e.to_string())
|
||||
} else {
|
||||
Err(format!(
|
||||
"Socket with engine_sid {} not found in namespace {}",
|
||||
engine_sid, packet.namespace
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(format!("Namespace {} not found", packet.namespace))
|
||||
}
|
||||
};
|
||||
Arc::new(LocalAdapter::new(send_fn))
|
||||
});
|
||||
|
||||
let adapter_clone = adapter.clone();
|
||||
let engine = Arc::new(EngineServer::with_store(
|
||||
self.config,
|
||||
engine_store,
|
||||
move |sid, engine_packet| {
|
||||
let namespaces = namespaces_clone.clone();
|
||||
let socket_txs = socket_txs_clone.clone();
|
||||
let engine_store = engine_store_clone.clone();
|
||||
let adapter = adapter_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
handle_engine_message(
|
||||
sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter,
|
||||
).await;
|
||||
});
|
||||
},
|
||||
));
|
||||
|
||||
let server = SocketServer {
|
||||
engine,
|
||||
namespaces,
|
||||
adapter,
|
||||
socket_txs,
|
||||
};
|
||||
|
||||
for ns in server.namespaces.all_namespaces() {
|
||||
let adapter_ref = server.adapter.clone();
|
||||
tokio::spawn(async move {
|
||||
ns.set_adapter(adapter_ref).await;
|
||||
});
|
||||
}
|
||||
|
||||
server
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_engine_message(
|
||||
engine_sid: String,
|
||||
engine_packet: EnginePacket,
|
||||
namespaces: &Arc<NamespaceManager>,
|
||||
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||
engine_store: &SessionStore,
|
||||
adapter: &Arc<dyn Adapter>,
|
||||
) {
|
||||
if let EnginePacketData::Text(ref text) = engine_packet.data {
|
||||
if let Ok(socket_packet) = parser::decode(text) {
|
||||
match socket_packet.packet_type {
|
||||
PacketType::Connect => {
|
||||
handle_connect(&engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter).await;
|
||||
}
|
||||
PacketType::Disconnect => {
|
||||
handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs);
|
||||
}
|
||||
PacketType::Event => {
|
||||
handle_event(&engine_sid, &socket_packet, namespaces);
|
||||
}
|
||||
PacketType::Ack => {
|
||||
handle_ack(&engine_sid, &socket_packet);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
engine_sid: &str,
|
||||
packet: &Packet,
|
||||
namespaces: &Arc<NamespaceManager>,
|
||||
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||
engine_store: &SessionStore,
|
||||
adapter: &Arc<dyn Adapter>,
|
||||
) {
|
||||
// Validate namespace path to prevent DoS via arbitrary namespace creation
|
||||
if !crate::socket::namespace::is_valid_namespace(&packet.namespace) {
|
||||
tracing::warn!("Rejected connect with invalid namespace: {}", packet.namespace);
|
||||
return;
|
||||
}
|
||||
|
||||
let namespace = namespaces.get_or_create_namespace(&packet.namespace);
|
||||
|
||||
// Ensure newly created namespaces get the shared adapter
|
||||
{
|
||||
let ns_adapter = namespace.adapter.read().await;
|
||||
if ns_adapter.is_none() {
|
||||
drop(ns_adapter);
|
||||
let adapter_ref = adapter.clone();
|
||||
let ns_clone = namespace.clone();
|
||||
tokio::spawn(async move {
|
||||
ns_clone.set_adapter(adapter_ref).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let socket_sid = crate::engine::session::generate_sid();
|
||||
let (tx, mut rx) = mpsc::channel::<Packet>(256);
|
||||
socket_txs.insert(socket_sid.clone(), tx.clone());
|
||||
|
||||
let socket = Arc::new(Socket::new(
|
||||
socket_sid.clone(),
|
||||
packet.namespace.clone(),
|
||||
engine_sid.to_string(),
|
||||
tx,
|
||||
));
|
||||
|
||||
// Run connect handler and add to namespace.
|
||||
// If the handler rejects, clean up and do NOT send a Connect response.
|
||||
if let Err(msg) = namespace.add_socket(socket.clone()).await {
|
||||
tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg);
|
||||
socket_txs.remove(&socket_sid);
|
||||
return;
|
||||
}
|
||||
|
||||
// Connect handler passed — spawn forwarding task
|
||||
let engine_store_clone = engine_store.clone();
|
||||
let engine_sid_clone = engine_sid.to_string();
|
||||
let socket_sid_clone = socket_sid.clone();
|
||||
let socket_txs_clone = socket_txs.clone();
|
||||
let namespace_clone = namespace.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(socket_packet) = rx.recv().await {
|
||||
let encoded = parser::encode(&socket_packet);
|
||||
let engine_packet = EnginePacket::message_text(encoded);
|
||||
|
||||
if let Some(session) = engine_store_clone.get(&engine_sid_clone) {
|
||||
let mut s = session.write().await;
|
||||
if s.state == crate::engine::session::SessionState::Closed {
|
||||
break;
|
||||
}
|
||||
s.push_packet(engine_packet);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Forwarding task ended — ensure socket is cleaned up from namespace
|
||||
socket_txs_clone.remove(&socket_sid_clone);
|
||||
namespace_clone.remove_socket_by_sid(&socket_sid_clone).await;
|
||||
});
|
||||
|
||||
// Send Connect response (only after handler passed)
|
||||
let response = Packet::connect(
|
||||
&socket.namespace,
|
||||
Some(serde_json::json!({ "sid": &socket.sid })),
|
||||
);
|
||||
|
||||
if socket.send_packet(&response).is_err() {
|
||||
tracing::warn!("Failed to send connect response to socket {}", socket.sid);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_disconnect(
|
||||
engine_sid: &str,
|
||||
packet: &Packet,
|
||||
namespaces: &Arc<NamespaceManager>,
|
||||
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||
) {
|
||||
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
|
||||
// Look up socket by engine_sid, then remove by socket_sid
|
||||
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
|
||||
socket_txs.remove(&socket.sid);
|
||||
let socket_sid = socket.sid.clone();
|
||||
let ns_clone = namespace.clone();
|
||||
tokio::spawn(async move {
|
||||
ns_clone.remove_socket_by_sid(&socket_sid).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_event(
|
||||
engine_sid: &str,
|
||||
packet: &Packet,
|
||||
namespaces: &Arc<NamespaceManager>,
|
||||
) {
|
||||
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
|
||||
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
|
||||
if let Some(ref data) = packet.data {
|
||||
if let Some(arr) = data.as_array() {
|
||||
if let Some(event) = arr.first().and_then(|v| v.as_str()) {
|
||||
let event_data = if arr.len() > 1 {
|
||||
serde_json::Value::Array(arr[1..].to_vec())
|
||||
} else {
|
||||
serde_json::Value::Null
|
||||
};
|
||||
|
||||
let namespace_clone = namespace.clone();
|
||||
let event = event.to_string();
|
||||
let socket_clone = socket.clone();
|
||||
tokio::spawn(async move {
|
||||
namespace_clone
|
||||
.handle_event(&socket_clone, &event, &event_data)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_ack(engine_sid: &str, packet: &Packet) {
|
||||
tracing::debug!(
|
||||
"Received ACK from {} for namespace {} with id {:?}",
|
||||
engine_sid,
|
||||
packet.namespace,
|
||||
packet.id
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
|
||||
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
|
||||
|
||||
pub struct InMemorySessionStore {
|
||||
sessions: Arc<DashMap<String, SessionInfo>>,
|
||||
}
|
||||
|
||||
impl InMemorySessionStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemorySessionStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn now_millis() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionStoreTrait for InMemorySessionStore {
|
||||
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
|
||||
let info = SessionInfo {
|
||||
sid: sid.to_string(),
|
||||
transport: transport.to_string(),
|
||||
state: "connecting".to_string(),
|
||||
server_id: server_id.to_string(),
|
||||
created_at: now_millis(),
|
||||
last_ping: now_millis(),
|
||||
};
|
||||
self.sessions.insert(sid.to_string(), info);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
|
||||
Ok(self.sessions.get(sid).map(|r| r.value().clone()))
|
||||
}
|
||||
|
||||
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
|
||||
if let Some(mut entry) = self.sessions.get_mut(sid) {
|
||||
entry.value_mut().state = state.to_string();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(SessionError::NotFound(sid.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
|
||||
if let Some(mut entry) = self.sessions.get_mut(sid) {
|
||||
entry.value_mut().transport = transport.to_string();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(SessionError::NotFound(sid.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
|
||||
if let Some(mut entry) = self.sessions.get_mut(sid) {
|
||||
entry.value_mut().last_ping = now_millis();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(SessionError::NotFound(sid.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
|
||||
self.sessions.remove(sid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
|
||||
Ok(self.sessions.contains_key(sid))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
pub mod memory;
|
||||
pub mod redis;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SessionError {
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(String),
|
||||
#[error("Session not found: {0}")]
|
||||
NotFound(String),
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
#[error("Session expired: {0}")]
|
||||
Expired(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SessionInfo {
|
||||
pub sid: String,
|
||||
pub transport: String,
|
||||
pub state: String,
|
||||
pub server_id: String,
|
||||
pub created_at: u64,
|
||||
pub last_ping: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SessionStoreTrait: Send + Sync + 'static {
|
||||
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError>;
|
||||
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError>;
|
||||
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>;
|
||||
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>;
|
||||
async fn update_ping(&self, sid: &str) -> Result<(), SessionError>;
|
||||
async fn remove(&self, sid: &str) -> Result<(), SessionError>;
|
||||
async fn exists(&self, sid: &str) -> Result<bool, SessionError>;
|
||||
}
|
||||
|
||||
pub use memory::InMemorySessionStore;
|
||||
pub use redis::RedisSessionStore;
|
||||
@@ -0,0 +1,164 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use fred::prelude::*;
|
||||
|
||||
use crate::socket::message_bus::redis::RedisMessageBus;
|
||||
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
|
||||
|
||||
fn now_millis() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
const DEFAULT_TTL_SECS: u64 = 60;
|
||||
const KEY_PREFIX: &str = "socket.io:session";
|
||||
|
||||
pub struct RedisSessionStore {
|
||||
client: Client,
|
||||
ttl_secs: u64,
|
||||
}
|
||||
|
||||
impl RedisSessionStore {
|
||||
pub fn new(bus: &RedisMessageBus, ttl_secs: Option<u64>) -> Self {
|
||||
Self {
|
||||
client: bus.client().clone(),
|
||||
ttl_secs: ttl_secs.unwrap_or(DEFAULT_TTL_SECS),
|
||||
}
|
||||
}
|
||||
|
||||
fn key(&self, sid: &str) -> String {
|
||||
format!("{}:{}", KEY_PREFIX, sid)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionStoreTrait for RedisSessionStore {
|
||||
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
|
||||
let key = self.key(sid);
|
||||
let now = now_millis();
|
||||
|
||||
// Batch all fields in a single HSET call for efficiency
|
||||
let fields: Vec<(&str, String)> = vec![
|
||||
("sid", sid.to_string()),
|
||||
("transport", transport.to_string()),
|
||||
("state", "connecting".to_string()),
|
||||
("server_id", server_id.to_string()),
|
||||
("created_at", now.to_string()),
|
||||
("last_ping", now.to_string()),
|
||||
];
|
||||
self.client
|
||||
.hset::<(), _, _>(&key, fields)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
self.client
|
||||
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
|
||||
let key = self.key(sid);
|
||||
|
||||
// Use hgetall directly — if the key doesn't exist Redis returns an empty map.
|
||||
// This avoids the TOCTOU race between EXISTS and HGETALL.
|
||||
let values: std::collections::HashMap<String, String> = self.client
|
||||
.hgetall::<std::collections::HashMap<String, String>, _>(&key)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
if values.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let info = SessionInfo {
|
||||
sid: values.get("sid").cloned().unwrap_or_default(),
|
||||
transport: values.get("transport").cloned().unwrap_or_default(),
|
||||
state: values.get("state").cloned().unwrap_or_default(),
|
||||
server_id: values.get("server_id").cloned().unwrap_or_default(),
|
||||
created_at: values.get("created_at").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
|
||||
last_ping: values.get("last_ping").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
|
||||
};
|
||||
|
||||
Ok(Some(info))
|
||||
}
|
||||
|
||||
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
|
||||
let key = self.key(sid);
|
||||
|
||||
// Use HSET (not HSETNX) to overwrite existing fields
|
||||
self.client
|
||||
.hset::<(), _, _>(&key, ("state", state))
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
self.client
|
||||
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
|
||||
let key = self.key(sid);
|
||||
|
||||
// Use HSET (not HSETNX) to overwrite existing fields
|
||||
self.client
|
||||
.hset::<(), _, _>(&key, ("transport", transport))
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
self.client
|
||||
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
|
||||
let key = self.key(sid);
|
||||
let now = now_millis();
|
||||
|
||||
// Use HSET (not HSETNX) to overwrite existing fields
|
||||
self.client
|
||||
.hset::<(), _, _>(&key, ("last_ping", now.to_string()))
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
self.client
|
||||
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
|
||||
let key = self.key(sid);
|
||||
|
||||
self.client
|
||||
.del::<(), _>(&key)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
|
||||
let key = self.key(sid);
|
||||
|
||||
let exists: bool = self.client
|
||||
.exists::<bool, _>(&key)
|
||||
.await
|
||||
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||
|
||||
Ok(exists)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::socket::packet::Packet;
|
||||
|
||||
pub struct Socket {
|
||||
pub sid: String,
|
||||
pub namespace: String,
|
||||
pub engine_sid: String,
|
||||
ack_id: AtomicU64,
|
||||
tx: mpsc::Sender<Packet>,
|
||||
}
|
||||
|
||||
impl Socket {
|
||||
pub fn new(
|
||||
sid: String,
|
||||
namespace: String,
|
||||
engine_sid: String,
|
||||
tx: mpsc::Sender<Packet>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sid,
|
||||
namespace,
|
||||
engine_sid,
|
||||
ack_id: AtomicU64::new(0),
|
||||
tx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_ack_id(&self) -> u64 {
|
||||
self.ack_id.fetch_add(1, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn send_packet(&self, packet: &Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||
self.tx.try_send(packet.clone())
|
||||
}
|
||||
|
||||
pub fn emit(&self, event: impl Into<String>, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||
let packet = Packet::event(
|
||||
&self.namespace,
|
||||
serde_json::json!([event.into(), data]),
|
||||
None,
|
||||
);
|
||||
self.send_packet(&packet)
|
||||
}
|
||||
|
||||
pub fn emit_with_ack(
|
||||
&self,
|
||||
event: impl Into<String>,
|
||||
data: serde_json::Value,
|
||||
) -> Result<u64, mpsc::error::TrySendError<Packet>> {
|
||||
let ack_id = self.next_ack_id();
|
||||
let packet = Packet::event(
|
||||
&self.namespace,
|
||||
serde_json::json!([event.into(), data]),
|
||||
Some(ack_id),
|
||||
);
|
||||
self.send_packet(&packet)?;
|
||||
Ok(ack_id)
|
||||
}
|
||||
|
||||
pub fn disconnect(&self) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||
let packet = Packet::disconnect(&self.namespace);
|
||||
self.send_packet(&packet)
|
||||
}
|
||||
|
||||
pub fn send_ack(&self, id: u64, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||
let packet = Packet::ack(&self.namespace, data, id);
|
||||
self.send_packet(&packet)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user