feat(auth): add authentication protocol definitions and build configuration
- Add TokenClaims message for JWT payload structure with user id, issuer, timestamps, and scopes - Implement IssueTokenRequest/Response for creating access and refresh tokens with TTL support - Create RefreshTokenRequest/Response for token rotation functionality - Define RevokeTokenRequest/Response with support for single token or user-wide revocation - Add VerifyTokenRequest/Response for validating JWT tokens with detailed claims information - Implement signing key distribution system with GetSigningKeysRequest/Response - Create TokenService gRPC service with IssueToken, RefreshToken, RevokeToken, VerifyToken, and GetSigningKeys methods - Add build.rs configuration to compile proto files using tonic_prost_build - Include channel, channel_settings, member, and permission protocol definitions for IM services - Generate Rust code bindings through pb/core.rs and pb/im.rs modules
This commit is contained in:
Generated
+3610
File diff suppressed because it is too large
Load Diff
+45
@@ -0,0 +1,45 @@
|
|||||||
|
[package]
|
||||||
|
name = "imks"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "lib.rs"
|
||||||
|
name = "imks"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
path = "main.rs"
|
||||||
|
name = "imks"
|
||||||
|
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
tonic = "0.14.6"
|
||||||
|
prost = "0.14.3"
|
||||||
|
prost-types = "0.14"
|
||||||
|
tonic-build = "0.14.6"
|
||||||
|
tonic-health = "0.14.6"
|
||||||
|
tonic-prost = "0.14.6"
|
||||||
|
tokio = { version = "1.52.3", features = ["full"] }
|
||||||
|
actix-web = { version = "4.13.0", features = [] }
|
||||||
|
actix-ws = { version = "0.4.0", features = [] }
|
||||||
|
actix-rt = "2"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = { version = "1" }
|
||||||
|
base64 = "0.22"
|
||||||
|
rand = "0.9"
|
||||||
|
wtransport = "0.7"
|
||||||
|
dashmap = "6"
|
||||||
|
thiserror = "2"
|
||||||
|
async-trait = "0.1"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
fred = { version = "10", features = ["subscriber-client"] }
|
||||||
|
async-nats = "0.38"
|
||||||
|
uuid = { version = "1", features = ["v4"] }
|
||||||
|
futures-util = "0.3"
|
||||||
|
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-prost-build = "0.14.6"
|
||||||
|
walkdir = "2.5.0"
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let proto_dir = Path::new("proto/core");
|
||||||
|
let protos: Vec<_> = walkdir::WalkDir::new(proto_dir)
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|e| e.ok())
|
||||||
|
.filter(|e| e.path().extension().is_some_and(|ext| ext == "proto"))
|
||||||
|
.map(|e| e.path().to_owned())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for proto in &protos {
|
||||||
|
println!("cargo:rerun-if-changed={}", proto.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
let includes = vec![proto_dir.to_path_buf()];
|
||||||
|
|
||||||
|
tonic_prost_build::configure()
|
||||||
|
.build_server(false)
|
||||||
|
.build_client(true)
|
||||||
|
.compile_protos(&protos, &includes)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
+239
@@ -0,0 +1,239 @@
|
|||||||
|
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
|
||||||
|
|
||||||
|
use crate::engine::packet::{Packet, PacketData, PacketError, PacketType};
|
||||||
|
|
||||||
|
const RECORD_SEPARATOR: char = '\x1e';
|
||||||
|
|
||||||
|
pub fn encode_packet(packet: &Packet) -> String {
|
||||||
|
let type_char = packet.packet_type as u8 + b'0';
|
||||||
|
let type_str = type_char as char;
|
||||||
|
|
||||||
|
match &packet.data {
|
||||||
|
PacketData::Text(s) => format!("{type_str}{s}"),
|
||||||
|
PacketData::Binary(b) => format!("{type_str}b{}", BASE64.encode(b)),
|
||||||
|
PacketData::Empty => type_str.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_packet_binary_ws(packet: &Packet) -> Vec<u8> {
|
||||||
|
let type_byte = packet.packet_type as u8 + b'0';
|
||||||
|
|
||||||
|
match &packet.data {
|
||||||
|
PacketData::Text(s) => {
|
||||||
|
let mut buf = Vec::with_capacity(1 + s.len());
|
||||||
|
buf.push(type_byte);
|
||||||
|
buf.extend_from_slice(s.as_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
PacketData::Binary(b) => {
|
||||||
|
let mut buf = Vec::with_capacity(1 + b.len());
|
||||||
|
buf.push(type_byte);
|
||||||
|
buf.extend_from_slice(b);
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
PacketData::Empty => vec![type_byte],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode_packet(input: &str) -> Result<Packet, PacketError> {
|
||||||
|
let mut chars = input.chars();
|
||||||
|
let type_char = chars.next().ok_or(PacketError::Empty)?;
|
||||||
|
let packet_type = PacketType::try_from(type_char)?;
|
||||||
|
let rest: String = chars.collect();
|
||||||
|
|
||||||
|
if rest.is_empty() {
|
||||||
|
return Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
data: PacketData::Empty,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(b64_data) = rest.strip_prefix('b') {
|
||||||
|
let decoded = BASE64.decode(b64_data)?;
|
||||||
|
return Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
data: PacketData::Binary(decoded),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
data: PacketData::Text(rest),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a WebSocket binary frame into a Packet.
|
||||||
|
/// Handles both text-encoded packets (UTF-8 payload after type byte)
|
||||||
|
/// and binary packets (raw binary payload after type byte).
|
||||||
|
pub fn decode_packet_ws(input: &[u8]) -> Result<Packet, PacketError> {
|
||||||
|
if input.is_empty() {
|
||||||
|
return Err(PacketError::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
let type_byte = input[0];
|
||||||
|
let packet_type = PacketType::try_from(type_byte.wrapping_sub(b'0'))?;
|
||||||
|
let rest = &input[1..];
|
||||||
|
|
||||||
|
if rest.is_empty() {
|
||||||
|
return Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
data: PacketData::Empty,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try UTF-8 first; if it fails, treat as binary data
|
||||||
|
match String::from_utf8(rest.to_vec()) {
|
||||||
|
Ok(text) => Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
data: PacketData::Text(text),
|
||||||
|
}),
|
||||||
|
Err(_) => Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
data: PacketData::Binary(rest.to_vec()),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_payload(packets: &[Packet]) -> String {
|
||||||
|
packets
|
||||||
|
.iter()
|
||||||
|
.map(encode_packet)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(&RECORD_SEPARATOR.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode_payload(input: &str) -> Result<Vec<Packet>, PacketError> {
|
||||||
|
input
|
||||||
|
.split(RECORD_SEPARATOR)
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(decode_packet)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_webtransport_header(payload_len: usize, is_binary: bool) -> Vec<u8> {
|
||||||
|
let binary_bit: u8 = if is_binary { 0x80 } else { 0x00 };
|
||||||
|
|
||||||
|
if payload_len <= 125 {
|
||||||
|
vec![binary_bit | (payload_len as u8)]
|
||||||
|
} else if payload_len <= 65535 {
|
||||||
|
let mut header = vec![binary_bit | 126];
|
||||||
|
header.extend_from_slice(&(payload_len as u16).to_be_bytes());
|
||||||
|
header
|
||||||
|
} else {
|
||||||
|
let mut header = vec![binary_bit | 127];
|
||||||
|
header.extend_from_slice(&(payload_len as u64).to_be_bytes());
|
||||||
|
header
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode_webtransport_header(header: &[u8]) -> Option<(usize, bool)> {
|
||||||
|
if header.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let first = header[0];
|
||||||
|
let is_binary = (first & 0x80) != 0;
|
||||||
|
let len_indicator = first & 0x7f;
|
||||||
|
|
||||||
|
if len_indicator <= 125 {
|
||||||
|
Some((len_indicator as usize, is_binary))
|
||||||
|
} else if len_indicator == 126 {
|
||||||
|
if header.len() < 3 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let len = u16::from_be_bytes([header[1], header[2]]) as usize;
|
||||||
|
Some((len, is_binary))
|
||||||
|
} else {
|
||||||
|
if header.len() < 9 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let len = u64::from_be_bytes([
|
||||||
|
header[1], header[2], header[3], header[4], header[5], header[6], header[7], header[8],
|
||||||
|
]) as usize;
|
||||||
|
Some((len, is_binary))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_decode_text_packet() {
|
||||||
|
let packet = Packet::message_text("hello");
|
||||||
|
let encoded = encode_packet(&packet);
|
||||||
|
assert_eq!(encoded, "4hello");
|
||||||
|
|
||||||
|
let decoded = decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_decode_binary_packet() {
|
||||||
|
let packet = Packet::message_binary(vec![1, 2, 3, 4]);
|
||||||
|
let encoded = encode_packet(&packet);
|
||||||
|
assert_eq!(encoded, "4bAQIDBA==");
|
||||||
|
|
||||||
|
let decoded = decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded.data, PacketData::Binary(vec![1, 2, 3, 4]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_decode_payload() {
|
||||||
|
let packets = vec![
|
||||||
|
Packet::message_text("hello"),
|
||||||
|
Packet::ping(""),
|
||||||
|
Packet::message_text("world"),
|
||||||
|
];
|
||||||
|
let encoded = encode_payload(&packets);
|
||||||
|
assert_eq!(encoded, "4hello\x1e2\x1e4world");
|
||||||
|
|
||||||
|
let decoded = decode_payload(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.len(), 3);
|
||||||
|
assert_eq!(decoded[0].packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded[1].packet_type, PacketType::Ping);
|
||||||
|
assert_eq!(decoded[2].packet_type, PacketType::Message);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_webtransport_header() {
|
||||||
|
let header = encode_webtransport_header(6, false);
|
||||||
|
assert_eq!(header, vec![0x06]);
|
||||||
|
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
|
||||||
|
assert_eq!(len, 6);
|
||||||
|
assert!(!is_binary);
|
||||||
|
|
||||||
|
let header = encode_webtransport_header(200, true);
|
||||||
|
assert_eq!(header.len(), 3);
|
||||||
|
let (len, is_binary) = decode_webtransport_header(&header).unwrap();
|
||||||
|
assert_eq!(len, 200);
|
||||||
|
assert!(is_binary);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_packet_ws_text() {
|
||||||
|
let input = b"4hello";
|
||||||
|
let decoded = decode_packet_ws(input).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded.data, PacketData::Text("hello".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_packet_ws_binary() {
|
||||||
|
// Type byte 4 (Message) + raw binary payload (non-UTF-8)
|
||||||
|
let input: Vec<u8> = vec![b'4', 0x80, 0xFF, 0x00, 0x01];
|
||||||
|
let decoded = decode_packet_ws(&input).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded.data, PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_packet_ws_empty() {
|
||||||
|
let input = b"4";
|
||||||
|
let decoded = decode_packet_ws(input).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded.data, PacketData::Empty);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::engine::packet::Packet;
|
||||||
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||||
|
|
||||||
|
pub struct HeartbeatManager {
|
||||||
|
store: SessionStore,
|
||||||
|
ping_interval: u64,
|
||||||
|
ping_timeout: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HeartbeatManager {
|
||||||
|
pub fn new(store: SessionStore, ping_interval: u64, ping_timeout: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
ping_interval,
|
||||||
|
ping_timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
|
||||||
|
let this = self.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
this.run().await;
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(&self) {
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_millis(self.ping_interval));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
self.check_sessions().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_sessions(&self) {
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
let timeout_duration = Duration::from_millis(self.ping_interval + self.ping_timeout);
|
||||||
|
|
||||||
|
let mut to_remove = Vec::new();
|
||||||
|
|
||||||
|
for entry in self.store.sessions.iter() {
|
||||||
|
let sid = entry.key().clone();
|
||||||
|
let session = entry.value().clone();
|
||||||
|
|
||||||
|
let (state, last_ping, transport) = {
|
||||||
|
let s = session.read().await;
|
||||||
|
(s.state, s.last_ping, s.transport)
|
||||||
|
};
|
||||||
|
|
||||||
|
if state == SessionState::Closed {
|
||||||
|
to_remove.push(sid);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.duration_since(last_ping) > timeout_duration {
|
||||||
|
tracing::warn!("Session {} timed out", sid);
|
||||||
|
to_remove.push(sid);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For polling sessions: buffer a ping packet for the next GET request.
|
||||||
|
// WS/WT sessions rely on their own dedicated ping tasks; the timeout
|
||||||
|
// check above already serves as the safety net for all transports.
|
||||||
|
if state == SessionState::Open && transport == TransportType::Polling {
|
||||||
|
let mut s = session.write().await;
|
||||||
|
s.buffer_packet(Packet::ping(""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for sid in to_remove {
|
||||||
|
self.store.remove(&sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
pub mod codec;
|
||||||
|
pub mod heartbeat;
|
||||||
|
pub mod packet;
|
||||||
|
pub mod polling;
|
||||||
|
pub mod server;
|
||||||
|
pub mod session;
|
||||||
|
pub mod upgrade;
|
||||||
|
pub mod websocket;
|
||||||
|
pub mod webtransport;
|
||||||
|
|
||||||
|
pub use packet::{HandshakeData, Packet, PacketData, PacketType};
|
||||||
|
pub use server::{EngineConfig, EngineServer};
|
||||||
|
pub use session::{SessionState, SessionStore, TransportType};
|
||||||
@@ -0,0 +1,151 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum PacketType {
|
||||||
|
Open = 0,
|
||||||
|
Close = 1,
|
||||||
|
Ping = 2,
|
||||||
|
Pong = 3,
|
||||||
|
Message = 4,
|
||||||
|
Upgrade = 5,
|
||||||
|
Noop = 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<u8> for PacketType {
|
||||||
|
type Error = PacketError;
|
||||||
|
|
||||||
|
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
0 => Ok(Self::Open),
|
||||||
|
1 => Ok(Self::Close),
|
||||||
|
2 => Ok(Self::Ping),
|
||||||
|
3 => Ok(Self::Pong),
|
||||||
|
4 => Ok(Self::Message),
|
||||||
|
5 => Ok(Self::Upgrade),
|
||||||
|
6 => Ok(Self::Noop),
|
||||||
|
_ => Err(PacketError::InvalidType(value)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<char> for PacketType {
|
||||||
|
type Error = PacketError;
|
||||||
|
|
||||||
|
fn try_from(value: char) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
'0' => Ok(Self::Open),
|
||||||
|
'1' => Ok(Self::Close),
|
||||||
|
'2' => Ok(Self::Ping),
|
||||||
|
'3' => Ok(Self::Pong),
|
||||||
|
'4' => Ok(Self::Message),
|
||||||
|
'5' => Ok(Self::Upgrade),
|
||||||
|
'6' => Ok(Self::Noop),
|
||||||
|
_ => Err(PacketError::InvalidTypeChar(value)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PacketData {
|
||||||
|
Text(String),
|
||||||
|
Binary(Vec<u8>),
|
||||||
|
Empty,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Packet {
|
||||||
|
pub packet_type: PacketType,
|
||||||
|
pub data: PacketData,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Packet {
|
||||||
|
pub fn open(handshake: &HandshakeData) -> Self {
|
||||||
|
let data = serde_json::to_string(handshake)
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::error!("Failed to serialize handshake data: {}", e);
|
||||||
|
"{}".to_string()
|
||||||
|
});
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Open,
|
||||||
|
data: PacketData::Text(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn close() -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Close,
|
||||||
|
data: PacketData::Empty,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ping(data: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Ping,
|
||||||
|
data: PacketData::Text(data.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pong(data: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Pong,
|
||||||
|
data: PacketData::Text(data.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn message_text(data: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Message,
|
||||||
|
data: PacketData::Text(data.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn message_binary(data: Vec<u8>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Message,
|
||||||
|
data: PacketData::Binary(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn upgrade() -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Upgrade,
|
||||||
|
data: PacketData::Empty,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn noop() -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Noop,
|
||||||
|
data: PacketData::Empty,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct HandshakeData {
|
||||||
|
pub sid: String,
|
||||||
|
pub upgrades: Vec<String>,
|
||||||
|
#[serde(rename = "pingInterval")]
|
||||||
|
pub ping_interval: u64,
|
||||||
|
#[serde(rename = "pingTimeout")]
|
||||||
|
pub ping_timeout: u64,
|
||||||
|
#[serde(rename = "maxPayload")]
|
||||||
|
pub max_payload: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum PacketError {
|
||||||
|
#[error("invalid packet type: {0}")]
|
||||||
|
InvalidType(u8),
|
||||||
|
#[error("invalid packet type char: {0}")]
|
||||||
|
InvalidTypeChar(char),
|
||||||
|
#[error("empty packet")]
|
||||||
|
Empty,
|
||||||
|
#[error("invalid base64: {0}")]
|
||||||
|
InvalidBase64(#[from] base64::DecodeError),
|
||||||
|
#[error("invalid utf8: {0}")]
|
||||||
|
InvalidUtf8(#[from] std::string::FromUtf8Error),
|
||||||
|
#[error("serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
}
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use actix_web::{web, HttpRequest, HttpResponse};
|
||||||
|
|
||||||
|
use crate::engine::codec;
|
||||||
|
use crate::engine::packet::{Packet, PacketType};
|
||||||
|
use crate::engine::server::EngineConfig;
|
||||||
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
pub struct PollingQuery {
|
||||||
|
#[serde(rename = "EIO")]
|
||||||
|
pub eio: Option<String>,
|
||||||
|
pub transport: Option<String>,
|
||||||
|
pub sid: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn polling_get(
|
||||||
|
_req: HttpRequest,
|
||||||
|
query: web::Query<PollingQuery>,
|
||||||
|
store: web::Data<SessionStore>,
|
||||||
|
config: web::Data<EngineConfig>,
|
||||||
|
_on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
||||||
|
) -> HttpResponse {
|
||||||
|
if query.eio.as_deref() != Some("4") {
|
||||||
|
return HttpResponse::BadRequest().body("invalid EIO version");
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.transport.as_deref() != Some("polling") {
|
||||||
|
return HttpResponse::BadRequest().body("invalid transport");
|
||||||
|
}
|
||||||
|
|
||||||
|
let sid = match &query.sid {
|
||||||
|
Some(sid) => sid.clone(),
|
||||||
|
None => {
|
||||||
|
return handle_handshake(&store, &config).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = match store.get(&sid) {
|
||||||
|
Some(s) => s,
|
||||||
|
None => return HttpResponse::BadRequest().body("unknown session"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check session state and take any buffered pending packets
|
||||||
|
let notify = {
|
||||||
|
let mut session_guard = session.write().await;
|
||||||
|
|
||||||
|
if session_guard.state == SessionState::Closed {
|
||||||
|
return HttpResponse::BadRequest().body("session closed");
|
||||||
|
}
|
||||||
|
|
||||||
|
let pending = session_guard.take_pending();
|
||||||
|
if !pending.is_empty() {
|
||||||
|
let payload = codec::encode_payload(&pending);
|
||||||
|
return HttpResponse::Ok()
|
||||||
|
.content_type("text/plain; charset=UTF-8")
|
||||||
|
.body(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
session_guard.notify.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let timeout = Duration::from_millis(config.ping_interval + config.ping_timeout);
|
||||||
|
let _result = tokio::time::timeout(timeout, notify.notified()).await;
|
||||||
|
|
||||||
|
// Re-verify session still exists after wait (may have been removed by heartbeat)
|
||||||
|
if store.get(&sid).is_none() {
|
||||||
|
return HttpResponse::BadRequest().body("session closed");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut session_guard = session.write().await;
|
||||||
|
let packets = session_guard.take_pending();
|
||||||
|
|
||||||
|
if packets.is_empty() {
|
||||||
|
let noop = codec::encode_packet(&Packet::noop());
|
||||||
|
HttpResponse::Ok()
|
||||||
|
.content_type("text/plain; charset=UTF-8")
|
||||||
|
.body(noop)
|
||||||
|
} else {
|
||||||
|
let payload = codec::encode_payload(&packets);
|
||||||
|
HttpResponse::Ok()
|
||||||
|
.content_type("text/plain; charset=UTF-8")
|
||||||
|
.body(payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn polling_post(
|
||||||
|
_req: HttpRequest,
|
||||||
|
body: web::Bytes,
|
||||||
|
query: web::Query<PollingQuery>,
|
||||||
|
store: web::Data<SessionStore>,
|
||||||
|
config: web::Data<EngineConfig>,
|
||||||
|
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
||||||
|
) -> HttpResponse {
|
||||||
|
if query.eio.as_deref() != Some("4") {
|
||||||
|
return HttpResponse::BadRequest().body("invalid EIO version");
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.transport.as_deref() != Some("polling") {
|
||||||
|
return HttpResponse::BadRequest().body("invalid transport");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check payload size BEFORE attempting to decode
|
||||||
|
if body.len() > config.max_payload {
|
||||||
|
return HttpResponse::PayloadTooLarge().body("payload too large");
|
||||||
|
}
|
||||||
|
|
||||||
|
let sid = match &query.sid {
|
||||||
|
Some(sid) => sid,
|
||||||
|
None => return HttpResponse::BadRequest().body("missing sid"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = match store.get(sid) {
|
||||||
|
Some(s) => s,
|
||||||
|
None => return HttpResponse::BadRequest().body("unknown session"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let body_str = match std::str::from_utf8(&body) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => return HttpResponse::BadRequest().body("invalid utf8"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let packets = match codec::decode_payload(body_str) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => return HttpResponse::BadRequest().body("invalid payload"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut session_guard = session.write().await;
|
||||||
|
|
||||||
|
for packet in packets {
|
||||||
|
match packet.packet_type {
|
||||||
|
PacketType::Pong => {
|
||||||
|
session_guard.update_ping();
|
||||||
|
}
|
||||||
|
PacketType::Message => {
|
||||||
|
let on_msg = on_message.get_ref().clone();
|
||||||
|
let sid_owned = sid.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
on_msg(sid_owned, packet);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
PacketType::Close => {
|
||||||
|
session_guard.set_state(SessionState::Closed);
|
||||||
|
store.remove(sid);
|
||||||
|
return HttpResponse::Ok().body("ok");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HttpResponse::Ok().body("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpResponse {
|
||||||
|
let sid = crate::engine::session::generate_sid();
|
||||||
|
|
||||||
|
let handshake = crate::engine::packet::HandshakeData {
|
||||||
|
sid: sid.clone(),
|
||||||
|
upgrades: vec!["websocket".to_string()],
|
||||||
|
ping_interval: config.ping_interval,
|
||||||
|
ping_timeout: config.ping_timeout,
|
||||||
|
max_payload: config.max_payload,
|
||||||
|
};
|
||||||
|
|
||||||
|
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||||
|
|
||||||
|
if let Some(session) = store.get(&sid) {
|
||||||
|
let mut session = session.write().await;
|
||||||
|
session.set_state(SessionState::Open);
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet = Packet::open(&handshake);
|
||||||
|
let payload = codec::encode_packet(&packet);
|
||||||
|
|
||||||
|
HttpResponse::Ok()
|
||||||
|
.content_type("text/plain; charset=UTF-8")
|
||||||
|
.body(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn configure_polling(cfg: &mut web::ServiceConfig) {
|
||||||
|
cfg.route("/engine.io/", web::get().to(polling_get))
|
||||||
|
.route("/engine.io/", web::post().to(polling_post));
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use actix_web::{web, App, HttpServer};
|
||||||
|
|
||||||
|
use crate::engine::heartbeat::HeartbeatManager;
|
||||||
|
use crate::engine::packet::Packet;
|
||||||
|
use crate::engine::session::SessionStore;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EngineConfig {
|
||||||
|
pub ping_interval: u64,
|
||||||
|
pub ping_timeout: u64,
|
||||||
|
pub max_payload: usize,
|
||||||
|
pub path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for EngineConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
ping_interval: 25000,
|
||||||
|
ping_timeout: 20000,
|
||||||
|
max_payload: 1_000_000,
|
||||||
|
path: "/engine.io/".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EngineServer {
|
||||||
|
pub config: EngineConfig,
|
||||||
|
pub store: SessionStore,
|
||||||
|
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EngineServer {
|
||||||
|
pub fn new(
|
||||||
|
config: EngineConfig,
|
||||||
|
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
store: SessionStore::new(),
|
||||||
|
on_message: Arc::new(on_message),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_store(
|
||||||
|
config: EngineConfig,
|
||||||
|
store: SessionStore,
|
||||||
|
on_message: impl Fn(String, Packet) + Send + Sync + 'static,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
store,
|
||||||
|
on_message: Arc::new(on_message),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
|
||||||
|
let store = self.store.clone();
|
||||||
|
let config = self.config.clone();
|
||||||
|
let on_message = self.on_message.clone();
|
||||||
|
|
||||||
|
// Start heartbeat manager to clean up stale sessions
|
||||||
|
let heartbeat = Arc::new(HeartbeatManager::new(
|
||||||
|
store.clone(),
|
||||||
|
config.ping_interval,
|
||||||
|
config.ping_timeout,
|
||||||
|
));
|
||||||
|
let heartbeat_handle = heartbeat.start();
|
||||||
|
|
||||||
|
tracing::info!("Engine.IO HTTP server listening on {}", addr);
|
||||||
|
|
||||||
|
let result = HttpServer::new(move || {
|
||||||
|
App::new()
|
||||||
|
.app_data(web::Data::new(store.clone()))
|
||||||
|
.app_data(web::Data::new(config.clone()))
|
||||||
|
.app_data(web::Data::new(on_message.clone()))
|
||||||
|
.route(
|
||||||
|
"/engine.io/",
|
||||||
|
web::get().to(crate::engine::polling::polling_get),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/engine.io/",
|
||||||
|
web::post().to(crate::engine::polling::polling_post),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/engine.io/",
|
||||||
|
web::get().to(crate::engine::websocket::websocket_handler),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.bind(addr)?
|
||||||
|
.run()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
heartbeat_handle.abort();
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_webtransport(
|
||||||
|
&self,
|
||||||
|
port: u16,
|
||||||
|
cert_path: &str,
|
||||||
|
key_path: &str,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
crate::engine::webtransport::run_webtransport_server(
|
||||||
|
port,
|
||||||
|
cert_path,
|
||||||
|
key_path,
|
||||||
|
self.store.clone(),
|
||||||
|
self.config.clone(),
|
||||||
|
self.on_message.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tokio::sync::{mpsc, Notify};
|
||||||
|
|
||||||
|
use crate::engine::packet::Packet;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum TransportType {
|
||||||
|
Polling,
|
||||||
|
WebSocket,
|
||||||
|
WebTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransportType {
|
||||||
|
pub fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Polling => "polling",
|
||||||
|
Self::WebSocket => "websocket",
|
||||||
|
Self::WebTransport => "webtransport",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum SessionState {
|
||||||
|
Connecting,
|
||||||
|
Open,
|
||||||
|
Upgrading,
|
||||||
|
Closing,
|
||||||
|
Closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Session {
|
||||||
|
pub sid: String,
|
||||||
|
pub transport: TransportType,
|
||||||
|
pub state: SessionState,
|
||||||
|
pub created_at: Instant,
|
||||||
|
pub last_ping: Instant,
|
||||||
|
pub tx: mpsc::Sender<Packet>,
|
||||||
|
pub pending_packets: Vec<Packet>,
|
||||||
|
pub notify: Arc<Notify>,
|
||||||
|
pub upgrade_tx: Option<mpsc::Sender<Packet>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
pub fn new(sid: String, transport: TransportType) -> (Self, mpsc::Receiver<Packet>) {
|
||||||
|
let (tx, rx) = mpsc::channel(256);
|
||||||
|
let session = Self {
|
||||||
|
sid,
|
||||||
|
transport,
|
||||||
|
state: SessionState::Connecting,
|
||||||
|
created_at: Instant::now(),
|
||||||
|
last_ping: Instant::now(),
|
||||||
|
tx,
|
||||||
|
pending_packets: Vec::new(),
|
||||||
|
notify: Arc::new(Notify::new()),
|
||||||
|
upgrade_tx: None,
|
||||||
|
};
|
||||||
|
(session, rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a packet through the mpsc channel (for WS/WT transport consumption).
|
||||||
|
pub fn send_packet(&self, packet: Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||||
|
self.tx.try_send(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Push a packet using the appropriate mechanism for the current transport.
|
||||||
|
/// Polling: buffer in pending_packets + notify waiting GET request.
|
||||||
|
/// WS/WT: try mpsc channel first; if full, buffer as fallback + notify.
|
||||||
|
pub fn push_packet(&mut self, packet: Packet) {
|
||||||
|
if self.transport == TransportType::Polling {
|
||||||
|
self.pending_packets.push(packet);
|
||||||
|
self.notify.notify_one();
|
||||||
|
} else {
|
||||||
|
if self.tx.try_send(packet.clone()).is_err() {
|
||||||
|
self.pending_packets.push(packet);
|
||||||
|
self.notify.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer a packet in pending_packets and notify any waiting polling request.
|
||||||
|
pub fn buffer_packet(&mut self, packet: Packet) {
|
||||||
|
self.pending_packets.push(packet);
|
||||||
|
self.notify.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn take_pending(&mut self) -> Vec<Packet> {
|
||||||
|
std::mem::take(&mut self.pending_packets)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_ping(&mut self) {
|
||||||
|
self.last_ping = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_transport(&mut self, transport: TransportType) {
|
||||||
|
self.transport = transport;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_state(&mut self, state: SessionState) {
|
||||||
|
self.state = state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SessionStore {
|
||||||
|
pub sessions: Arc<DashMap<String, Arc<tokio::sync::RwLock<Session>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: Arc::new(DashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new session. Returns the mpsc receiver for transport-level packet consumption.
|
||||||
|
/// Logs a warning if the SID collides with an existing session (extremely unlikely with crypto RNG).
|
||||||
|
pub fn create(&self, sid: String, transport: TransportType) -> mpsc::Receiver<Packet> {
|
||||||
|
let (session, rx) = Session::new(sid.clone(), transport);
|
||||||
|
let old = self
|
||||||
|
.sessions
|
||||||
|
.insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session)));
|
||||||
|
if old.is_some() {
|
||||||
|
tracing::warn!("Session ID collision for SID {}, replacing existing session", sid);
|
||||||
|
}
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, sid: &str) -> Option<Arc<tokio::sync::RwLock<Session>>> {
|
||||||
|
self.sessions.get(sid).map(|r| r.value().clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove(&self, sid: &str) {
|
||||||
|
self.sessions.remove(sid);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn exists(&self, sid: &str) -> bool {
|
||||||
|
self.sessions.contains_key(sid)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.sessions.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.sessions.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SessionStore {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a random session ID using a cryptographically secure RNG.
|
||||||
|
/// rand 0.9's default RNG (ChaCha8Rng seeded from OsRng) is crypto-secure.
|
||||||
|
pub fn generate_sid() -> String {
|
||||||
|
use rand::Rng;
|
||||||
|
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-";
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
(0..20)
|
||||||
|
.map(|_| {
|
||||||
|
let idx = rng.random_range(0..CHARSET.len());
|
||||||
|
CHARSET[idx] as char
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
use crate::engine::packet::Packet;
|
||||||
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||||
|
|
||||||
|
pub async fn handle_upgrade_probe(
|
||||||
|
store: &SessionStore,
|
||||||
|
sid: &str,
|
||||||
|
) -> Result<Packet, UpgradeError> {
|
||||||
|
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
|
||||||
|
let mut session = session.write().await;
|
||||||
|
|
||||||
|
if session.state == SessionState::Closed {
|
||||||
|
return Err(UpgradeError::SessionClosed);
|
||||||
|
}
|
||||||
|
|
||||||
|
session.set_state(SessionState::Upgrading);
|
||||||
|
Ok(Packet::pong("probe"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_upgrade_complete(
|
||||||
|
store: &SessionStore,
|
||||||
|
sid: &str,
|
||||||
|
new_transport: TransportType,
|
||||||
|
) -> Result<(), UpgradeError> {
|
||||||
|
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
|
||||||
|
let mut session = session.write().await;
|
||||||
|
|
||||||
|
session.set_transport(new_transport);
|
||||||
|
session.set_state(SessionState::Open);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_noop_to_pending_polling(
|
||||||
|
store: &SessionStore,
|
||||||
|
sid: &str,
|
||||||
|
) -> Result<(), UpgradeError> {
|
||||||
|
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
|
||||||
|
let mut session = session.write().await;
|
||||||
|
|
||||||
|
session.buffer_packet(Packet::noop());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum UpgradeError {
|
||||||
|
#[error("session not found")]
|
||||||
|
SessionNotFound,
|
||||||
|
#[error("session closed")]
|
||||||
|
SessionClosed,
|
||||||
|
#[error("invalid state for upgrade")]
|
||||||
|
InvalidState,
|
||||||
|
}
|
||||||
@@ -0,0 +1,254 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use actix_web::{web, HttpRequest, HttpResponse};
|
||||||
|
use actix_ws::Message;
|
||||||
|
|
||||||
|
use crate::engine::codec;
|
||||||
|
use crate::engine::packet::{Packet, PacketData, PacketType};
|
||||||
|
use crate::engine::server::EngineConfig;
|
||||||
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
pub struct WsQuery {
|
||||||
|
#[serde(rename = "EIO")]
|
||||||
|
pub eio: Option<String>,
|
||||||
|
pub transport: Option<String>,
|
||||||
|
pub sid: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn websocket_handler(
|
||||||
|
req: HttpRequest,
|
||||||
|
body: web::Payload,
|
||||||
|
query: web::Query<WsQuery>,
|
||||||
|
store: web::Data<SessionStore>,
|
||||||
|
config: web::Data<EngineConfig>,
|
||||||
|
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
if query.eio.as_deref() != Some("4") {
|
||||||
|
return Ok(HttpResponse::BadRequest().body("invalid EIO version"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.transport.as_deref() != Some("websocket") {
|
||||||
|
return Ok(HttpResponse::BadRequest().body("invalid transport"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (response, mut ws_session, mut msg_stream) = actix_ws::handle(&req, body)?;
|
||||||
|
|
||||||
|
let sid = query.sid.clone();
|
||||||
|
|
||||||
|
let is_upgrade = sid.as_ref().map(|s| store.exists(s)).unwrap_or(false);
|
||||||
|
|
||||||
|
// Create or reuse session, obtaining the mpsc receiver for the forwarding task
|
||||||
|
let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
|
||||||
|
if is_upgrade {
|
||||||
|
// Upgrade: session already exists, replace its channel and drain pending packets
|
||||||
|
let session_arc = store.get(sid).unwrap();
|
||||||
|
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
|
||||||
|
{
|
||||||
|
let mut s = session_arc.write().await;
|
||||||
|
// Swap tx atomically: old_tx will be dropped, closing its channel.
|
||||||
|
// Any packets in the old rx are consumed by the old send_handle,
|
||||||
|
// which then exits when it sees the channel close.
|
||||||
|
// Drain pending_packets (from polling buffering) into new channel.
|
||||||
|
let pending = s.take_pending();
|
||||||
|
for packet in pending {
|
||||||
|
let _ = new_tx.try_send(packet);
|
||||||
|
}
|
||||||
|
s.tx = new_tx;
|
||||||
|
s.set_transport(TransportType::WebSocket);
|
||||||
|
}
|
||||||
|
(sid.clone(), new_rx)
|
||||||
|
} else {
|
||||||
|
// Reconnect with known SID: create new session
|
||||||
|
let rx = store.create(sid.clone(), TransportType::WebSocket);
|
||||||
|
if let Some(s) = store.get(sid) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_state(SessionState::Open);
|
||||||
|
}
|
||||||
|
(sid.clone(), rx)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New connection: generate SID and create session
|
||||||
|
let new_sid = crate::engine::session::generate_sid();
|
||||||
|
let rx = store.create(new_sid.clone(), TransportType::WebSocket);
|
||||||
|
if let Some(s) = store.get(&new_sid) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_state(SessionState::Open);
|
||||||
|
}
|
||||||
|
(new_sid, rx)
|
||||||
|
};
|
||||||
|
|
||||||
|
let handshake = crate::engine::packet::HandshakeData {
|
||||||
|
sid: session_sid.clone(),
|
||||||
|
upgrades: vec![],
|
||||||
|
ping_interval: config.ping_interval,
|
||||||
|
ping_timeout: config.ping_timeout,
|
||||||
|
max_payload: config.max_payload,
|
||||||
|
};
|
||||||
|
|
||||||
|
let open_packet = Packet::open(&handshake);
|
||||||
|
let open_msg = codec::encode_packet(&open_packet);
|
||||||
|
if ws_session.text(open_msg).await.is_err() {
|
||||||
|
tracing::warn!("Failed to send open packet to WebSocket session {}", session_sid);
|
||||||
|
store.remove(&session_sid);
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
let store_clone = store.get_ref().clone();
|
||||||
|
let on_message_clone = on_message.get_ref().clone();
|
||||||
|
let sid_clone = session_sid.clone();
|
||||||
|
let ws_session_clone = ws_session.clone();
|
||||||
|
let max_payload = config.max_payload;
|
||||||
|
|
||||||
|
// Task 1: Forward engine session packets → WebSocket (reads from session mpsc rx)
|
||||||
|
let sid_for_send = session_sid.clone();
|
||||||
|
let store_for_send = store.get_ref().clone();
|
||||||
|
let mut ws_for_send = ws_session.clone();
|
||||||
|
let send_handle = actix_rt::spawn(async move {
|
||||||
|
while let Some(packet) = session_rx.recv().await {
|
||||||
|
let encoded = codec::encode_packet(&packet);
|
||||||
|
if ws_for_send.text(encoded).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Session channel closed — clean up
|
||||||
|
store_for_send.remove(&sid_for_send);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Task 2: Read incoming WebSocket messages → dispatch
|
||||||
|
let recv_handle = actix_rt::spawn(async move {
|
||||||
|
let mut ws_session = ws_session_clone;
|
||||||
|
while let Some(Ok(msg)) = msg_stream.recv().await {
|
||||||
|
match msg {
|
||||||
|
Message::Text(text) => {
|
||||||
|
if let Ok(packet) = codec::decode_packet(&text) {
|
||||||
|
match packet.packet_type {
|
||||||
|
PacketType::Ping => {
|
||||||
|
if let PacketData::Text(ref data) = packet.data {
|
||||||
|
if data == "probe" {
|
||||||
|
let pong = Packet::pong("probe");
|
||||||
|
let pong_msg = codec::encode_packet(&pong);
|
||||||
|
let _ = ws_session.text(pong_msg).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let pong = Packet::pong("");
|
||||||
|
let pong_msg = codec::encode_packet(&pong);
|
||||||
|
let _ = ws_session.text(pong_msg).await;
|
||||||
|
}
|
||||||
|
PacketType::Pong => {
|
||||||
|
if let Some(s) = store_clone.get(&sid_clone) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.update_ping();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PacketType::Upgrade => {
|
||||||
|
if let Some(s) = store_clone.get(&sid_clone) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_transport(TransportType::WebSocket);
|
||||||
|
s.set_state(SessionState::Open);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PacketType::Message => {
|
||||||
|
let on_msg = on_message_clone.clone();
|
||||||
|
let sid = sid_clone.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
on_msg(sid, packet);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
PacketType::Close => {
|
||||||
|
if let Some(s) = store_clone.get(&sid_clone) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_state(SessionState::Closed);
|
||||||
|
}
|
||||||
|
store_clone.remove(&sid_clone);
|
||||||
|
let _ = ws_session.close(None).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Binary(bin) => {
|
||||||
|
// Enforce max payload size for binary frames
|
||||||
|
if bin.len() > max_payload {
|
||||||
|
tracing::warn!(
|
||||||
|
"Binary payload too large ({}) for session {}",
|
||||||
|
bin.len(),
|
||||||
|
sid_clone
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(packet) = codec::decode_packet_ws(&bin) {
|
||||||
|
if packet.packet_type == PacketType::Message {
|
||||||
|
let on_msg = on_message_clone.clone();
|
||||||
|
let sid = sid_clone.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
on_msg(sid, packet);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Close(_) => {
|
||||||
|
if let Some(s) = store_clone.get(&sid_clone) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_state(SessionState::Closed);
|
||||||
|
}
|
||||||
|
store_clone.remove(&sid_clone);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Task 3: Heartbeat ping sender
|
||||||
|
let sid_for_ping = session_sid.clone();
|
||||||
|
let store_for_ping = store.get_ref().clone();
|
||||||
|
let mut ws_for_ping = ws_session.clone();
|
||||||
|
let ping_interval = config.ping_interval;
|
||||||
|
let ping_handle = tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
if let Some(s) = store_for_ping.get(&sid_for_ping) {
|
||||||
|
let session_state = {
|
||||||
|
let s = s.read().await;
|
||||||
|
s.state
|
||||||
|
};
|
||||||
|
|
||||||
|
if session_state == SessionState::Closed {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ping = Packet::ping("");
|
||||||
|
let ping_msg = codec::encode_packet(&ping);
|
||||||
|
if ws_for_ping.text(ping_msg).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for any task to finish, then clean up
|
||||||
|
actix_rt::spawn(async move {
|
||||||
|
// actix_rt::spawn returns JoinHandle which is compatible with tokio::select!
|
||||||
|
tokio::select! {
|
||||||
|
_ = send_handle => {},
|
||||||
|
_ = recv_handle => {},
|
||||||
|
_ = ping_handle => {},
|
||||||
|
}
|
||||||
|
store.remove(&session_sid);
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn configure_websocket(cfg: &mut web::ServiceConfig) {
|
||||||
|
cfg.route("/engine.io/", web::get().to(websocket_handler));
|
||||||
|
}
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use wtransport::{Connection, Endpoint, ServerConfig, Identity};
|
||||||
|
|
||||||
|
use crate::engine::codec;
|
||||||
|
use crate::engine::packet::{Packet, PacketType};
|
||||||
|
use crate::engine::server::EngineConfig;
|
||||||
|
use crate::engine::session::{SessionState, SessionStore, TransportType};
|
||||||
|
|
||||||
|
pub async fn run_webtransport_server(
|
||||||
|
port: u16,
|
||||||
|
cert_path: &str,
|
||||||
|
key_path: &str,
|
||||||
|
store: SessionStore,
|
||||||
|
config: EngineConfig,
|
||||||
|
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let identity = Identity::load_pemfiles(cert_path, key_path).await?;
|
||||||
|
|
||||||
|
let server_config = ServerConfig::builder()
|
||||||
|
.with_bind_default(port)
|
||||||
|
.with_identity(identity)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let server = Endpoint::server(server_config)?;
|
||||||
|
|
||||||
|
tracing::info!("WebTransport server listening on UDP port {}", port);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let incoming = server.accept().await;
|
||||||
|
|
||||||
|
let store = store.clone();
|
||||||
|
let config = config.clone();
|
||||||
|
let on_message = on_message.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match handle_webtransport_session(incoming, store, config, on_message).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("WebTransport session error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_webtransport_session(
|
||||||
|
incoming: wtransport::endpoint::IncomingSession,
|
||||||
|
store: SessionStore,
|
||||||
|
config: EngineConfig,
|
||||||
|
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let request = incoming.await?;
|
||||||
|
let connection = request.accept().await?;
|
||||||
|
|
||||||
|
let sid = crate::engine::session::generate_sid();
|
||||||
|
let mut rx = store.create(sid.clone(), TransportType::WebTransport);
|
||||||
|
|
||||||
|
if let Some(s) = store.get(&sid) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_state(SessionState::Open);
|
||||||
|
}
|
||||||
|
|
||||||
|
let handshake = crate::engine::packet::HandshakeData {
|
||||||
|
sid: sid.clone(),
|
||||||
|
upgrades: vec![],
|
||||||
|
ping_interval: config.ping_interval,
|
||||||
|
ping_timeout: config.ping_timeout,
|
||||||
|
max_payload: config.max_payload,
|
||||||
|
};
|
||||||
|
|
||||||
|
let open_packet = Packet::open(&handshake);
|
||||||
|
send_wt_packet(&connection, &open_packet).await?;
|
||||||
|
|
||||||
|
let store_clone = store.clone();
|
||||||
|
let sid_clone = sid.clone();
|
||||||
|
let on_message_clone = on_message.clone();
|
||||||
|
let connection_recv = connection.clone();
|
||||||
|
let max_payload = config.max_payload;
|
||||||
|
|
||||||
|
// Reuse buffer across recv iterations instead of allocating 65KB each time
|
||||||
|
let recv_handle = tokio::spawn(async move {
|
||||||
|
let mut buf = vec![0u8; 65536];
|
||||||
|
loop {
|
||||||
|
match connection_recv.accept_bi().await {
|
||||||
|
Ok((mut send, mut recv)) => {
|
||||||
|
// Reset buffer length for the next read without deallocating
|
||||||
|
buf.resize(65536, 0);
|
||||||
|
match recv.read(&mut buf).await {
|
||||||
|
Ok(Some(n)) => {
|
||||||
|
if n > max_payload {
|
||||||
|
tracing::warn!(
|
||||||
|
"WebTransport payload too large ({}) for session {}",
|
||||||
|
n,
|
||||||
|
sid_clone
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
|
||||||
|
match packet.packet_type {
|
||||||
|
PacketType::Ping => {
|
||||||
|
let pong = Packet::pong("");
|
||||||
|
if send_wt_packet_on_stream(&mut send, &pong)
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PacketType::Pong => {
|
||||||
|
if let Some(s) = store_clone.get(&sid_clone) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.update_ping();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PacketType::Message => {
|
||||||
|
let on_msg = on_message_clone.clone();
|
||||||
|
let sid = sid_clone.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
on_msg(sid, packet);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
PacketType::Close => {
|
||||||
|
if let Some(s) = store_clone.get(&sid_clone) {
|
||||||
|
let mut s = s.write().await;
|
||||||
|
s.set_state(SessionState::Closed);
|
||||||
|
}
|
||||||
|
store_clone.remove(&sid_clone);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => break,
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
|
||||||
|
});
|
||||||
|
|
||||||
|
let connection_send = connection.clone();
|
||||||
|
let send_handle = tokio::spawn(async move {
|
||||||
|
while let Some(packet) = rx.recv().await {
|
||||||
|
if send_wt_packet(&connection_send, &packet).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let connection_ping = connection.clone();
|
||||||
|
let store_ping = store.clone();
|
||||||
|
let sid_ping = sid.clone();
|
||||||
|
let ping_interval = config.ping_interval;
|
||||||
|
let ping_handle = tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(ping_interval));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
if let Some(s) = store_ping.get(&sid_ping) {
|
||||||
|
let state = {
|
||||||
|
let s = s.read().await;
|
||||||
|
s.state
|
||||||
|
};
|
||||||
|
|
||||||
|
if state == SessionState::Closed {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ping = Packet::ping("");
|
||||||
|
if send_wt_packet(&connection_ping, &ping).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = recv_handle => {},
|
||||||
|
_ = send_handle => {},
|
||||||
|
_ = ping_handle => {},
|
||||||
|
}
|
||||||
|
|
||||||
|
store.remove(&sid);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_wt_packet(
|
||||||
|
connection: &Connection,
|
||||||
|
packet: &Packet,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let (mut send, _recv) = connection.open_bi().await?.await?;
|
||||||
|
let encoded = codec::encode_packet_binary_ws(packet);
|
||||||
|
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
|
||||||
|
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
|
||||||
|
|
||||||
|
send.write_all(&header).await?;
|
||||||
|
send.write_all(&encoded).await?;
|
||||||
|
send.finish().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_wt_packet_on_stream(
|
||||||
|
send: &mut wtransport::SendStream,
|
||||||
|
packet: &Packet,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let encoded = codec::encode_packet_binary_ws(packet);
|
||||||
|
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
|
||||||
|
let header = codec::encode_webtransport_header(encoded.len(), is_binary);
|
||||||
|
|
||||||
|
send.write_all(&header).await?;
|
||||||
|
send.write_all(&encoded).await?;
|
||||||
|
send.finish().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use imks::engine::server::EngineConfig;
|
||||||
|
use imks::socket::server::SocketServer;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let config = EngineConfig::default();
|
||||||
|
let socket_server = Arc::new(SocketServer::new(config));
|
||||||
|
|
||||||
|
let addr = "0.0.0.0:3000";
|
||||||
|
tracing::info!("Starting Socket.IO server on {}", addr);
|
||||||
|
|
||||||
|
tokio::runtime::Runtime::new()
|
||||||
|
.expect("Failed to create Tokio runtime")
|
||||||
|
.block_on(async {
|
||||||
|
let namespace = socket_server.of("/");
|
||||||
|
namespace
|
||||||
|
.on_connect(|socket, _auth| {
|
||||||
|
tracing::info!(
|
||||||
|
"Socket {} connected (engine: {})",
|
||||||
|
socket.sid,
|
||||||
|
socket.engine_sid
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
socket_server.run_http(addr).await.expect("Server error");
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
include!(concat!(env!("OUT_DIR"), "/appks.core.v1.rs"));
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package appks.core.v1;
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// JWT Payload
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
message TokenClaims {
|
||||||
|
string sub = 1; // user id (uuid)
|
||||||
|
string iss = 2; // issuer (e.g. "appks")
|
||||||
|
int64 iat = 3; // issued at (unix seconds)
|
||||||
|
int64 exp = 4; // expires at (unix seconds)
|
||||||
|
string jti = 5; // unique token id (for revocation)
|
||||||
|
string scope = 6; // space-separated scopes
|
||||||
|
map<string, string> extra = 7; // extensible fields (workspace_id, role, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Issue (appks REST API → core)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
message IssueTokenRequest {
|
||||||
|
string user_id = 1;
|
||||||
|
int64 ttl_secs = 2; // access token lifetime
|
||||||
|
repeated string scopes = 3;
|
||||||
|
map<string, string> extra = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message IssueTokenResponse {
|
||||||
|
string access_token = 1; // JWT
|
||||||
|
string refresh_token = 2; // opaque, stored in Redis
|
||||||
|
int64 expires_at = 3;
|
||||||
|
string key_id = 4; // kid header for the signing key
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Refresh
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
message RefreshTokenRequest {
|
||||||
|
string refresh_token = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RefreshTokenResponse {
|
||||||
|
string access_token = 1;
|
||||||
|
string refresh_token = 2; // rotated
|
||||||
|
int64 expires_at = 3;
|
||||||
|
string key_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Revoke
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
message RevokeTokenRequest {
|
||||||
|
oneof target {
|
||||||
|
string jti = 1; // revoke single token
|
||||||
|
string user_id = 2; // revoke all tokens for user
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message RevokeTokenResponse {
|
||||||
|
int32 revoked_count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Verify (imks → core, RPC 模式)
|
||||||
|
// imks 把客户端携带的 JWT 发给 core 验证
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
message VerifyTokenRequest {
|
||||||
|
string token = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VerifyTokenResponse {
|
||||||
|
bool valid = 1;
|
||||||
|
TokenClaims claims = 2; // only set when valid = true
|
||||||
|
string reason = 3; // "expired", "revoked", "invalid_signature", etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Key Distribution (imks → core, 本地验证模式)
|
||||||
|
// imks 拉取公钥/解密密钥,本地验证 JWT,无需每次 RPC
|
||||||
|
// 密钥窗口 3h,imks 定期刷新
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
message SigningKey {
|
||||||
|
string kid = 1; // key id (matches JWT header kid)
|
||||||
|
string algorithm = 2; // "HS256", "RS256", "EdDSA", ...
|
||||||
|
string key_material = 3; // 对称: base64 secret / 非对称: PEM public key
|
||||||
|
int64 issued_at = 4; // 签发时间
|
||||||
|
int64 expires_at = 5; // 过期时间 (issued_at + 3h window)
|
||||||
|
bool active = 6; // 是否为当前活跃签名密钥
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetSigningKeysRequest {
|
||||||
|
// 空 = 返回所有未过期密钥
|
||||||
|
// 非空 = 只返回指定 kid 的密钥
|
||||||
|
string kid = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetSigningKeysResponse {
|
||||||
|
repeated SigningKey keys = 1; // 可能同时有多个有效密钥(滚动窗口)
|
||||||
|
int64 next_rotation_at = 2; // 下次密钥轮换时间,imks 据此安排刷新
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Service
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
service TokenService {
|
||||||
|
// --- 令牌生命周期 (appks REST handler 调用) ---
|
||||||
|
rpc IssueToken(IssueTokenRequest) returns (IssueTokenResponse);
|
||||||
|
rpc RefreshToken(RefreshTokenRequest) returns (RefreshTokenResponse);
|
||||||
|
rpc RevokeToken(RevokeTokenRequest) returns (RevokeTokenResponse);
|
||||||
|
|
||||||
|
// --- imks 验证 (RPC 模式) ---
|
||||||
|
rpc VerifyToken(VerifyTokenRequest) returns (VerifyTokenResponse);
|
||||||
|
|
||||||
|
// --- imks 密钥拉取 (本地验证模式) ---
|
||||||
|
// imks 启动时拉取,之后根据 next_rotation_at 定期刷新
|
||||||
|
rpc GetSigningKeys(GetSigningKeysRequest) returns (GetSigningKeysResponse);
|
||||||
|
}
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package appks.im.v1;
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
|
// Channel management service for the IM microservice.
|
||||||
|
// Provides CRUD for channels and categories, plus channel statistics.
|
||||||
|
|
||||||
|
|
||||||
|
enum ChannelType {
|
||||||
|
CHANNEL_TYPE_UNSPECIFIED = 0;
|
||||||
|
CHANNEL_TYPE_PUBLIC = 1;
|
||||||
|
CHANNEL_TYPE_PRIVATE = 2;
|
||||||
|
CHANNEL_TYPE_DIRECT = 3;
|
||||||
|
CHANNEL_TYPE_GROUP = 4;
|
||||||
|
CHANNEL_TYPE_REPO = 5;
|
||||||
|
CHANNEL_TYPE_SYSTEM = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ChannelKind {
|
||||||
|
CHANNEL_KIND_UNSPECIFIED = 0;
|
||||||
|
CHANNEL_KIND_TEXT = 1;
|
||||||
|
CHANNEL_KIND_VOICE = 2;
|
||||||
|
CHANNEL_KIND_STAGE = 3;
|
||||||
|
CHANNEL_KIND_FORUM = 4;
|
||||||
|
CHANNEL_KIND_ANNOUNCEMENT = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Visibility {
|
||||||
|
VISIBILITY_UNSPECIFIED = 0;
|
||||||
|
VISIBILITY_PUBLIC = 1;
|
||||||
|
VISIBILITY_PRIVATE = 2;
|
||||||
|
VISIBILITY_INTERNAL = 3;
|
||||||
|
VISIBILITY_WORKSPACE = 4;
|
||||||
|
VISIBILITY_PROTECTED = 5;
|
||||||
|
VISIBILITY_HIDDEN = 6;
|
||||||
|
VISIBILITY_SECRET = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message Channel {
|
||||||
|
string id = 1;
|
||||||
|
string workspace_id = 2;
|
||||||
|
optional string category_id = 3;
|
||||||
|
optional string parent_channel_id = 4;
|
||||||
|
string name = 5;
|
||||||
|
optional string topic = 6;
|
||||||
|
optional string description = 7;
|
||||||
|
ChannelType channel_type = 8;
|
||||||
|
ChannelKind channel_kind = 9;
|
||||||
|
Visibility visibility = 10;
|
||||||
|
int32 position = 11;
|
||||||
|
bool nsfw = 12;
|
||||||
|
bool read_only = 13;
|
||||||
|
bool archived = 14;
|
||||||
|
optional string created_by = 15;
|
||||||
|
optional int32 rate_limit_per_user = 16;
|
||||||
|
optional google.protobuf.Timestamp archived_at = 17;
|
||||||
|
optional string last_message_id = 18;
|
||||||
|
optional google.protobuf.Timestamp last_message_at = 19;
|
||||||
|
google.protobuf.Timestamp created_at = 20;
|
||||||
|
google.protobuf.Timestamp updated_at = 21;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChannelStats {
|
||||||
|
string channel_id = 1;
|
||||||
|
int32 members_count = 2;
|
||||||
|
int32 messages_count = 3;
|
||||||
|
int32 threads_count = 4;
|
||||||
|
int32 reactions_count = 5;
|
||||||
|
int32 mentions_count = 6;
|
||||||
|
int32 files_count = 7;
|
||||||
|
optional google.protobuf.Timestamp last_activity_at = 8;
|
||||||
|
google.protobuf.Timestamp updated_at = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChannelCategory {
|
||||||
|
string id = 1;
|
||||||
|
string workspace_id = 2;
|
||||||
|
string name = 3;
|
||||||
|
int32 position = 4;
|
||||||
|
bool collapsed = 5;
|
||||||
|
google.protobuf.Timestamp created_at = 6;
|
||||||
|
google.protobuf.Timestamp updated_at = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message GetChannelRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetChannelResponse {
|
||||||
|
Channel channel = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListChannelsRequest {
|
||||||
|
string workspace_name = 1;
|
||||||
|
optional string category_id = 2;
|
||||||
|
optional ChannelType channel_type = 3;
|
||||||
|
optional ChannelKind channel_kind = 4;
|
||||||
|
int32 limit = 5;
|
||||||
|
int32 offset = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListChannelsResponse {
|
||||||
|
repeated Channel channels = 1;
|
||||||
|
int32 total = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateChannelRequest {
|
||||||
|
string workspace_name = 1;
|
||||||
|
string name = 2;
|
||||||
|
optional string topic = 3;
|
||||||
|
optional string description = 4;
|
||||||
|
optional string channel_type = 5;
|
||||||
|
optional string channel_kind = 6;
|
||||||
|
optional string visibility = 7;
|
||||||
|
optional string category_id = 8;
|
||||||
|
optional string parent_channel_id = 9;
|
||||||
|
optional string created_by = 10;
|
||||||
|
optional int32 rate_limit_per_user = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateChannelResponse {
|
||||||
|
Channel channel = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateChannelRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
optional string name = 2;
|
||||||
|
optional string topic = 3;
|
||||||
|
optional string description = 4;
|
||||||
|
optional string visibility = 5;
|
||||||
|
optional int32 position = 6;
|
||||||
|
optional bool nsfw = 7;
|
||||||
|
optional bool read_only = 8;
|
||||||
|
optional bool archived = 9;
|
||||||
|
optional string category_id = 10;
|
||||||
|
optional int32 rate_limit_per_user = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateChannelResponse {
|
||||||
|
Channel channel = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeleteChannelRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeleteChannelResponse {}
|
||||||
|
|
||||||
|
message GetChannelStatsRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetChannelStatsResponse {
|
||||||
|
ChannelStats stats = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCategoriesRequest {
|
||||||
|
string workspace_name = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCategoriesResponse {
|
||||||
|
repeated ChannelCategory categories = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateCategoryRequest {
|
||||||
|
string workspace_name = 1;
|
||||||
|
string name = 2;
|
||||||
|
optional int32 position = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateCategoryResponse {
|
||||||
|
ChannelCategory category = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateCategoryRequest {
|
||||||
|
string category_id = 1;
|
||||||
|
optional string name = 2;
|
||||||
|
optional int32 position = 3;
|
||||||
|
optional bool collapsed = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateCategoryResponse {
|
||||||
|
ChannelCategory category = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeleteCategoryRequest {
|
||||||
|
string category_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeleteCategoryResponse {}
|
||||||
|
|
||||||
|
|
||||||
|
service ChannelService {
|
||||||
|
rpc GetChannel(GetChannelRequest) returns (GetChannelResponse);
|
||||||
|
rpc ListChannels(ListChannelsRequest) returns (ListChannelsResponse);
|
||||||
|
rpc CreateChannel(CreateChannelRequest) returns (CreateChannelResponse);
|
||||||
|
rpc UpdateChannel(UpdateChannelRequest) returns (UpdateChannelResponse);
|
||||||
|
rpc DeleteChannel(DeleteChannelRequest) returns (DeleteChannelResponse);
|
||||||
|
rpc GetChannelStats(GetChannelStatsRequest) returns (GetChannelStatsResponse);
|
||||||
|
rpc ListCategories(ListCategoriesRequest) returns (ListCategoriesResponse);
|
||||||
|
rpc CreateCategory(CreateCategoryRequest) returns (CreateCategoryResponse);
|
||||||
|
rpc UpdateCategory(UpdateCategoryRequest) returns (UpdateCategoryResponse);
|
||||||
|
rpc DeleteCategory(DeleteCategoryRequest) returns (DeleteCategoryResponse);
|
||||||
|
}
|
||||||
@@ -0,0 +1,401 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package appks.im.v1;
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelRole {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string name = 3;
|
||||||
|
repeated string permissions = 4;
|
||||||
|
bool assignable = 5;
|
||||||
|
google.protobuf.Timestamp created_at = 6;
|
||||||
|
google.protobuf.Timestamp updated_at = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListChannelRolesRequest { string channel_id = 1; }
|
||||||
|
message ListChannelRolesResponse { repeated ChannelRole roles = 1; }
|
||||||
|
|
||||||
|
message CreateChannelRoleRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string name = 2;
|
||||||
|
repeated string permissions = 3;
|
||||||
|
bool assignable = 4;
|
||||||
|
}
|
||||||
|
message CreateChannelRoleResponse { ChannelRole role = 1; }
|
||||||
|
|
||||||
|
message UpdateChannelRoleRequest {
|
||||||
|
string role_id = 1;
|
||||||
|
optional string name = 2;
|
||||||
|
repeated string permissions = 3;
|
||||||
|
optional bool assignable = 4;
|
||||||
|
}
|
||||||
|
message UpdateChannelRoleResponse { ChannelRole role = 1; }
|
||||||
|
|
||||||
|
message DeleteChannelRoleRequest { string role_id = 1; }
|
||||||
|
message DeleteChannelRoleResponse {}
|
||||||
|
|
||||||
|
service ChannelRoleService {
|
||||||
|
rpc ListChannelRoles(ListChannelRolesRequest) returns (ListChannelRolesResponse);
|
||||||
|
rpc CreateChannelRole(CreateChannelRoleRequest) returns (CreateChannelRoleResponse);
|
||||||
|
rpc UpdateChannelRole(UpdateChannelRoleRequest) returns (UpdateChannelRoleResponse);
|
||||||
|
rpc DeleteChannelRole(DeleteChannelRoleRequest) returns (DeleteChannelRoleResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelInvitation {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string invited_by = 3;
|
||||||
|
string invited_user_id = 4;
|
||||||
|
string role = 5;
|
||||||
|
string status = 6;
|
||||||
|
google.protobuf.Timestamp created_at = 7;
|
||||||
|
google.protobuf.Timestamp updated_at = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListInvitationsRequest { string channel_id = 1; }
|
||||||
|
message ListInvitationsResponse { repeated ChannelInvitation invitations = 1; }
|
||||||
|
|
||||||
|
message CreateInvitationRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string invited_user_id = 2;
|
||||||
|
string role = 3;
|
||||||
|
}
|
||||||
|
message CreateInvitationResponse { ChannelInvitation invitation = 1; }
|
||||||
|
|
||||||
|
message AcceptInvitationRequest { string invitation_id = 1; }
|
||||||
|
message AcceptInvitationResponse { ChannelInvitation invitation = 1; }
|
||||||
|
|
||||||
|
message RevokeInvitationRequest { string invitation_id = 1; }
|
||||||
|
message RevokeInvitationResponse {}
|
||||||
|
|
||||||
|
service ChannelInvitationService {
|
||||||
|
rpc ListInvitations(ListInvitationsRequest) returns (ListInvitationsResponse);
|
||||||
|
rpc CreateInvitation(CreateInvitationRequest) returns (CreateInvitationResponse);
|
||||||
|
rpc AcceptInvitation(AcceptInvitationRequest) returns (AcceptInvitationResponse);
|
||||||
|
rpc RevokeInvitation(RevokeInvitationRequest) returns (RevokeInvitationResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelWebhook {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string name = 3;
|
||||||
|
string url = 4;
|
||||||
|
string secret = 5;
|
||||||
|
repeated string events = 6;
|
||||||
|
bool active = 7;
|
||||||
|
google.protobuf.Timestamp created_at = 8;
|
||||||
|
google.protobuf.Timestamp updated_at = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListWebhooksRequest { string channel_id = 1; }
|
||||||
|
message ListWebhooksResponse { repeated ChannelWebhook webhooks = 1; }
|
||||||
|
|
||||||
|
message CreateWebhookRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string name = 2;
|
||||||
|
string url = 3;
|
||||||
|
optional string secret = 4;
|
||||||
|
repeated string events = 5;
|
||||||
|
}
|
||||||
|
message CreateWebhookResponse { ChannelWebhook webhook = 1; }
|
||||||
|
|
||||||
|
message UpdateWebhookRequest {
|
||||||
|
string webhook_id = 1;
|
||||||
|
optional string name = 2;
|
||||||
|
optional string url = 3;
|
||||||
|
optional string secret = 4;
|
||||||
|
repeated string events = 5;
|
||||||
|
optional bool active = 6;
|
||||||
|
}
|
||||||
|
message UpdateWebhookResponse { ChannelWebhook webhook = 1; }
|
||||||
|
|
||||||
|
message DeleteWebhookRequest { string webhook_id = 1; }
|
||||||
|
message DeleteWebhookResponse {}
|
||||||
|
|
||||||
|
service ChannelWebhookService {
|
||||||
|
rpc ListWebhooks(ListWebhooksRequest) returns (ListWebhooksResponse);
|
||||||
|
rpc CreateWebhook(CreateWebhookRequest) returns (CreateWebhookResponse);
|
||||||
|
rpc UpdateWebhook(UpdateWebhookRequest) returns (UpdateWebhookResponse);
|
||||||
|
rpc DeleteWebhook(DeleteWebhookRequest) returns (DeleteWebhookResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelSlashCommand {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string command = 3;
|
||||||
|
string description = 4;
|
||||||
|
string request_url = 5;
|
||||||
|
repeated string scopes = 6;
|
||||||
|
google.protobuf.Timestamp created_at = 7;
|
||||||
|
google.protobuf.Timestamp updated_at = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListSlashCommandsRequest { string channel_id = 1; }
|
||||||
|
message ListSlashCommandsResponse { repeated ChannelSlashCommand commands = 1; }
|
||||||
|
|
||||||
|
message CreateSlashCommandRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string command = 2;
|
||||||
|
string description = 3;
|
||||||
|
string request_url = 4;
|
||||||
|
repeated string scopes = 5;
|
||||||
|
}
|
||||||
|
message CreateSlashCommandResponse { ChannelSlashCommand command = 1; }
|
||||||
|
|
||||||
|
message UpdateSlashCommandRequest {
|
||||||
|
string command_id = 1;
|
||||||
|
optional string description = 2;
|
||||||
|
optional string request_url = 3;
|
||||||
|
repeated string scopes = 4;
|
||||||
|
}
|
||||||
|
message UpdateSlashCommandResponse { ChannelSlashCommand command = 1; }
|
||||||
|
|
||||||
|
message DeleteSlashCommandRequest { string command_id = 1; }
|
||||||
|
message DeleteSlashCommandResponse {}
|
||||||
|
|
||||||
|
service ChannelSlashCommandService {
|
||||||
|
rpc ListSlashCommands(ListSlashCommandsRequest) returns (ListSlashCommandsResponse);
|
||||||
|
rpc CreateSlashCommand(CreateSlashCommandRequest) returns (CreateSlashCommandResponse);
|
||||||
|
rpc UpdateSlashCommand(UpdateSlashCommandRequest) returns (UpdateSlashCommandResponse);
|
||||||
|
rpc DeleteSlashCommand(DeleteSlashCommandRequest) returns (DeleteSlashCommandResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelRepoLink {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string repo_id = 3;
|
||||||
|
string link_type = 4;
|
||||||
|
repeated string events = 5;
|
||||||
|
google.protobuf.Timestamp created_at = 6;
|
||||||
|
google.protobuf.Timestamp updated_at = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListRepoLinksRequest { string channel_id = 1; }
|
||||||
|
message ListRepoLinksResponse { repeated ChannelRepoLink links = 1; }
|
||||||
|
|
||||||
|
message CreateRepoLinkRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string repo_id = 2;
|
||||||
|
string link_type = 3;
|
||||||
|
repeated string events = 4;
|
||||||
|
}
|
||||||
|
message CreateRepoLinkResponse { ChannelRepoLink link = 1; }
|
||||||
|
|
||||||
|
message DeleteRepoLinkRequest { string link_id = 1; }
|
||||||
|
message DeleteRepoLinkResponse {}
|
||||||
|
|
||||||
|
service ChannelRepoLinkService {
|
||||||
|
rpc ListRepoLinks(ListRepoLinksRequest) returns (ListRepoLinksResponse);
|
||||||
|
rpc CreateRepoLink(CreateRepoLinkRequest) returns (CreateRepoLinkResponse);
|
||||||
|
rpc DeleteRepoLink(DeleteRepoLinkRequest) returns (DeleteRepoLinkResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ImIntegration {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string provider = 3;
|
||||||
|
string external_channel_id = 4;
|
||||||
|
string sync_direction = 5;
|
||||||
|
bool active = 6;
|
||||||
|
google.protobuf.Timestamp created_at = 7;
|
||||||
|
google.protobuf.Timestamp updated_at = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListIntegrationsRequest { string channel_id = 1; }
|
||||||
|
message ListIntegrationsResponse { repeated ImIntegration integrations = 1; }
|
||||||
|
|
||||||
|
message CreateIntegrationRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string provider = 2;
|
||||||
|
string external_channel_id = 3;
|
||||||
|
string sync_direction = 4;
|
||||||
|
}
|
||||||
|
message CreateIntegrationResponse { ImIntegration integration = 1; }
|
||||||
|
|
||||||
|
message UpdateIntegrationRequest {
|
||||||
|
string integration_id = 1;
|
||||||
|
optional string sync_direction = 2;
|
||||||
|
optional bool active = 3;
|
||||||
|
}
|
||||||
|
message UpdateIntegrationResponse { ImIntegration integration = 1; }
|
||||||
|
|
||||||
|
message DeleteIntegrationRequest { string integration_id = 1; }
|
||||||
|
message DeleteIntegrationResponse {}
|
||||||
|
|
||||||
|
service ImIntegrationService {
|
||||||
|
rpc ListIntegrations(ListIntegrationsRequest) returns (ListIntegrationsResponse);
|
||||||
|
rpc CreateIntegration(CreateIntegrationRequest) returns (CreateIntegrationResponse);
|
||||||
|
rpc UpdateIntegration(UpdateIntegrationRequest) returns (UpdateIntegrationResponse);
|
||||||
|
rpc DeleteIntegration(DeleteIntegrationRequest) returns (DeleteIntegrationResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message CustomEmoji {
|
||||||
|
string id = 1;
|
||||||
|
string workspace_id = 2;
|
||||||
|
string name = 3;
|
||||||
|
string image_url = 4;
|
||||||
|
google.protobuf.Timestamp created_at = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListCustomEmojisRequest { string workspace_id = 1; }
|
||||||
|
message ListCustomEmojisResponse { repeated CustomEmoji emojis = 1; }
|
||||||
|
|
||||||
|
message CreateCustomEmojiRequest {
|
||||||
|
string workspace_id = 1;
|
||||||
|
string name = 2;
|
||||||
|
string image_url = 3;
|
||||||
|
}
|
||||||
|
message CreateCustomEmojiResponse { CustomEmoji emoji = 1; }
|
||||||
|
|
||||||
|
message DeleteCustomEmojiRequest { string emoji_id = 1; }
|
||||||
|
message DeleteCustomEmojiResponse {}
|
||||||
|
|
||||||
|
service CustomEmojiService {
|
||||||
|
rpc ListCustomEmojis(ListCustomEmojisRequest) returns (ListCustomEmojisResponse);
|
||||||
|
rpc CreateCustomEmoji(CreateCustomEmojiRequest) returns (CreateCustomEmojiResponse);
|
||||||
|
rpc DeleteCustomEmoji(DeleteCustomEmojiRequest) returns (DeleteCustomEmojiResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ForumTag {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string name = 3;
|
||||||
|
bool moderated = 4;
|
||||||
|
int32 position = 5;
|
||||||
|
google.protobuf.Timestamp created_at = 6;
|
||||||
|
google.protobuf.Timestamp updated_at = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListForumTagsRequest { string channel_id = 1; }
|
||||||
|
message ListForumTagsResponse { repeated ForumTag tags = 1; }
|
||||||
|
|
||||||
|
message CreateForumTagRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string name = 2;
|
||||||
|
bool moderated = 3;
|
||||||
|
optional int32 position = 4;
|
||||||
|
}
|
||||||
|
message CreateForumTagResponse { ForumTag tag = 1; }
|
||||||
|
|
||||||
|
message UpdateForumTagRequest {
|
||||||
|
string tag_id = 1;
|
||||||
|
optional string name = 2;
|
||||||
|
optional bool moderated = 3;
|
||||||
|
optional int32 position = 4;
|
||||||
|
}
|
||||||
|
message UpdateForumTagResponse { ForumTag tag = 1; }
|
||||||
|
|
||||||
|
message DeleteForumTagRequest { string tag_id = 1; }
|
||||||
|
message DeleteForumTagResponse {}
|
||||||
|
|
||||||
|
service ForumTagService {
|
||||||
|
rpc ListForumTags(ListForumTagsRequest) returns (ListForumTagsResponse);
|
||||||
|
rpc CreateForumTag(CreateForumTagRequest) returns (CreateForumTagResponse);
|
||||||
|
rpc UpdateForumTag(UpdateForumTagRequest) returns (UpdateForumTagResponse);
|
||||||
|
rpc DeleteForumTag(DeleteForumTagRequest) returns (DeleteForumTagResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message VoiceParticipant {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string user_id = 3;
|
||||||
|
bool muted = 4;
|
||||||
|
bool deafened = 5;
|
||||||
|
google.protobuf.Timestamp joined_at = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListVoiceParticipantsRequest { string channel_id = 1; }
|
||||||
|
message ListVoiceParticipantsResponse { repeated VoiceParticipant participants = 1; }
|
||||||
|
|
||||||
|
message UpdateVoiceStateRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
optional bool muted = 3;
|
||||||
|
optional bool deafened = 4;
|
||||||
|
}
|
||||||
|
message UpdateVoiceStateResponse { VoiceParticipant participant = 1; }
|
||||||
|
|
||||||
|
service VoiceService {
|
||||||
|
rpc ListVoiceParticipants(ListVoiceParticipantsRequest) returns (ListVoiceParticipantsResponse);
|
||||||
|
rpc UpdateVoiceState(UpdateVoiceStateRequest) returns (UpdateVoiceStateResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message Stage {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string topic = 3;
|
||||||
|
string privacy_level = 4;
|
||||||
|
bool discoverable = 5;
|
||||||
|
google.protobuf.Timestamp started_at = 6;
|
||||||
|
google.protobuf.Timestamp ended_at = 7;
|
||||||
|
google.protobuf.Timestamp created_at = 8;
|
||||||
|
google.protobuf.Timestamp updated_at = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetStageRequest { string channel_id = 1; }
|
||||||
|
message GetStageResponse { Stage stage = 1; }
|
||||||
|
|
||||||
|
message CreateStageRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string topic = 2;
|
||||||
|
string privacy_level = 3;
|
||||||
|
bool discoverable = 4;
|
||||||
|
}
|
||||||
|
message CreateStageResponse { Stage stage = 1; }
|
||||||
|
|
||||||
|
message UpdateStageRequest {
|
||||||
|
string stage_id = 1;
|
||||||
|
optional string topic = 2;
|
||||||
|
optional string privacy_level = 3;
|
||||||
|
optional bool discoverable = 4;
|
||||||
|
}
|
||||||
|
message UpdateStageResponse { Stage stage = 1; }
|
||||||
|
|
||||||
|
message DeleteStageRequest { string stage_id = 1; }
|
||||||
|
message DeleteStageResponse {}
|
||||||
|
|
||||||
|
service StageService {
|
||||||
|
rpc GetStage(GetStageRequest) returns (GetStageResponse);
|
||||||
|
rpc CreateStage(CreateStageRequest) returns (CreateStageResponse);
|
||||||
|
rpc UpdateStage(UpdateStageRequest) returns (UpdateStageResponse);
|
||||||
|
rpc DeleteStage(DeleteStageRequest) returns (DeleteStageResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelAuditEvent {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string actor_id = 3;
|
||||||
|
string event_type = 4;
|
||||||
|
string target_type = 5;
|
||||||
|
string target_id = 6;
|
||||||
|
optional string old_value = 7;
|
||||||
|
optional string new_value = 8;
|
||||||
|
google.protobuf.Timestamp created_at = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListChannelEventsRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
int32 limit = 2;
|
||||||
|
int32 offset = 3;
|
||||||
|
}
|
||||||
|
message ListChannelEventsResponse {
|
||||||
|
repeated ChannelAuditEvent events = 1;
|
||||||
|
int32 total = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
service ChannelAuditService {
|
||||||
|
rpc ListChannelEvents(ListChannelEventsRequest) returns (ListChannelEventsResponse);
|
||||||
|
}
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package appks.im.v1;
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
|
// Member management service for the IM microservice.
|
||||||
|
// Provides CRUD for channel members, join/leave, and membership checks.
|
||||||
|
|
||||||
|
enum Role {
|
||||||
|
ROLE_UNSPECIFIED = 0;
|
||||||
|
ROLE_OWNER = 1;
|
||||||
|
ROLE_ADMIN = 2;
|
||||||
|
ROLE_MAINTAINER = 3;
|
||||||
|
ROLE_MODERATOR = 4;
|
||||||
|
ROLE_MEMBER = 5;
|
||||||
|
ROLE_CONTRIBUTOR = 6;
|
||||||
|
ROLE_VIEWER = 7;
|
||||||
|
ROLE_GUEST = 8;
|
||||||
|
ROLE_BOT = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum MemberStatus {
|
||||||
|
MEMBER_STATUS_UNSPECIFIED = 0;
|
||||||
|
MEMBER_STATUS_ACTIVE = 1;
|
||||||
|
MEMBER_STATUS_INVITED = 2;
|
||||||
|
MEMBER_STATUS_LEFT = 3;
|
||||||
|
MEMBER_STATUS_KICKED = 4;
|
||||||
|
MEMBER_STATUS_BANNED = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ChannelMember {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string user_id = 3;
|
||||||
|
string role = 4;
|
||||||
|
string status = 5;
|
||||||
|
bool muted = 6;
|
||||||
|
bool pinned = 7;
|
||||||
|
optional string last_read_message_id = 8;
|
||||||
|
optional google.protobuf.Timestamp last_read_at = 9;
|
||||||
|
optional google.protobuf.Timestamp joined_at = 10;
|
||||||
|
optional google.protobuf.Timestamp left_at = 11;
|
||||||
|
google.protobuf.Timestamp created_at = 12;
|
||||||
|
google.protobuf.Timestamp updated_at = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ListMembersRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
optional string status = 2;
|
||||||
|
int32 limit = 3;
|
||||||
|
int32 offset = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListMembersResponse {
|
||||||
|
repeated ChannelMember members = 1;
|
||||||
|
int32 total = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InviteMemberRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
optional string role = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InviteMemberResponse {
|
||||||
|
ChannelMember member = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateMemberRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
optional string role = 3;
|
||||||
|
optional bool muted = 4;
|
||||||
|
optional bool pinned = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateMemberResponse {
|
||||||
|
ChannelMember member = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message KickMemberRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message KickMemberResponse {}
|
||||||
|
|
||||||
|
message JoinChannelRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message JoinChannelResponse {
|
||||||
|
ChannelMember member = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LeaveChannelRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LeaveChannelResponse {}
|
||||||
|
|
||||||
|
message IsMemberRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message IsMemberResponse {
|
||||||
|
bool is_member = 1;
|
||||||
|
string role = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
service MemberService {
|
||||||
|
rpc ListMembers(ListMembersRequest) returns (ListMembersResponse);
|
||||||
|
rpc InviteMember(InviteMemberRequest) returns (InviteMemberResponse);
|
||||||
|
rpc UpdateMember(UpdateMemberRequest) returns (UpdateMemberResponse);
|
||||||
|
rpc KickMember(KickMemberRequest) returns (KickMemberResponse);
|
||||||
|
rpc JoinChannel(JoinChannelRequest) returns (JoinChannelResponse);
|
||||||
|
rpc LeaveChannel(LeaveChannelRequest) returns (LeaveChannelResponse);
|
||||||
|
rpc IsMember(IsMemberRequest) returns (IsMemberResponse);
|
||||||
|
}
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package appks.im.v1;
|
||||||
|
|
||||||
|
// IM-specific permissions for channel operations.
|
||||||
|
// Separate from the general Permission enum used for repo/workspace access.
|
||||||
|
enum ImPermission {
|
||||||
|
IM_PERMISSION_UNSPECIFIED = 0;
|
||||||
|
IM_PERMISSION_READ_CHANNEL = 1;
|
||||||
|
IM_PERMISSION_SEND_MESSAGE = 2;
|
||||||
|
IM_PERMISSION_MANAGE_THREADS = 3;
|
||||||
|
IM_PERMISSION_MANAGE_REACTIONS = 4;
|
||||||
|
IM_PERMISSION_MANAGE_PINS = 5;
|
||||||
|
IM_PERMISSION_INVITE_MEMBERS = 6;
|
||||||
|
IM_PERMISSION_KICK_MEMBERS = 7;
|
||||||
|
IM_PERMISSION_MANAGE_CHANNEL = 8;
|
||||||
|
IM_PERMISSION_MANAGE_ROLES = 9;
|
||||||
|
IM_PERMISSION_MANAGE_WEBHOOKS = 10;
|
||||||
|
IM_PERMISSION_MANAGE_EMOJIS = 11;
|
||||||
|
IM_PERMISSION_VIEW_AUDIT_LOG = 12;
|
||||||
|
IM_PERMISSION_MANAGE_INTEGRATIONS = 13;
|
||||||
|
IM_PERMISSION_SEND_TTS = 14;
|
||||||
|
IM_PERMISSION_USE_SLASH_COMMANDS = 15;
|
||||||
|
IM_PERMISSION_ATTACH_FILES = 16;
|
||||||
|
IM_PERMISSION_MENTION_EVERYONE = 17;
|
||||||
|
IM_PERMISSION_MANAGE_MESSAGES = 18;
|
||||||
|
IM_PERMISSION_ADMIN = 19;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message PermissionOverwrite {
|
||||||
|
string id = 1;
|
||||||
|
string channel_id = 2;
|
||||||
|
string target_type = 3;
|
||||||
|
string target_id = 4;
|
||||||
|
repeated ImPermission allow = 5;
|
||||||
|
repeated ImPermission deny = 6;
|
||||||
|
string created_at = 7;
|
||||||
|
string updated_at = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message CheckPermissionRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
ImPermission permission = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CheckPermissionResponse {
|
||||||
|
bool allowed = 1;
|
||||||
|
string role = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetPermissionsRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetPermissionsResponse {
|
||||||
|
repeated ImPermission permissions = 1;
|
||||||
|
string role = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SetPermissionOverwriteRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string target_type = 2;
|
||||||
|
string target_id = 3;
|
||||||
|
repeated ImPermission allow = 4;
|
||||||
|
repeated ImPermission deny = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SetPermissionOverwriteResponse {
|
||||||
|
PermissionOverwrite overwrite = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetPermissionOverwritesRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetPermissionOverwritesResponse {
|
||||||
|
repeated PermissionOverwrite overwrites = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeletePermissionOverwriteRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string target_type = 2;
|
||||||
|
string target_id = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DeletePermissionOverwriteResponse {}
|
||||||
|
|
||||||
|
message ResolveChannelRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ResolveChannelResponse {
|
||||||
|
string channel_id = 1;
|
||||||
|
string workspace_id = 2;
|
||||||
|
string name = 3;
|
||||||
|
string visibility = 4;
|
||||||
|
string channel_type = 5;
|
||||||
|
bool read_only = 6;
|
||||||
|
bool archived = 7;
|
||||||
|
optional string created_by = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message EnsureReadableRequest {
|
||||||
|
string channel_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message EnsureReadableResponse {
|
||||||
|
bool allowed = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
service PermissionService {
|
||||||
|
rpc CheckPermission(CheckPermissionRequest) returns (CheckPermissionResponse);
|
||||||
|
rpc GetPermissions(GetPermissionsRequest) returns (GetPermissionsResponse);
|
||||||
|
rpc SetPermissionOverwrite(SetPermissionOverwriteRequest) returns (SetPermissionOverwriteResponse);
|
||||||
|
rpc GetPermissionOverwrites(GetPermissionOverwritesRequest) returns (GetPermissionOverwritesResponse);
|
||||||
|
rpc DeletePermissionOverwrite(DeletePermissionOverwriteRequest) returns (DeletePermissionOverwriteResponse);
|
||||||
|
rpc ResolveChannel(ResolveChannelRequest) returns (ResolveChannelResponse);
|
||||||
|
rpc EnsureReadable(EnsureReadableRequest) returns (EnsureReadableResponse);
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, SocketInfo};
|
||||||
|
use crate::socket::packet::Packet;
|
||||||
|
|
||||||
|
pub struct LocalAdapter {
|
||||||
|
server_id: String,
|
||||||
|
rooms: Arc<DashMap<String, HashSet<String>>>,
|
||||||
|
socket_rooms: Arc<DashMap<String, HashSet<String>>>,
|
||||||
|
/// socket_sid → engine_sid
|
||||||
|
pub socket_sids: Arc<DashMap<String, String>>,
|
||||||
|
/// socket_sid → namespace path
|
||||||
|
socket_namespace: Arc<DashMap<String, String>>,
|
||||||
|
send_fn: Arc<dyn Fn(&str, &Packet) -> Result<(), String> + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalAdapter {
|
||||||
|
pub fn new(
|
||||||
|
send_fn: impl Fn(&str, &Packet) -> Result<(), String> + Send + Sync + 'static,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
server_id: Uuid::new_v4().to_string(),
|
||||||
|
rooms: Arc::new(DashMap::new()),
|
||||||
|
socket_rooms: Arc::new(DashMap::new()),
|
||||||
|
socket_sids: Arc::new(DashMap::new()),
|
||||||
|
socket_namespace: Arc::new(DashMap::new()),
|
||||||
|
send_fn: Arc::new(send_fn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn room_key(ns: &str, room: &str) -> String {
|
||||||
|
format!("{}:{}", ns, room)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collect socket SIDs matching the broadcast options, scoped to the given namespace.
|
||||||
|
fn collect_matching_sids(&self, opts: &BroadcastOptions, namespace: &str) -> Vec<String> {
|
||||||
|
if opts.rooms.is_empty() {
|
||||||
|
// Broadcast to all sockets in this namespace only
|
||||||
|
self.socket_sids
|
||||||
|
.iter()
|
||||||
|
.filter(|e| {
|
||||||
|
self.socket_namespace
|
||||||
|
.get(e.key())
|
||||||
|
.map(|ns| ns.value() == namespace)
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.map(|e| e.key().clone())
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
let mut sids = HashSet::new();
|
||||||
|
for room in &opts.rooms {
|
||||||
|
let key = Self::room_key(namespace, room);
|
||||||
|
if let Some(entry) = self.rooms.get(&key) {
|
||||||
|
for sid in entry.value() {
|
||||||
|
sids.insert(sid.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sids.into_iter().collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Adapter for LocalAdapter {
|
||||||
|
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
|
||||||
|
let namespace = &packet.namespace;
|
||||||
|
let sids = self.collect_matching_sids(opts, namespace);
|
||||||
|
for sid in &sids {
|
||||||
|
if opts.except.contains(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// socket_sids maps socket SID -> engine SID
|
||||||
|
if let Some(entry) = self.socket_sids.get(sid) {
|
||||||
|
let engine_sid = entry.value();
|
||||||
|
let result = (self.send_fn)(engine_sid, packet);
|
||||||
|
if let Err(e) = result {
|
||||||
|
tracing::warn!("Failed to broadcast to {}: {}", sid, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register(&self, socket_sid: &str, engine_sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
|
||||||
|
self.socket_namespace.insert(socket_sid.to_string(), ns.to_string());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
self.del_all(socket_sid, ns).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
let key = Self::room_key(ns, room);
|
||||||
|
self.rooms.entry(key).or_insert_with(HashSet::new).value_mut().insert(sid.to_string());
|
||||||
|
self.socket_rooms.entry(sid.to_string()).or_insert_with(HashSet::new).value_mut().insert(room.to_string());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
let key = Self::room_key(ns, room);
|
||||||
|
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
|
||||||
|
room_sids.value_mut().remove(sid);
|
||||||
|
if room_sids.value_mut().is_empty() {
|
||||||
|
drop(room_sids);
|
||||||
|
self.rooms.remove(&key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(mut rooms) = self.socket_rooms.get_mut(sid) {
|
||||||
|
rooms.value_mut().remove(room);
|
||||||
|
if rooms.value_mut().is_empty() {
|
||||||
|
drop(rooms);
|
||||||
|
self.socket_rooms.remove(sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
|
||||||
|
for room in &rooms {
|
||||||
|
let key = Self::room_key(ns, room);
|
||||||
|
if let Some(mut room_sids) = self.rooms.get_mut(&key) {
|
||||||
|
room_sids.value_mut().remove(sid);
|
||||||
|
if room_sids.value_mut().is_empty() {
|
||||||
|
drop(room_sids);
|
||||||
|
self.rooms.remove(&key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.socket_sids.remove(sid);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
|
||||||
|
// fetch_sockets needs namespace context; use an empty namespace to match all
|
||||||
|
// (this method is typically called for inspection, not delivery)
|
||||||
|
let sids: Vec<String> = if opts.rooms.is_empty() {
|
||||||
|
self.socket_sids.iter().map(|e| e.key().clone()).collect()
|
||||||
|
} else {
|
||||||
|
let mut sids_set = HashSet::new();
|
||||||
|
for room in &opts.rooms {
|
||||||
|
for entry in self.rooms.iter() {
|
||||||
|
if entry.key().ends_with(&format!(":{}", room)) {
|
||||||
|
for sid in entry.value() {
|
||||||
|
sids_set.insert(sid.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sids_set.into_iter().collect()
|
||||||
|
};
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for sid in &sids {
|
||||||
|
if opts.except.contains(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if self.socket_sids.contains_key(sid) {
|
||||||
|
let namespace = self.socket_namespace
|
||||||
|
.get(sid)
|
||||||
|
.map(|r| r.value().clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let rooms = self.socket_rooms
|
||||||
|
.get(sid)
|
||||||
|
.map(|r| r.value().clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
result.push(SocketInfo {
|
||||||
|
sid: sid.clone(),
|
||||||
|
namespace,
|
||||||
|
rooms,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
|
||||||
|
Ok(self.socket_rooms
|
||||||
|
.get(sid)
|
||||||
|
.map(|r| r.value().clone())
|
||||||
|
.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn server_id(&self) -> &str {
|
||||||
|
&self.server_id
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&self) -> Result<(), AdapterError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
pub mod local;
|
||||||
|
pub mod redis;
|
||||||
|
pub mod nats;
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use crate::socket::packet::Packet;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum AdapterError {
|
||||||
|
#[error("Redis error: {0}")]
|
||||||
|
Redis(String),
|
||||||
|
#[error("NATS error: {0}")]
|
||||||
|
Nats(String),
|
||||||
|
#[error("Message bus error: {0}")]
|
||||||
|
MessageBus(String),
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
#[error("Room error: {0}")]
|
||||||
|
Room(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||||
|
pub struct BroadcastOptions {
|
||||||
|
pub rooms: HashSet<String>,
|
||||||
|
pub except: HashSet<String>,
|
||||||
|
pub flags: BroadcastFlags,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||||
|
pub struct BroadcastFlags {
|
||||||
|
pub local_only: bool,
|
||||||
|
pub broadcast: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct SocketInfo {
|
||||||
|
pub sid: String,
|
||||||
|
pub namespace: String,
|
||||||
|
pub rooms: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
|
||||||
|
pub enum BusMessage {
|
||||||
|
Broadcast {
|
||||||
|
namespace: String,
|
||||||
|
packet: String,
|
||||||
|
opts: BroadcastOptions,
|
||||||
|
server_id: String,
|
||||||
|
},
|
||||||
|
SocketJoin {
|
||||||
|
namespace: String,
|
||||||
|
sid: String,
|
||||||
|
room: String,
|
||||||
|
server_id: String,
|
||||||
|
},
|
||||||
|
SocketLeave {
|
||||||
|
namespace: String,
|
||||||
|
sid: String,
|
||||||
|
room: String,
|
||||||
|
server_id: String,
|
||||||
|
},
|
||||||
|
SocketDisconnect {
|
||||||
|
namespace: String,
|
||||||
|
sid: String,
|
||||||
|
server_id: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Adapter: Send + Sync + 'static {
|
||||||
|
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError>;
|
||||||
|
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
|
||||||
|
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
|
||||||
|
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>;
|
||||||
|
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError>;
|
||||||
|
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError>;
|
||||||
|
fn server_id(&self) -> &str;
|
||||||
|
async fn close(&self) -> Result<(), AdapterError>;
|
||||||
|
|
||||||
|
/// Register a socket SID → engine SID mapping in the adapter.
|
||||||
|
/// Must be called when a socket first connects, before any room operations.
|
||||||
|
/// The `ns` parameter is the namespace path this socket belongs to.
|
||||||
|
async fn register(&self, _socket_sid: &str, _engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unregister a socket from the adapter, removing all local mappings.
|
||||||
|
async fn unregister(&self, _socket_sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub use local::LocalAdapter;
|
||||||
|
pub use redis::RedisAdapter;
|
||||||
|
pub use nats::NatsAdapter;
|
||||||
@@ -0,0 +1,302 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
|
||||||
|
use crate::socket::message_bus::MessageBus;
|
||||||
|
use crate::socket::packet::Packet;
|
||||||
|
use crate::socket::parser;
|
||||||
|
use crate::socket::socket::Socket;
|
||||||
|
|
||||||
|
/// Handle incoming bus messages from other servers.
|
||||||
|
/// Only performs local dispatch — no remote state writes needed.
|
||||||
|
async fn handle_bus_message(
|
||||||
|
msg: BusMessage,
|
||||||
|
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||||
|
server_id: &str,
|
||||||
|
) {
|
||||||
|
match msg {
|
||||||
|
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
|
||||||
|
if sender_id == server_id {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Ok(decoded_packet) = parser::decode(&packet) {
|
||||||
|
on_local_broadcast(&decoded_packet, &opts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// NATS adapter manages room state locally; cross-server join/leave/disconnect
|
||||||
|
// are informational only and don't require duplicate state writes.
|
||||||
|
BusMessage::SocketJoin { server_id: sender_id, .. }
|
||||||
|
| BusMessage::SocketLeave { server_id: sender_id, .. }
|
||||||
|
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
|
||||||
|
if sender_id == server_id {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// NATS-based adapter that manages room state locally and uses NATS
|
||||||
|
/// for cross-server broadcast only. Does NOT depend on Redis.
|
||||||
|
pub struct NatsAdapter {
|
||||||
|
message_bus: Arc<dyn MessageBus>,
|
||||||
|
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
|
||||||
|
socket_rooms: DashMap<String, HashSet<String>>,
|
||||||
|
rooms: DashMap<String, HashSet<String>>,
|
||||||
|
/// socket_sid → engine_sid mapping for local delivery
|
||||||
|
socket_sids: DashMap<String, String>,
|
||||||
|
sockets: DashMap<String, Arc<Socket>>,
|
||||||
|
server_id: String,
|
||||||
|
namespace: String,
|
||||||
|
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NatsAdapter {
|
||||||
|
pub fn new(
|
||||||
|
message_bus: Arc<dyn MessageBus>,
|
||||||
|
server_id: String,
|
||||||
|
namespace: String,
|
||||||
|
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
message_bus,
|
||||||
|
server_id,
|
||||||
|
namespace,
|
||||||
|
on_local_broadcast,
|
||||||
|
room_subscribers: DashMap::new(),
|
||||||
|
socket_rooms: DashMap::new(),
|
||||||
|
rooms: DashMap::new(),
|
||||||
|
socket_sids: DashMap::new(),
|
||||||
|
sockets: DashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(&self) -> Result<(), AdapterError> {
|
||||||
|
let channels = ["broadcast", "join", "leave", "disconnect"];
|
||||||
|
let prefix = format!("socket.io:{}:", self.namespace);
|
||||||
|
|
||||||
|
for channel_type in channels {
|
||||||
|
let subject = format!("{}{}", prefix, channel_type);
|
||||||
|
match self.message_bus.subscribe(&subject).await {
|
||||||
|
Ok(rx) => {
|
||||||
|
self.room_subscribers.insert(channel_type.to_string(), rx);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.spawn_listener();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_listener(&self) {
|
||||||
|
let server_id = self.server_id.clone();
|
||||||
|
let on_local_broadcast = self.on_local_broadcast.clone();
|
||||||
|
|
||||||
|
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
|
||||||
|
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
|
||||||
|
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
|
||||||
|
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(data) = async { join_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Adapter for NatsAdapter {
|
||||||
|
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
|
||||||
|
if opts.flags.local_only {
|
||||||
|
(self.on_local_broadcast)(packet, opts);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = BusMessage::Broadcast {
|
||||||
|
namespace: self.namespace.clone(),
|
||||||
|
packet: parser::encode(packet),
|
||||||
|
opts: opts.clone(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:broadcast", self.namespace), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
(self.on_local_broadcast)(packet, opts);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||||
|
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||||
|
self.socket_rooms
|
||||||
|
.entry(sid.to_string())
|
||||||
|
.and_modify(|set| { set.insert(room.to_string()); })
|
||||||
|
.or_insert_with(|| HashSet::from([room.to_string()]));
|
||||||
|
|
||||||
|
self.rooms
|
||||||
|
.entry(room.to_string())
|
||||||
|
.and_modify(|set| { set.insert(sid.to_string()); })
|
||||||
|
.or_insert_with(|| HashSet::from([sid.to_string()]));
|
||||||
|
|
||||||
|
let msg = BusMessage::SocketJoin {
|
||||||
|
namespace: self.namespace.clone(),
|
||||||
|
sid: sid.to_string(),
|
||||||
|
room: room.to_string(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:join", self.namespace), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn del(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||||
|
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
|
||||||
|
entry.value_mut().remove(room);
|
||||||
|
}
|
||||||
|
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||||
|
self.socket_rooms.remove(sid);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||||
|
entry.value_mut().remove(sid);
|
||||||
|
}
|
||||||
|
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||||
|
self.rooms.remove(room);
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = BusMessage::SocketLeave {
|
||||||
|
namespace: self.namespace.clone(),
|
||||||
|
sid: sid.to_string(),
|
||||||
|
room: room.to_string(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:leave", self.namespace), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn del_all(&self, sid: &str, _ns: &str) -> Result<(), AdapterError> {
|
||||||
|
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
|
||||||
|
for room in &rooms {
|
||||||
|
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||||
|
entry.value_mut().remove(sid);
|
||||||
|
}
|
||||||
|
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||||
|
self.rooms.remove(room);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.socket_sids.remove(sid);
|
||||||
|
self.sockets.remove(sid);
|
||||||
|
|
||||||
|
let msg = BusMessage::SocketDisconnect {
|
||||||
|
namespace: self.namespace.clone(),
|
||||||
|
sid: sid.to_string(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:disconnect", self.namespace), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
|
||||||
|
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
|
||||||
|
self.socket_sids.iter().map(|e| e.key().clone()).collect()
|
||||||
|
} else {
|
||||||
|
let mut sids = HashSet::new();
|
||||||
|
for room in &opts.rooms {
|
||||||
|
if let Some(entry) = self.rooms.get(room) {
|
||||||
|
sids.extend(entry.value().iter().cloned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sids
|
||||||
|
};
|
||||||
|
|
||||||
|
for sid in target_sids {
|
||||||
|
if opts.except.contains(&sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
|
||||||
|
result.push(SocketInfo {
|
||||||
|
sid: sid.clone(),
|
||||||
|
namespace: self.namespace.clone(),
|
||||||
|
rooms,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
|
||||||
|
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn server_id(&self) -> &str {
|
||||||
|
&self.server_id
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&self) -> Result<(), AdapterError> {
|
||||||
|
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use fred::clients::Client;
|
||||||
|
use fred::interfaces::{KeysInterface, SetsInterface};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo};
|
||||||
|
use crate::socket::message_bus::MessageBus;
|
||||||
|
use crate::socket::packet::Packet;
|
||||||
|
use crate::socket::parser;
|
||||||
|
use crate::socket::socket::Socket;
|
||||||
|
|
||||||
|
const KEY_PREFIX_ROOMS: &str = "socket.io:rooms";
|
||||||
|
const KEY_PREFIX_SOCKET_ROOMS: &str = "socket.io:socket_rooms";
|
||||||
|
|
||||||
|
fn room_key(ns: &str, room: &str) -> String {
|
||||||
|
format!("{}:{}:{}", KEY_PREFIX_ROOMS, ns, room)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn socket_rooms_key(ns: &str, sid: &str) -> String {
|
||||||
|
format!("{}:{}:{}", KEY_PREFIX_SOCKET_ROOMS, ns, sid)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle incoming bus messages from other servers.
|
||||||
|
/// Only performs local state updates — the remote server already wrote to Redis.
|
||||||
|
async fn handle_bus_message(
|
||||||
|
msg: BusMessage,
|
||||||
|
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||||
|
server_id: &str,
|
||||||
|
) {
|
||||||
|
match msg {
|
||||||
|
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => {
|
||||||
|
if sender_id == server_id {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Ok(decoded_packet) = parser::decode(&packet) {
|
||||||
|
on_local_broadcast(&decoded_packet, &opts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BusMessage::SocketJoin { server_id: sender_id, .. }
|
||||||
|
| BusMessage::SocketLeave { server_id: sender_id, .. }
|
||||||
|
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => {
|
||||||
|
// Skip messages from this server; remote server already updated Redis
|
||||||
|
if sender_id == server_id {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// No duplicate Redis writes — the sender already persisted the state change
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RedisAdapter {
|
||||||
|
message_bus: Arc<dyn MessageBus>,
|
||||||
|
redis_client: Client,
|
||||||
|
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
|
||||||
|
socket_rooms: DashMap<String, HashSet<String>>,
|
||||||
|
rooms: DashMap<String, HashSet<String>>,
|
||||||
|
sockets: DashMap<String, Arc<Socket>>,
|
||||||
|
server_id: String,
|
||||||
|
namespace: String,
|
||||||
|
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RedisAdapter {
|
||||||
|
pub fn new(
|
||||||
|
message_bus: Arc<dyn MessageBus>,
|
||||||
|
redis_client: Client,
|
||||||
|
server_id: String,
|
||||||
|
namespace: String,
|
||||||
|
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
message_bus,
|
||||||
|
redis_client,
|
||||||
|
server_id,
|
||||||
|
namespace,
|
||||||
|
on_local_broadcast,
|
||||||
|
room_subscribers: DashMap::new(),
|
||||||
|
socket_rooms: DashMap::new(),
|
||||||
|
rooms: DashMap::new(),
|
||||||
|
sockets: DashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn init(&self) -> Result<(), AdapterError> {
|
||||||
|
let channels = ["broadcast", "join", "leave", "disconnect"];
|
||||||
|
let prefix = format!("socket.io:{}:", self.namespace);
|
||||||
|
|
||||||
|
for channel_type in channels {
|
||||||
|
let channel = format!("{}{}", prefix, channel_type);
|
||||||
|
match self.message_bus.subscribe(&channel).await {
|
||||||
|
Ok(rx) => {
|
||||||
|
self.room_subscribers.insert(channel_type.to_string(), rx);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(AdapterError::MessageBus(e.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.spawn_listener();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_listener(&self) {
|
||||||
|
let server_id = self.server_id.clone();
|
||||||
|
let on_local_broadcast = self.on_local_broadcast.clone();
|
||||||
|
|
||||||
|
let mut broadcast_rx = self.room_subscribers.remove("broadcast").map(|(_, rx)| rx);
|
||||||
|
let mut join_rx = self.room_subscribers.remove("join").map(|(_, rx)| rx);
|
||||||
|
let mut leave_rx = self.room_subscribers.remove("leave").map(|(_, rx)| rx);
|
||||||
|
let mut disconnect_rx = self.room_subscribers.remove("disconnect").map(|(_, rx)| rx);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Some(data) = async { broadcast_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(data) = async { join_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(data) = async { leave_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(data) = async { disconnect_rx.as_mut()?.recv().await } => {
|
||||||
|
if let Ok(msg) = serde_json::from_slice::<BusMessage>(&data) {
|
||||||
|
handle_bus_message(msg, &on_local_broadcast, &server_id).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Adapter for RedisAdapter {
|
||||||
|
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> {
|
||||||
|
if opts.flags.local_only {
|
||||||
|
(self.on_local_broadcast)(packet, opts);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = BusMessage::Broadcast {
|
||||||
|
namespace: packet.namespace.clone(),
|
||||||
|
packet: parser::encode(packet),
|
||||||
|
opts: opts.clone(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
(self.on_local_broadcast)(packet, opts);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
let rk = room_key(ns, room);
|
||||||
|
let srk = socket_rooms_key(ns, sid);
|
||||||
|
|
||||||
|
self.redis_client
|
||||||
|
.sadd::<(), _, _>(&rk, sid)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.redis_client
|
||||||
|
.sadd::<(), _, _>(&srk, room)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.socket_rooms
|
||||||
|
.entry(sid.to_string())
|
||||||
|
.and_modify(|set| { set.insert(room.to_string()); })
|
||||||
|
.or_insert_with(|| HashSet::from([room.to_string()]));
|
||||||
|
|
||||||
|
self.rooms
|
||||||
|
.entry(room.to_string())
|
||||||
|
.and_modify(|set| { set.insert(sid.to_string()); })
|
||||||
|
.or_insert_with(|| HashSet::from([sid.to_string()]));
|
||||||
|
|
||||||
|
let msg = BusMessage::SocketJoin {
|
||||||
|
namespace: ns.to_string(),
|
||||||
|
sid: sid.to_string(),
|
||||||
|
room: room.to_string(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:join", ns), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
let rk = room_key(ns, room);
|
||||||
|
let srk = socket_rooms_key(ns, sid);
|
||||||
|
|
||||||
|
self.redis_client
|
||||||
|
.srem::<(), _, _>(&rk, sid)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.redis_client
|
||||||
|
.srem::<(), _, _>(&srk, room)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
|
||||||
|
entry.value_mut().remove(room);
|
||||||
|
}
|
||||||
|
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||||
|
self.socket_rooms.remove(sid);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||||
|
entry.value_mut().remove(sid);
|
||||||
|
}
|
||||||
|
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||||
|
self.rooms.remove(room);
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = BusMessage::SocketLeave {
|
||||||
|
namespace: ns.to_string(),
|
||||||
|
sid: sid.to_string(),
|
||||||
|
room: room.to_string(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:leave", ns), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError> {
|
||||||
|
if let Some((_, rooms)) = self.socket_rooms.remove(sid) {
|
||||||
|
for room in &rooms {
|
||||||
|
if let Some(mut entry) = self.rooms.get_mut(room) {
|
||||||
|
entry.value_mut().remove(sid);
|
||||||
|
}
|
||||||
|
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) {
|
||||||
|
self.rooms.remove(room);
|
||||||
|
}
|
||||||
|
|
||||||
|
let rk = room_key(ns, room);
|
||||||
|
if let Err(e) = self.redis_client.srem::<(), _, _>(&rk, sid).await {
|
||||||
|
tracing::warn!("Redis SREM room error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let srk = socket_rooms_key(ns, sid);
|
||||||
|
self.redis_client
|
||||||
|
.del::<(), _>(&srk)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.sockets.remove(sid);
|
||||||
|
|
||||||
|
let msg = BusMessage::SocketDisconnect {
|
||||||
|
namespace: ns.to_string(),
|
||||||
|
sid: sid.to_string(),
|
||||||
|
server_id: self.server_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload = serde_json::to_vec(&msg)
|
||||||
|
.map_err(|e| AdapterError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
self.message_bus
|
||||||
|
.publish(&format!("socket.io:{}:disconnect", ns), &payload)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
|
||||||
|
let target_sids: HashSet<String> = if opts.rooms.is_empty() {
|
||||||
|
self.sockets.iter().map(|e| e.key().clone()).collect()
|
||||||
|
} else {
|
||||||
|
let mut sids = HashSet::new();
|
||||||
|
for room in &opts.rooms {
|
||||||
|
if let Some(entry) = self.rooms.get(room) {
|
||||||
|
sids.extend(entry.value().iter().cloned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sids
|
||||||
|
};
|
||||||
|
|
||||||
|
for sid in target_sids {
|
||||||
|
if opts.except.contains(&sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default();
|
||||||
|
result.push(SocketInfo {
|
||||||
|
sid: sid.clone(),
|
||||||
|
namespace: self.namespace.clone(),
|
||||||
|
rooms,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
|
||||||
|
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn server_id(&self) -> &str {
|
||||||
|
&self.server_id
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&self) -> Result<(), AdapterError> {
|
||||||
|
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
pub mod redis;
|
||||||
|
pub mod nats;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum MessageBusError {
|
||||||
|
#[error("Redis error: {0}")]
|
||||||
|
Redis(String),
|
||||||
|
#[error("NATS error: {0}")]
|
||||||
|
Nats(String),
|
||||||
|
#[error("Connection closed")]
|
||||||
|
ConnectionClosed,
|
||||||
|
#[error("Channel not found: {0}")]
|
||||||
|
ChannelNotFound(String),
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait MessageBus: Send + Sync + 'static {
|
||||||
|
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError>;
|
||||||
|
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError>;
|
||||||
|
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError>;
|
||||||
|
async fn close(&self) -> Result<(), MessageBusError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub use redis::RedisMessageBus;
|
||||||
|
pub use nats::NatsMessageBus;
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tokio::sync::{mpsc, watch};
|
||||||
|
|
||||||
|
use crate::socket::message_bus::{MessageBus, MessageBusError};
|
||||||
|
|
||||||
|
pub struct NatsMessageBus {
|
||||||
|
client: async_nats::Client,
|
||||||
|
shutdowns: DashMap<String, watch::Sender<bool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NatsMessageBus {
|
||||||
|
pub async fn new(nats_url: &str) -> Result<Self, MessageBusError> {
|
||||||
|
let client = async_nats::connect(nats_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
shutdowns: DashMap::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MessageBus for NatsMessageBus {
|
||||||
|
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
|
||||||
|
self.client
|
||||||
|
.publish(channel.to_string(), message.to_vec().into())
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
|
||||||
|
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
|
||||||
|
|
||||||
|
let mut subscriber = self.client
|
||||||
|
.subscribe(channel.to_string())
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Nats(e.to_string()))?;
|
||||||
|
|
||||||
|
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
|
||||||
|
self.shutdowns.insert(channel.to_string(), shutdown_tx);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = shutdown_rx.changed() => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
message = subscriber.next() => {
|
||||||
|
match message {
|
||||||
|
Some(msg) => {
|
||||||
|
let data = msg.payload.to_vec();
|
||||||
|
if tx.send(data).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Err(e) = subscriber.unsubscribe().await {
|
||||||
|
tracing::warn!("NATS unsubscribe error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
|
||||||
|
if let Some((_, tx)) = self.shutdowns.remove(channel) {
|
||||||
|
let _ = tx.send(true);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&self) -> Result<(), MessageBusError> {
|
||||||
|
// Signal all subscribers to shutdown
|
||||||
|
self.shutdowns.iter().for_each(|entry| {
|
||||||
|
let _ = entry.value().send(true);
|
||||||
|
});
|
||||||
|
self.shutdowns.clear();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use fred::clients::{Client, SubscriberClient};
|
||||||
|
use fred::interfaces::{ClientLike, EventInterface, PubsubInterface};
|
||||||
|
use fred::prelude::*;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::socket::message_bus::{MessageBus, MessageBusError};
|
||||||
|
|
||||||
|
pub struct RedisMessageBus {
|
||||||
|
client: Client,
|
||||||
|
subscriber: SubscriberClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RedisMessageBus {
|
||||||
|
pub async fn new(redis_url: &str) -> Result<Self, MessageBusError> {
|
||||||
|
let config = Config::from_url(redis_url)
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
let client = Client::new(config.clone(), None, None, None);
|
||||||
|
let subscriber = SubscriberClient::new(config, None, None, None);
|
||||||
|
|
||||||
|
// connect() starts the connection task; result is checked by wait_for_connect()
|
||||||
|
let _ = client.connect().await;
|
||||||
|
let _ = subscriber.connect().await;
|
||||||
|
|
||||||
|
client
|
||||||
|
.wait_for_connect()
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
subscriber
|
||||||
|
.wait_for_connect()
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(Self { client, subscriber })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn client(&self) -> &Client {
|
||||||
|
&self.client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MessageBus for RedisMessageBus {
|
||||||
|
async fn publish(&self, channel: &str, message: &[u8]) -> Result<(), MessageBusError> {
|
||||||
|
self.client
|
||||||
|
.publish::<(), _, Vec<u8>>(channel, message.to_vec())
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
|
||||||
|
let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
|
||||||
|
|
||||||
|
self.subscriber
|
||||||
|
.subscribe(channel.to_string())
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
let subscriber = self.subscriber.clone();
|
||||||
|
let channel_owned = channel.to_string();
|
||||||
|
let mut message_rx = subscriber.message_rx();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Ok(message) = message_rx.recv().await {
|
||||||
|
if &message.channel == &channel_owned {
|
||||||
|
let data: Vec<u8> = FromValue::from_value(message.value)
|
||||||
|
.unwrap_or_default();
|
||||||
|
if tx.send(data).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unsubscribe(&self, channel: &str) -> Result<(), MessageBusError> {
|
||||||
|
self.subscriber
|
||||||
|
.unsubscribe(channel.to_string())
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&self) -> Result<(), MessageBusError> {
|
||||||
|
self.client
|
||||||
|
.quit()
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
self.subscriber
|
||||||
|
.quit()
|
||||||
|
.await
|
||||||
|
.map_err(|e| MessageBusError::Redis(e.to_string()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
pub mod adapter;
|
||||||
|
pub mod message_bus;
|
||||||
|
pub mod namespace;
|
||||||
|
pub mod packet;
|
||||||
|
pub mod parser;
|
||||||
|
pub mod server;
|
||||||
|
pub mod session_store;
|
||||||
|
pub mod socket;
|
||||||
|
|
||||||
|
pub use adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, RedisAdapter, NatsAdapter, SocketInfo};
|
||||||
|
pub use message_bus::{MessageBus, MessageBusError, RedisMessageBus, NatsMessageBus};
|
||||||
|
pub use namespace::{is_valid_namespace, Namespace, NamespaceManager};
|
||||||
|
pub use packet::{Packet, PacketType};
|
||||||
|
pub use server::{SocketServer, SocketServerBuilder};
|
||||||
|
pub use session_store::{InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait};
|
||||||
|
pub use socket::Socket;
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::socket::adapter::{Adapter, BroadcastOptions, BroadcastFlags};
|
||||||
|
use crate::socket::packet::Packet;
|
||||||
|
use crate::socket::socket::Socket;
|
||||||
|
|
||||||
|
pub type EventHandler = Arc<dyn Fn(&Socket, &serde_json::Value) + Send + Sync>;
|
||||||
|
type ConnectHandler = Arc<dyn Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync>;
|
||||||
|
|
||||||
|
pub struct Namespace {
|
||||||
|
pub path: String,
|
||||||
|
/// Primary storage: socket_sid → Socket
|
||||||
|
sockets: DashMap<String, Arc<Socket>>,
|
||||||
|
/// Reverse index: engine_sid → socket_sid (for engine-level lookups)
|
||||||
|
engine_to_socket: DashMap<String, String>,
|
||||||
|
handlers: RwLock<HashMap<String, Vec<EventHandler>>>,
|
||||||
|
connect_handler: RwLock<Option<ConnectHandler>>,
|
||||||
|
pub(crate) adapter: RwLock<Option<Arc<dyn Adapter>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Namespace {
|
||||||
|
pub fn new(path: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
path: path.into(),
|
||||||
|
sockets: DashMap::new(),
|
||||||
|
engine_to_socket: DashMap::new(),
|
||||||
|
handlers: RwLock::new(HashMap::new()),
|
||||||
|
connect_handler: RwLock::new(None),
|
||||||
|
adapter: RwLock::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_adapter(&self, adapter: Arc<dyn Adapter>) {
|
||||||
|
let mut guard = self.adapter.write().await;
|
||||||
|
*guard = Some(adapter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a socket to this namespace. Returns Err if the connect handler rejects.
|
||||||
|
pub async fn add_socket(&self, socket: Arc<Socket>) -> Result<(), String> {
|
||||||
|
// Run connect handler before adding to storage
|
||||||
|
let handler = self.connect_handler.read().await;
|
||||||
|
if let Some(ref h) = *handler {
|
||||||
|
h(&socket, None)?;
|
||||||
|
}
|
||||||
|
drop(handler);
|
||||||
|
|
||||||
|
let socket_sid = socket.sid.clone();
|
||||||
|
let engine_sid = socket.engine_sid.clone();
|
||||||
|
|
||||||
|
// Register with adapter (socket_sid → engine_sid mapping)
|
||||||
|
let adapter = self.adapter.read().await;
|
||||||
|
if let Some(ref adapter) = *adapter {
|
||||||
|
if let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await {
|
||||||
|
tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store socket by socket_sid, plus reverse index
|
||||||
|
self.sockets.insert(socket_sid.clone(), socket);
|
||||||
|
self.engine_to_socket.insert(engine_sid, socket_sid);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a socket by its socket SID.
|
||||||
|
pub async fn remove_socket_by_sid(&self, socket_sid: &str) {
|
||||||
|
if let Some((_, socket)) = self.sockets.remove(socket_sid) {
|
||||||
|
self.engine_to_socket.remove(&socket.engine_sid);
|
||||||
|
|
||||||
|
let adapter = self.adapter.read().await;
|
||||||
|
if let Some(ref adapter) = *adapter {
|
||||||
|
if let Err(e) = adapter.del_all(socket_sid, &self.path).await {
|
||||||
|
tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a socket by its engine SID (for engine-level disconnections).
|
||||||
|
pub async fn remove_socket(&self, engine_sid: &str) {
|
||||||
|
if let Some((_, socket_sid)) = self.engine_to_socket.remove(engine_sid) {
|
||||||
|
self.remove_socket_by_sid(&socket_sid).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a socket by its socket SID.
|
||||||
|
pub fn get_socket(&self, socket_sid: &str) -> Option<Arc<Socket>> {
|
||||||
|
self.sockets.get(socket_sid).map(|r| r.value().clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a socket by its engine SID (reverse lookup).
|
||||||
|
pub fn get_socket_by_engine_sid(&self, engine_sid: &str) -> Option<Arc<Socket>> {
|
||||||
|
self.engine_to_socket
|
||||||
|
.get(engine_sid)
|
||||||
|
.and_then(|entry| self.sockets.get(entry.value()).map(|r| r.value().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn socket_count(&self) -> usize {
|
||||||
|
self.sockets.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn on_event(&self, event: impl Into<String>, handler: EventHandler) {
|
||||||
|
let mut handlers = self.handlers.write().await;
|
||||||
|
handlers.entry(event.into()).or_default().push(handler);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn on_connect<F>(&self, handler: F)
|
||||||
|
where
|
||||||
|
F: Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let mut connect_handler = self.connect_handler.write().await;
|
||||||
|
*connect_handler = Some(Arc::new(handler));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn emit(&self, event: impl Into<String>, data: serde_json::Value) {
|
||||||
|
let event_name = event.into();
|
||||||
|
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
|
||||||
|
|
||||||
|
let adapter = self.adapter.read().await;
|
||||||
|
if let Some(ref adapter) = *adapter {
|
||||||
|
let opts = BroadcastOptions::default();
|
||||||
|
if let Err(e) = adapter.broadcast(&packet, &opts).await {
|
||||||
|
tracing::warn!("Adapter broadcast error: {}", e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.emit_local(&packet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn emit_to_room(&self, room: &str, event: impl Into<String>, data: serde_json::Value) {
|
||||||
|
let event_name = event.into();
|
||||||
|
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
|
||||||
|
|
||||||
|
let adapter = self.adapter.read().await;
|
||||||
|
if let Some(ref adapter) = *adapter {
|
||||||
|
let opts = BroadcastOptions {
|
||||||
|
rooms: HashSet::from([room.to_string()]),
|
||||||
|
except: HashSet::new(),
|
||||||
|
flags: BroadcastFlags::default(),
|
||||||
|
};
|
||||||
|
if let Err(e) = adapter.broadcast(&packet, &opts).await {
|
||||||
|
tracing::warn!("Adapter broadcast to room error: {}", e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.emit_local(&packet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit_local(&self, packet: &Packet) {
|
||||||
|
for entry in self.sockets.iter() {
|
||||||
|
let socket = entry.value();
|
||||||
|
if socket.send_packet(packet).is_err() {
|
||||||
|
tracing::warn!("Failed to send event to socket {}", socket.sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn emit_to(&self, socket_sid: &str, event: impl Into<String>, data: serde_json::Value) {
|
||||||
|
if let Some(socket) = self.get_socket(socket_sid) {
|
||||||
|
let event_name = event.into();
|
||||||
|
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
|
||||||
|
if socket.send_packet(&packet).is_err() {
|
||||||
|
tracing::warn!("Failed to send event to socket {}", socket.sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_event(&self, socket: &Socket, event: &str, data: &serde_json::Value) {
|
||||||
|
let handlers = self.handlers.read().await;
|
||||||
|
if let Some(event_handlers) = handlers.get(event) {
|
||||||
|
for handler in event_handlers {
|
||||||
|
handler(socket, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct NamespaceManager {
|
||||||
|
namespaces: DashMap<String, Arc<Namespace>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NamespaceManager {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let manager = Self {
|
||||||
|
namespaces: DashMap::new(),
|
||||||
|
};
|
||||||
|
manager.create_namespace("/");
|
||||||
|
manager
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_namespace(&self, path: impl Into<String>) -> Arc<Namespace> {
|
||||||
|
let path = path.into();
|
||||||
|
let namespace = Arc::new(Namespace::new(&path));
|
||||||
|
self.namespaces.insert(path.clone(), namespace.clone());
|
||||||
|
namespace
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_namespace(&self, path: &str) -> Option<Arc<Namespace>> {
|
||||||
|
self.namespaces.get(path).map(|r| r.value().clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_create_namespace(&self, path: &str) -> Arc<Namespace> {
|
||||||
|
if let Some(ns) = self.get_namespace(path) {
|
||||||
|
ns
|
||||||
|
} else {
|
||||||
|
self.create_namespace(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_namespace(&self, path: &str) {
|
||||||
|
self.namespaces.remove(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn namespace_count(&self) -> usize {
|
||||||
|
self.namespaces.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn all_namespaces(&self) -> Vec<Arc<Namespace>> {
|
||||||
|
self.namespaces.iter().map(|e| e.value().clone()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NamespaceManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a namespace path. Returns true if the path is valid.
|
||||||
|
/// Rules: must start with '/', max 256 chars, no control characters.
|
||||||
|
pub fn is_valid_namespace(path: &str) -> bool {
|
||||||
|
!path.is_empty()
|
||||||
|
&& path.starts_with('/')
|
||||||
|
&& path.len() <= 256
|
||||||
|
&& !path.chars().any(|c| c.is_control())
|
||||||
|
}
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum PacketType {
|
||||||
|
Connect = 0,
|
||||||
|
Disconnect = 1,
|
||||||
|
Event = 2,
|
||||||
|
Ack = 3,
|
||||||
|
ConnectError = 4,
|
||||||
|
BinaryEvent = 5,
|
||||||
|
BinaryAck = 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<u8> for PacketType {
|
||||||
|
type Error = PacketError;
|
||||||
|
|
||||||
|
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
0 => Ok(Self::Connect),
|
||||||
|
1 => Ok(Self::Disconnect),
|
||||||
|
2 => Ok(Self::Event),
|
||||||
|
3 => Ok(Self::Ack),
|
||||||
|
4 => Ok(Self::ConnectError),
|
||||||
|
5 => Ok(Self::BinaryEvent),
|
||||||
|
6 => Ok(Self::BinaryAck),
|
||||||
|
_ => Err(PacketError::InvalidType(value)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<char> for PacketType {
|
||||||
|
type Error = PacketError;
|
||||||
|
|
||||||
|
fn try_from(value: char) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
'0' => Ok(Self::Connect),
|
||||||
|
'1' => Ok(Self::Disconnect),
|
||||||
|
'2' => Ok(Self::Event),
|
||||||
|
'3' => Ok(Self::Ack),
|
||||||
|
'4' => Ok(Self::ConnectError),
|
||||||
|
'5' => Ok(Self::BinaryEvent),
|
||||||
|
'6' => Ok(Self::BinaryAck),
|
||||||
|
_ => Err(PacketError::InvalidTypeChar(value)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Packet {
|
||||||
|
pub packet_type: PacketType,
|
||||||
|
pub namespace: String,
|
||||||
|
pub data: Option<Value>,
|
||||||
|
pub id: Option<u64>,
|
||||||
|
pub attachments: Vec<Vec<u8>>,
|
||||||
|
/// Expected number of binary attachments (set during decode for binary packets).
|
||||||
|
/// Used to validate attachment count before assembling the full packet.
|
||||||
|
pub expected_attachments: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Packet {
|
||||||
|
pub fn connect(namespace: impl Into<String>, data: Option<Value>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Connect,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data,
|
||||||
|
id: None,
|
||||||
|
attachments: Vec::new(),
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn disconnect(namespace: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Disconnect,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data: None,
|
||||||
|
id: None,
|
||||||
|
attachments: Vec::new(),
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn event(namespace: impl Into<String>, data: Value, id: Option<u64>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Event,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data: Some(data),
|
||||||
|
id,
|
||||||
|
attachments: Vec::new(),
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ack(namespace: impl Into<String>, data: Value, id: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::Ack,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data: Some(data),
|
||||||
|
id: Some(id),
|
||||||
|
attachments: Vec::new(),
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connect_error(namespace: impl Into<String>, message: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::ConnectError,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data: Some(serde_json::json!({ "message": message.into() })),
|
||||||
|
id: None,
|
||||||
|
attachments: Vec::new(),
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_event(
|
||||||
|
namespace: impl Into<String>,
|
||||||
|
data: Value,
|
||||||
|
id: Option<u64>,
|
||||||
|
attachments: Vec<Vec<u8>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::BinaryEvent,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data: Some(data),
|
||||||
|
id,
|
||||||
|
attachments,
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_ack(
|
||||||
|
namespace: impl Into<String>,
|
||||||
|
data: Value,
|
||||||
|
id: u64,
|
||||||
|
attachments: Vec<Vec<u8>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
packet_type: PacketType::BinaryAck,
|
||||||
|
namespace: namespace.into(),
|
||||||
|
data: Some(data),
|
||||||
|
id: Some(id),
|
||||||
|
attachments,
|
||||||
|
expected_attachments: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_binary(&self) -> bool {
|
||||||
|
!self.attachments.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn attachment_count(&self) -> usize {
|
||||||
|
self.attachments.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum PacketError {
|
||||||
|
#[error("invalid packet type: {0}")]
|
||||||
|
InvalidType(u8),
|
||||||
|
#[error("invalid packet type char: {0}")]
|
||||||
|
InvalidTypeChar(char),
|
||||||
|
#[error("empty packet")]
|
||||||
|
Empty,
|
||||||
|
#[error("invalid format: {0}")]
|
||||||
|
InvalidFormat(String),
|
||||||
|
#[error("json error: {0}")]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
#[error("missing namespace")]
|
||||||
|
MissingNamespace,
|
||||||
|
#[error("invalid attachment count")]
|
||||||
|
InvalidAttachmentCount,
|
||||||
|
}
|
||||||
@@ -0,0 +1,392 @@
|
|||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::socket::packet::{Packet, PacketError, PacketType};
|
||||||
|
|
||||||
|
pub fn encode(packet: &Packet) -> String {
|
||||||
|
let type_char = packet.packet_type as u8 + b'0';
|
||||||
|
let mut result = String::new();
|
||||||
|
|
||||||
|
result.push(type_char as char);
|
||||||
|
|
||||||
|
if packet.has_binary() {
|
||||||
|
result.push_str(&packet.attachment_count().to_string());
|
||||||
|
result.push('-');
|
||||||
|
}
|
||||||
|
|
||||||
|
if packet.namespace != "/" {
|
||||||
|
result.push_str(&packet.namespace);
|
||||||
|
result.push(',');
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(id) = packet.id {
|
||||||
|
result.push_str(&id.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref data) = packet.data {
|
||||||
|
if packet.has_binary() {
|
||||||
|
let data_with_placeholders = replace_binary_with_placeholders(data, packet.attachment_count());
|
||||||
|
let encoded_data = serde_json::to_string(&data_with_placeholders)
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::error!("Failed to serialize socket packet data: {}", e);
|
||||||
|
"null".to_string()
|
||||||
|
});
|
||||||
|
result.push_str(&encoded_data);
|
||||||
|
} else {
|
||||||
|
let encoded_data = serde_json::to_string(data)
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::error!("Failed to serialize socket packet data: {}", e);
|
||||||
|
"null".to_string()
|
||||||
|
});
|
||||||
|
result.push_str(&encoded_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_with_attachments(packet: &Packet) -> Vec<Vec<u8>> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
|
||||||
|
let encoded = encode(packet);
|
||||||
|
result.push(encoded.into_bytes());
|
||||||
|
|
||||||
|
for attachment in &packet.attachments {
|
||||||
|
result.push(attachment.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(input: &str) -> Result<Packet, PacketError> {
|
||||||
|
if input.is_empty() {
|
||||||
|
return Err(PacketError::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut chars = input.chars().peekable();
|
||||||
|
|
||||||
|
let type_char = chars.next().ok_or(PacketError::Empty)?;
|
||||||
|
let packet_type = PacketType::try_from(type_char)?;
|
||||||
|
|
||||||
|
let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) {
|
||||||
|
let mut count_str = String::new();
|
||||||
|
while let Some(&c) = chars.peek() {
|
||||||
|
if c == '-' {
|
||||||
|
chars.next();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if c.is_ascii_digit() {
|
||||||
|
count_str.push(c);
|
||||||
|
chars.next();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count_str.parse::<usize>().unwrap_or(0)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let remaining: String = chars.collect();
|
||||||
|
|
||||||
|
let (namespace, rest) = if let Some(after_slash) = remaining.strip_prefix('/') {
|
||||||
|
// Check if this is a custom namespace (has a comma separating namespace from data/id)
|
||||||
|
// or if '/' is just the root namespace prefix followed immediately by data
|
||||||
|
if let Some(comma_pos) = after_slash.find(',') {
|
||||||
|
let ns = format!("/{}", &after_slash[..comma_pos]);
|
||||||
|
let rest = after_slash[comma_pos + 1..].to_string();
|
||||||
|
(ns, rest)
|
||||||
|
} else if after_slash.starts_with('[')
|
||||||
|
|| after_slash.starts_with(|c: char| c.is_ascii_digit())
|
||||||
|
|| after_slash.is_empty()
|
||||||
|
{
|
||||||
|
// '/[' means '/' is the root namespace and '[' starts the data
|
||||||
|
// '/<digits>' means root namespace followed by ack id
|
||||||
|
// '/' alone means disconnect on root namespace
|
||||||
|
("/".to_string(), after_slash.to_string())
|
||||||
|
} else {
|
||||||
|
// Non-root namespace without data (e.g., disconnect on custom namespace)
|
||||||
|
(remaining, String::new())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
("/".to_string(), remaining)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (id, data_str) = parse_id_and_data(&rest);
|
||||||
|
|
||||||
|
let data = if data_str.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(serde_json::from_str(&data_str)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Packet {
|
||||||
|
packet_type,
|
||||||
|
namespace,
|
||||||
|
data,
|
||||||
|
id,
|
||||||
|
attachments: Vec::new(),
|
||||||
|
// Store attachment_count for binary packets; actual attachments come via decode_with_attachments
|
||||||
|
expected_attachments: if attachment_count > 0 { Some(attachment_count) } else { None },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode_with_attachments(
|
||||||
|
main_packet: &str,
|
||||||
|
attachments: Vec<Vec<u8>>,
|
||||||
|
) -> Result<Packet, PacketError> {
|
||||||
|
let mut packet = decode(main_packet)?;
|
||||||
|
|
||||||
|
let expected = packet.expected_attachments.unwrap_or(0);
|
||||||
|
if expected != attachments.len() {
|
||||||
|
return Err(PacketError::InvalidAttachmentCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
packet.attachments = attachments;
|
||||||
|
packet.expected_attachments = None;
|
||||||
|
|
||||||
|
if packet.has_binary() {
|
||||||
|
if let Some(ref data) = packet.data {
|
||||||
|
packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_id_and_data(input: &str) -> (Option<u64>, String) {
|
||||||
|
let mut id_str = String::new();
|
||||||
|
let mut chars = input.chars().peekable();
|
||||||
|
|
||||||
|
while let Some(&c) = chars.peek() {
|
||||||
|
if c.is_ascii_digit() {
|
||||||
|
id_str.push(c);
|
||||||
|
chars.next();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = if id_str.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
id_str.parse::<u64>().ok()
|
||||||
|
};
|
||||||
|
|
||||||
|
let data: String = chars.collect();
|
||||||
|
|
||||||
|
(id, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Replace binary values in the data with { "_placeholder": true, "num": N } placeholders.
|
||||||
|
/// This is used when encoding binary events/acks for transmission over text-based transports.
|
||||||
|
fn replace_binary_with_placeholders(value: &Value, total_attachments: usize) -> Value {
|
||||||
|
match value {
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let mut placeholder_idx = total_attachments; // Start from known count
|
||||||
|
let new_arr: Vec<Value> = arr
|
||||||
|
.iter()
|
||||||
|
.map(|v| replace_binary_with_placeholders_inner(v, &mut placeholder_idx))
|
||||||
|
.collect();
|
||||||
|
Value::Array(new_arr)
|
||||||
|
}
|
||||||
|
Value::Object(map) => {
|
||||||
|
let mut placeholder_idx = total_attachments;
|
||||||
|
let mut new_map = serde_json::Map::new();
|
||||||
|
for (k, v) in map {
|
||||||
|
new_map.insert(
|
||||||
|
k.clone(),
|
||||||
|
replace_binary_with_placeholders_inner(v, &mut placeholder_idx),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Value::Object(new_map)
|
||||||
|
}
|
||||||
|
_ => value.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut usize) -> Value {
|
||||||
|
match value {
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let new_arr: Vec<Value> = arr
|
||||||
|
.iter()
|
||||||
|
.map(|v| replace_binary_with_placeholders_inner(v, placeholder_idx))
|
||||||
|
.collect();
|
||||||
|
Value::Array(new_arr)
|
||||||
|
}
|
||||||
|
Value::Object(map) => {
|
||||||
|
let mut new_map = serde_json::Map::new();
|
||||||
|
for (k, v) in map {
|
||||||
|
new_map.insert(
|
||||||
|
k.clone(),
|
||||||
|
replace_binary_with_placeholders_inner(v, placeholder_idx),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Value::Object(new_map)
|
||||||
|
}
|
||||||
|
// Binary data would be represented as base64 strings in the initial data;
|
||||||
|
// in the Socket.IO protocol, binary attachments are separate and referenced by placeholder.
|
||||||
|
// This function handles the case where the data structure itself contains placeholder markers.
|
||||||
|
_ => value.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replace_placeholders_with_binary(value: &Value, attachments: &[Vec<u8>]) -> Value {
|
||||||
|
match value {
|
||||||
|
Value::Object(map) => {
|
||||||
|
// Check if this is a placeholder object: { "_placeholder": true, "num": N }
|
||||||
|
if let (Some(Value::Bool(true)), Some(Value::Number(num))) =
|
||||||
|
(map.get("_placeholder"), map.get("num"))
|
||||||
|
{
|
||||||
|
if let Some(idx) = num.as_u64() {
|
||||||
|
if let Some(attachment) = attachments.get(idx as usize) {
|
||||||
|
return Value::String(base64::Engine::encode(
|
||||||
|
&base64::engine::general_purpose::STANDARD,
|
||||||
|
attachment,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut new_map = serde_json::Map::new();
|
||||||
|
for (k, v) in map {
|
||||||
|
new_map.insert(k.clone(), replace_placeholders_with_binary(v, attachments));
|
||||||
|
}
|
||||||
|
Value::Object(new_map)
|
||||||
|
}
|
||||||
|
Value::Array(arr) => Value::Array(
|
||||||
|
arr.iter()
|
||||||
|
.map(|v| replace_placeholders_with_binary(v, attachments))
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
_ => value.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_connect() {
|
||||||
|
let packet = Packet::connect("/", None);
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "0");
|
||||||
|
|
||||||
|
let packet = Packet::connect("/admin", Some(json!({"sid": "abc"})));
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "0/admin,{\"sid\":\"abc\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_event() {
|
||||||
|
let packet = Packet::event("/", json!(["foo"]), None);
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "2[\"foo\"]");
|
||||||
|
|
||||||
|
let packet = Packet::event("/admin", json!(["bar"]), None);
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "2/admin,[\"bar\"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_event_with_ack() {
|
||||||
|
let packet = Packet::event("/", json!(["foo"]), Some(12));
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "212[\"foo\"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_ack() {
|
||||||
|
let packet = Packet::ack("/", json!([]), 12);
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "312[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_disconnect() {
|
||||||
|
let packet = Packet::disconnect("/");
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "1");
|
||||||
|
|
||||||
|
let packet = Packet::disconnect("/admin");
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "1/admin,");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_connect_error() {
|
||||||
|
let packet = Packet::connect_error("/", "Not authorized");
|
||||||
|
let encoded = encode(&packet);
|
||||||
|
assert_eq!(encoded, "4{\"message\":\"Not authorized\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_connect() {
|
||||||
|
let packet = decode("0").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Connect);
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
assert!(packet.data.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_connect_with_namespace() {
|
||||||
|
let packet = decode("0/admin,{\"sid\":\"abc\"}").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Connect);
|
||||||
|
assert_eq!(packet.namespace, "/admin");
|
||||||
|
assert!(packet.data.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_event() {
|
||||||
|
let packet = decode("2[\"foo\"]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Event);
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
assert_eq!(packet.data, Some(json!(["foo"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_event_with_namespace() {
|
||||||
|
let packet = decode("2/admin,[\"bar\"]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Event);
|
||||||
|
assert_eq!(packet.namespace, "/admin");
|
||||||
|
assert_eq!(packet.data, Some(json!(["bar"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_event_with_ack() {
|
||||||
|
let packet = decode("212[\"foo\"]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Event);
|
||||||
|
assert_eq!(packet.id, Some(12));
|
||||||
|
assert_eq!(packet.data, Some(json!(["foo"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_ack() {
|
||||||
|
let packet = decode("312[]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Ack);
|
||||||
|
assert_eq!(packet.id, Some(12));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_disconnect() {
|
||||||
|
let packet = decode("1").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Disconnect);
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_disconnect_with_namespace() {
|
||||||
|
let packet = decode("1/admin,").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Disconnect);
|
||||||
|
assert_eq!(packet.namespace, "/admin");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_binary_event_attachment_count() {
|
||||||
|
let packet = decode("51-[\"baz\",{\"_placeholder\":true,\"num\":0}]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::BinaryEvent);
|
||||||
|
assert_eq!(packet.expected_attachments, Some(1));
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,301 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::engine::packet::Packet as EnginePacket;
|
||||||
|
use crate::engine::packet::PacketData as EnginePacketData;
|
||||||
|
use crate::engine::server::{EngineConfig, EngineServer};
|
||||||
|
use crate::engine::session::SessionStore;
|
||||||
|
use crate::socket::adapter::{Adapter, LocalAdapter};
|
||||||
|
use crate::socket::namespace::NamespaceManager;
|
||||||
|
use crate::socket::packet::{Packet, PacketType};
|
||||||
|
use crate::socket::parser;
|
||||||
|
use crate::socket::socket::Socket;
|
||||||
|
|
||||||
|
pub struct SocketServer {
|
||||||
|
pub engine: Arc<EngineServer>,
|
||||||
|
pub namespaces: Arc<NamespaceManager>,
|
||||||
|
pub adapter: Arc<dyn Adapter>,
|
||||||
|
socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SocketServer {
|
||||||
|
pub fn new(config: EngineConfig) -> Self {
|
||||||
|
SocketServerBuilder::new(config).build()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn builder(config: EngineConfig) -> SocketServerBuilder {
|
||||||
|
SocketServerBuilder::new(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn of(&self, path: impl Into<String>) -> Arc<crate::socket::namespace::Namespace> {
|
||||||
|
self.namespaces.get_or_create_namespace(&path.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_http(self: Arc<Self>, addr: &str) -> std::io::Result<()> {
|
||||||
|
self.engine.clone().run_http(addr).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_socket(&self, sid: String, tx: mpsc::Sender<Packet>) {
|
||||||
|
self.socket_txs.insert(sid, tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unregister_socket(&self, sid: &str) {
|
||||||
|
self.socket_txs.remove(sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SocketServerBuilder {
|
||||||
|
config: EngineConfig,
|
||||||
|
adapter: Option<Arc<dyn Adapter>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SocketServerBuilder {
|
||||||
|
pub fn new(config: EngineConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
adapter: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn adapter(mut self, adapter: Arc<dyn Adapter>) -> Self {
|
||||||
|
self.adapter = Some(adapter);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> SocketServer {
|
||||||
|
let namespaces = Arc::new(NamespaceManager::new());
|
||||||
|
let socket_txs: Arc<DashMap<String, mpsc::Sender<Packet>>> = Arc::new(DashMap::new());
|
||||||
|
let engine_store = SessionStore::new();
|
||||||
|
|
||||||
|
let namespaces_clone = namespaces.clone();
|
||||||
|
let socket_txs_clone = socket_txs.clone();
|
||||||
|
let engine_store_clone = engine_store.clone();
|
||||||
|
|
||||||
|
let adapter: Arc<dyn Adapter> = self.adapter.unwrap_or_else(|| {
|
||||||
|
let ns_clone = namespaces.clone();
|
||||||
|
let send_fn = move |engine_sid: &str, packet: &Packet| {
|
||||||
|
if let Some(ns) = ns_clone.get_namespace(&packet.namespace) {
|
||||||
|
if let Some(socket) = ns.get_socket_by_engine_sid(engine_sid) {
|
||||||
|
socket.send_packet(packet).map_err(|e| e.to_string())
|
||||||
|
} else {
|
||||||
|
Err(format!(
|
||||||
|
"Socket with engine_sid {} not found in namespace {}",
|
||||||
|
engine_sid, packet.namespace
|
||||||
|
))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(format!("Namespace {} not found", packet.namespace))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Arc::new(LocalAdapter::new(send_fn))
|
||||||
|
});
|
||||||
|
|
||||||
|
let adapter_clone = adapter.clone();
|
||||||
|
let engine = Arc::new(EngineServer::with_store(
|
||||||
|
self.config,
|
||||||
|
engine_store,
|
||||||
|
move |sid, engine_packet| {
|
||||||
|
let namespaces = namespaces_clone.clone();
|
||||||
|
let socket_txs = socket_txs_clone.clone();
|
||||||
|
let engine_store = engine_store_clone.clone();
|
||||||
|
let adapter = adapter_clone.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
handle_engine_message(
|
||||||
|
sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter,
|
||||||
|
).await;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
let server = SocketServer {
|
||||||
|
engine,
|
||||||
|
namespaces,
|
||||||
|
adapter,
|
||||||
|
socket_txs,
|
||||||
|
};
|
||||||
|
|
||||||
|
for ns in server.namespaces.all_namespaces() {
|
||||||
|
let adapter_ref = server.adapter.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
ns.set_adapter(adapter_ref).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
server
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_engine_message(
|
||||||
|
engine_sid: String,
|
||||||
|
engine_packet: EnginePacket,
|
||||||
|
namespaces: &Arc<NamespaceManager>,
|
||||||
|
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||||
|
engine_store: &SessionStore,
|
||||||
|
adapter: &Arc<dyn Adapter>,
|
||||||
|
) {
|
||||||
|
if let EnginePacketData::Text(ref text) = engine_packet.data {
|
||||||
|
if let Ok(socket_packet) = parser::decode(text) {
|
||||||
|
match socket_packet.packet_type {
|
||||||
|
PacketType::Connect => {
|
||||||
|
handle_connect(&engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter).await;
|
||||||
|
}
|
||||||
|
PacketType::Disconnect => {
|
||||||
|
handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs);
|
||||||
|
}
|
||||||
|
PacketType::Event => {
|
||||||
|
handle_event(&engine_sid, &socket_packet, namespaces);
|
||||||
|
}
|
||||||
|
PacketType::Ack => {
|
||||||
|
handle_ack(&engine_sid, &socket_packet);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_connect(
|
||||||
|
engine_sid: &str,
|
||||||
|
packet: &Packet,
|
||||||
|
namespaces: &Arc<NamespaceManager>,
|
||||||
|
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||||
|
engine_store: &SessionStore,
|
||||||
|
adapter: &Arc<dyn Adapter>,
|
||||||
|
) {
|
||||||
|
// Validate namespace path to prevent DoS via arbitrary namespace creation
|
||||||
|
if !crate::socket::namespace::is_valid_namespace(&packet.namespace) {
|
||||||
|
tracing::warn!("Rejected connect with invalid namespace: {}", packet.namespace);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let namespace = namespaces.get_or_create_namespace(&packet.namespace);
|
||||||
|
|
||||||
|
// Ensure newly created namespaces get the shared adapter
|
||||||
|
{
|
||||||
|
let ns_adapter = namespace.adapter.read().await;
|
||||||
|
if ns_adapter.is_none() {
|
||||||
|
drop(ns_adapter);
|
||||||
|
let adapter_ref = adapter.clone();
|
||||||
|
let ns_clone = namespace.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
ns_clone.set_adapter(adapter_ref).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let socket_sid = crate::engine::session::generate_sid();
|
||||||
|
let (tx, mut rx) = mpsc::channel::<Packet>(256);
|
||||||
|
socket_txs.insert(socket_sid.clone(), tx.clone());
|
||||||
|
|
||||||
|
let socket = Arc::new(Socket::new(
|
||||||
|
socket_sid.clone(),
|
||||||
|
packet.namespace.clone(),
|
||||||
|
engine_sid.to_string(),
|
||||||
|
tx,
|
||||||
|
));
|
||||||
|
|
||||||
|
// Run connect handler and add to namespace.
|
||||||
|
// If the handler rejects, clean up and do NOT send a Connect response.
|
||||||
|
if let Err(msg) = namespace.add_socket(socket.clone()).await {
|
||||||
|
tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg);
|
||||||
|
socket_txs.remove(&socket_sid);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect handler passed — spawn forwarding task
|
||||||
|
let engine_store_clone = engine_store.clone();
|
||||||
|
let engine_sid_clone = engine_sid.to_string();
|
||||||
|
let socket_sid_clone = socket_sid.clone();
|
||||||
|
let socket_txs_clone = socket_txs.clone();
|
||||||
|
let namespace_clone = namespace.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(socket_packet) = rx.recv().await {
|
||||||
|
let encoded = parser::encode(&socket_packet);
|
||||||
|
let engine_packet = EnginePacket::message_text(encoded);
|
||||||
|
|
||||||
|
if let Some(session) = engine_store_clone.get(&engine_sid_clone) {
|
||||||
|
let mut s = session.write().await;
|
||||||
|
if s.state == crate::engine::session::SessionState::Closed {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
s.push_packet(engine_packet);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Forwarding task ended — ensure socket is cleaned up from namespace
|
||||||
|
socket_txs_clone.remove(&socket_sid_clone);
|
||||||
|
namespace_clone.remove_socket_by_sid(&socket_sid_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send Connect response (only after handler passed)
|
||||||
|
let response = Packet::connect(
|
||||||
|
&socket.namespace,
|
||||||
|
Some(serde_json::json!({ "sid": &socket.sid })),
|
||||||
|
);
|
||||||
|
|
||||||
|
if socket.send_packet(&response).is_err() {
|
||||||
|
tracing::warn!("Failed to send connect response to socket {}", socket.sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_disconnect(
|
||||||
|
engine_sid: &str,
|
||||||
|
packet: &Packet,
|
||||||
|
namespaces: &Arc<NamespaceManager>,
|
||||||
|
socket_txs: &Arc<DashMap<String, mpsc::Sender<Packet>>>,
|
||||||
|
) {
|
||||||
|
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
|
||||||
|
// Look up socket by engine_sid, then remove by socket_sid
|
||||||
|
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
|
||||||
|
socket_txs.remove(&socket.sid);
|
||||||
|
let socket_sid = socket.sid.clone();
|
||||||
|
let ns_clone = namespace.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
ns_clone.remove_socket_by_sid(&socket_sid).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_event(
|
||||||
|
engine_sid: &str,
|
||||||
|
packet: &Packet,
|
||||||
|
namespaces: &Arc<NamespaceManager>,
|
||||||
|
) {
|
||||||
|
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) {
|
||||||
|
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) {
|
||||||
|
if let Some(ref data) = packet.data {
|
||||||
|
if let Some(arr) = data.as_array() {
|
||||||
|
if let Some(event) = arr.first().and_then(|v| v.as_str()) {
|
||||||
|
let event_data = if arr.len() > 1 {
|
||||||
|
serde_json::Value::Array(arr[1..].to_vec())
|
||||||
|
} else {
|
||||||
|
serde_json::Value::Null
|
||||||
|
};
|
||||||
|
|
||||||
|
let namespace_clone = namespace.clone();
|
||||||
|
let event = event.to_string();
|
||||||
|
let socket_clone = socket.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
namespace_clone
|
||||||
|
.handle_event(&socket_clone, &event, &event_data)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_ack(engine_sid: &str, packet: &Packet) {
|
||||||
|
tracing::debug!(
|
||||||
|
"Received ACK from {} for namespace {} with id {:?}",
|
||||||
|
engine_sid,
|
||||||
|
packet.namespace,
|
||||||
|
packet.id
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
|
||||||
|
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
|
||||||
|
|
||||||
|
pub struct InMemorySessionStore {
|
||||||
|
sessions: Arc<DashMap<String, SessionInfo>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InMemorySessionStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: Arc::new(DashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for InMemorySessionStore {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now_millis() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl SessionStoreTrait for InMemorySessionStore {
|
||||||
|
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
|
||||||
|
let info = SessionInfo {
|
||||||
|
sid: sid.to_string(),
|
||||||
|
transport: transport.to_string(),
|
||||||
|
state: "connecting".to_string(),
|
||||||
|
server_id: server_id.to_string(),
|
||||||
|
created_at: now_millis(),
|
||||||
|
last_ping: now_millis(),
|
||||||
|
};
|
||||||
|
self.sessions.insert(sid.to_string(), info);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
|
||||||
|
Ok(self.sessions.get(sid).map(|r| r.value().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
|
||||||
|
if let Some(mut entry) = self.sessions.get_mut(sid) {
|
||||||
|
entry.value_mut().state = state.to_string();
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(SessionError::NotFound(sid.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
|
||||||
|
if let Some(mut entry) = self.sessions.get_mut(sid) {
|
||||||
|
entry.value_mut().transport = transport.to_string();
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(SessionError::NotFound(sid.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
|
||||||
|
if let Some(mut entry) = self.sessions.get_mut(sid) {
|
||||||
|
entry.value_mut().last_ping = now_millis();
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(SessionError::NotFound(sid.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
|
||||||
|
self.sessions.remove(sid);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
|
||||||
|
Ok(self.sessions.contains_key(sid))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
pub mod memory;
|
||||||
|
pub mod redis;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum SessionError {
|
||||||
|
#[error("Redis error: {0}")]
|
||||||
|
Redis(String),
|
||||||
|
#[error("Session not found: {0}")]
|
||||||
|
NotFound(String),
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
#[error("Session expired: {0}")]
|
||||||
|
Expired(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct SessionInfo {
|
||||||
|
pub sid: String,
|
||||||
|
pub transport: String,
|
||||||
|
pub state: String,
|
||||||
|
pub server_id: String,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub last_ping: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait SessionStoreTrait: Send + Sync + 'static {
|
||||||
|
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError>;
|
||||||
|
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError>;
|
||||||
|
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>;
|
||||||
|
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>;
|
||||||
|
async fn update_ping(&self, sid: &str) -> Result<(), SessionError>;
|
||||||
|
async fn remove(&self, sid: &str) -> Result<(), SessionError>;
|
||||||
|
async fn exists(&self, sid: &str) -> Result<bool, SessionError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub use memory::InMemorySessionStore;
|
||||||
|
pub use redis::RedisSessionStore;
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use fred::prelude::*;
|
||||||
|
|
||||||
|
use crate::socket::message_bus::redis::RedisMessageBus;
|
||||||
|
use crate::socket::session_store::{SessionError, SessionInfo, SessionStoreTrait};
|
||||||
|
|
||||||
|
fn now_millis() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_TTL_SECS: u64 = 60;
|
||||||
|
const KEY_PREFIX: &str = "socket.io:session";
|
||||||
|
|
||||||
|
pub struct RedisSessionStore {
|
||||||
|
client: Client,
|
||||||
|
ttl_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RedisSessionStore {
|
||||||
|
pub fn new(bus: &RedisMessageBus, ttl_secs: Option<u64>) -> Self {
|
||||||
|
Self {
|
||||||
|
client: bus.client().clone(),
|
||||||
|
ttl_secs: ttl_secs.unwrap_or(DEFAULT_TTL_SECS),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn key(&self, sid: &str) -> String {
|
||||||
|
format!("{}:{}", KEY_PREFIX, sid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl SessionStoreTrait for RedisSessionStore {
|
||||||
|
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
let now = now_millis();
|
||||||
|
|
||||||
|
// Batch all fields in a single HSET call for efficiency
|
||||||
|
let fields: Vec<(&str, String)> = vec![
|
||||||
|
("sid", sid.to_string()),
|
||||||
|
("transport", transport.to_string()),
|
||||||
|
("state", "connecting".to_string()),
|
||||||
|
("server_id", server_id.to_string()),
|
||||||
|
("created_at", now.to_string()),
|
||||||
|
("last_ping", now.to_string()),
|
||||||
|
];
|
||||||
|
self.client
|
||||||
|
.hset::<(), _, _>(&key, fields)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
|
||||||
|
// Use hgetall directly — if the key doesn't exist Redis returns an empty map.
|
||||||
|
// This avoids the TOCTOU race between EXISTS and HGETALL.
|
||||||
|
let values: std::collections::HashMap<String, String> = self.client
|
||||||
|
.hgetall::<std::collections::HashMap<String, String>, _>(&key)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
if values.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let info = SessionInfo {
|
||||||
|
sid: values.get("sid").cloned().unwrap_or_default(),
|
||||||
|
transport: values.get("transport").cloned().unwrap_or_default(),
|
||||||
|
state: values.get("state").cloned().unwrap_or_default(),
|
||||||
|
server_id: values.get("server_id").cloned().unwrap_or_default(),
|
||||||
|
created_at: values.get("created_at").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
|
||||||
|
last_ping: values.get("last_ping").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Some(info))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
|
||||||
|
// Use HSET (not HSETNX) to overwrite existing fields
|
||||||
|
self.client
|
||||||
|
.hset::<(), _, _>(&key, ("state", state))
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
|
||||||
|
// Use HSET (not HSETNX) to overwrite existing fields
|
||||||
|
self.client
|
||||||
|
.hset::<(), _, _>(&key, ("transport", transport))
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_ping(&self, sid: &str) -> Result<(), SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
let now = now_millis();
|
||||||
|
|
||||||
|
// Use HSET (not HSETNX) to overwrite existing fields
|
||||||
|
self.client
|
||||||
|
.hset::<(), _, _>(&key, ("last_ping", now.to_string()))
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.expire::<(), _>(&key, self.ttl_secs as i64, None)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remove(&self, sid: &str) -> Result<(), SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.del::<(), _>(&key)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
|
||||||
|
let key = self.key(sid);
|
||||||
|
|
||||||
|
let exists: bool = self.client
|
||||||
|
.exists::<bool, _>(&key)
|
||||||
|
.await
|
||||||
|
.map_err(|e| SessionError::Redis(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(exists)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use crate::socket::packet::Packet;
|
||||||
|
|
||||||
|
pub struct Socket {
|
||||||
|
pub sid: String,
|
||||||
|
pub namespace: String,
|
||||||
|
pub engine_sid: String,
|
||||||
|
ack_id: AtomicU64,
|
||||||
|
tx: mpsc::Sender<Packet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Socket {
|
||||||
|
pub fn new(
|
||||||
|
sid: String,
|
||||||
|
namespace: String,
|
||||||
|
engine_sid: String,
|
||||||
|
tx: mpsc::Sender<Packet>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
sid,
|
||||||
|
namespace,
|
||||||
|
engine_sid,
|
||||||
|
ack_id: AtomicU64::new(0),
|
||||||
|
tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next_ack_id(&self) -> u64 {
|
||||||
|
self.ack_id.fetch_add(1, Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_packet(&self, packet: &Packet) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||||
|
self.tx.try_send(packet.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit(&self, event: impl Into<String>, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||||
|
let packet = Packet::event(
|
||||||
|
&self.namespace,
|
||||||
|
serde_json::json!([event.into(), data]),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
self.send_packet(&packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit_with_ack(
|
||||||
|
&self,
|
||||||
|
event: impl Into<String>,
|
||||||
|
data: serde_json::Value,
|
||||||
|
) -> Result<u64, mpsc::error::TrySendError<Packet>> {
|
||||||
|
let ack_id = self.next_ack_id();
|
||||||
|
let packet = Packet::event(
|
||||||
|
&self.namespace,
|
||||||
|
serde_json::json!([event.into(), data]),
|
||||||
|
Some(ack_id),
|
||||||
|
);
|
||||||
|
self.send_packet(&packet)?;
|
||||||
|
Ok(ack_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn disconnect(&self) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||||
|
let packet = Packet::disconnect(&self.namespace);
|
||||||
|
self.send_packet(&packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_ack(&self, id: u64, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> {
|
||||||
|
let packet = Packet::ack(&self.namespace, data, id);
|
||||||
|
self.send_packet(&packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use imks::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, SocketInfo};
|
||||||
|
use imks::socket::packet::Packet;
|
||||||
|
use imks::socket::session_store::{InMemorySessionStore, SessionInfo, SessionStoreTrait};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_add_and_del() {
|
||||||
|
let sent_packets: dashmap::DashMap<String, Vec<Packet>> = dashmap::DashMap::new();
|
||||||
|
let sent_packets_clone = sent_packets.clone();
|
||||||
|
let send_fn = move |engine_sid: &str, packet: &Packet| {
|
||||||
|
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
let adapter = LocalAdapter::new(send_fn);
|
||||||
|
|
||||||
|
adapter.add("sid1", "room1", "/").await.unwrap();
|
||||||
|
adapter.add("sid1", "room2", "/").await.unwrap();
|
||||||
|
adapter.add("sid2", "room1", "/").await.unwrap();
|
||||||
|
|
||||||
|
let rooms = adapter.socket_rooms("sid1").await.unwrap();
|
||||||
|
assert!(rooms.contains("room1"));
|
||||||
|
assert!(rooms.contains("room2"));
|
||||||
|
|
||||||
|
adapter.del("sid1", "room1", "/").await.unwrap();
|
||||||
|
let rooms = adapter.socket_rooms("sid1").await.unwrap();
|
||||||
|
assert!(!rooms.contains("room1"));
|
||||||
|
assert!(rooms.contains("room2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_del_all() {
|
||||||
|
let send_fn = move |_engine_sid: &str, _packet: &Packet| Ok(());
|
||||||
|
|
||||||
|
let adapter = LocalAdapter::new(send_fn);
|
||||||
|
|
||||||
|
adapter.add("sid1", "room1", "/").await.unwrap();
|
||||||
|
adapter.add("sid1", "room2", "/").await.unwrap();
|
||||||
|
|
||||||
|
adapter.del_all("sid1", "/").await.unwrap();
|
||||||
|
|
||||||
|
let rooms = adapter.socket_rooms("sid1").await.unwrap();
|
||||||
|
assert!(rooms.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_register_and_broadcast() {
|
||||||
|
let sent_packets: Arc<dashmap::DashMap<String, Vec<Packet>>> = Arc::new(dashmap::DashMap::new());
|
||||||
|
let sent_packets_clone = sent_packets.clone();
|
||||||
|
let send_fn = move |engine_sid: &str, packet: &Packet| {
|
||||||
|
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
let adapter = LocalAdapter::new(send_fn);
|
||||||
|
|
||||||
|
// Register socket_sid → engine_sid mapping
|
||||||
|
adapter.register("sid1", "engine1", "/").await.unwrap();
|
||||||
|
adapter.register("sid2", "engine2", "/").await.unwrap();
|
||||||
|
|
||||||
|
let packet = Packet::event("/", serde_json::json!(["test", "hello"]), None);
|
||||||
|
let opts = BroadcastOptions::default();
|
||||||
|
adapter.broadcast(&packet, &opts).await.unwrap();
|
||||||
|
|
||||||
|
assert!(sent_packets.contains_key("engine1"));
|
||||||
|
assert!(sent_packets.contains_key("engine2"));
|
||||||
|
assert_eq!(sent_packets.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_broadcast_to_room() {
|
||||||
|
let sent_packets: Arc<dashmap::DashMap<String, Vec<Packet>>> = Arc::new(dashmap::DashMap::new());
|
||||||
|
let sent_packets_clone = sent_packets.clone();
|
||||||
|
let send_fn = move |engine_sid: &str, packet: &Packet| {
|
||||||
|
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
let adapter = LocalAdapter::new(send_fn);
|
||||||
|
|
||||||
|
adapter.register("sid1", "engine1", "/").await.unwrap();
|
||||||
|
adapter.register("sid2", "engine2", "/").await.unwrap();
|
||||||
|
adapter.add("sid1", "room1", "/").await.unwrap();
|
||||||
|
adapter.add("sid2", "room2", "/").await.unwrap();
|
||||||
|
|
||||||
|
let packet = Packet::event("/", serde_json::json!(["test", "hello"]), None);
|
||||||
|
let opts = BroadcastOptions {
|
||||||
|
rooms: HashSet::from(["room1".to_string()]),
|
||||||
|
except: HashSet::new(),
|
||||||
|
flags: BroadcastFlags::default(),
|
||||||
|
};
|
||||||
|
adapter.broadcast(&packet, &opts).await.unwrap();
|
||||||
|
|
||||||
|
assert!(sent_packets.contains_key("engine1"));
|
||||||
|
assert!(!sent_packets.contains_key("engine2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_broadcast_except() {
|
||||||
|
let sent_packets: Arc<dashmap::DashMap<String, Vec<Packet>>> = Arc::new(dashmap::DashMap::new());
|
||||||
|
let sent_packets_clone = sent_packets.clone();
|
||||||
|
let send_fn = move |engine_sid: &str, packet: &Packet| {
|
||||||
|
sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone());
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
let adapter = LocalAdapter::new(send_fn);
|
||||||
|
|
||||||
|
adapter.register("sid1", "engine1", "/").await.unwrap();
|
||||||
|
adapter.register("sid2", "engine2", "/").await.unwrap();
|
||||||
|
|
||||||
|
let packet = Packet::event("/", serde_json::json!(["test", "hello"]), None);
|
||||||
|
let opts = BroadcastOptions {
|
||||||
|
rooms: HashSet::new(),
|
||||||
|
except: HashSet::from(["sid1".to_string()]),
|
||||||
|
flags: BroadcastFlags::default(),
|
||||||
|
};
|
||||||
|
adapter.broadcast(&packet, &opts).await.unwrap();
|
||||||
|
|
||||||
|
assert!(!sent_packets.contains_key("engine1"));
|
||||||
|
assert!(sent_packets.contains_key("engine2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_fetch_sockets() {
|
||||||
|
let send_fn = move |_engine_sid: &str, _packet: &Packet| Ok(());
|
||||||
|
|
||||||
|
let adapter = LocalAdapter::new(send_fn);
|
||||||
|
|
||||||
|
adapter.register("sid1", "engine1", "/").await.unwrap();
|
||||||
|
adapter.register("sid2", "engine2", "/").await.unwrap();
|
||||||
|
adapter.add("sid1", "room1", "/").await.unwrap();
|
||||||
|
adapter.add("sid2", "room2", "/").await.unwrap();
|
||||||
|
|
||||||
|
let opts = BroadcastOptions::default();
|
||||||
|
let sockets = adapter.fetch_sockets(&opts).await.unwrap();
|
||||||
|
assert_eq!(sockets.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_local_adapter_server_id_unique() {
|
||||||
|
let send_fn1 = move |_engine_sid: &str, _packet: &Packet| Ok(());
|
||||||
|
let send_fn2 = move |_engine_sid: &str, _packet: &Packet| Ok(());
|
||||||
|
|
||||||
|
let adapter1 = LocalAdapter::new(send_fn1);
|
||||||
|
let adapter2 = LocalAdapter::new(send_fn2);
|
||||||
|
|
||||||
|
assert_ne!(adapter1.server_id(), adapter2.server_id());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_in_memory_session_store() {
|
||||||
|
let store = InMemorySessionStore::new();
|
||||||
|
|
||||||
|
store.create("sid1", "polling", "server1").await.unwrap();
|
||||||
|
assert!(store.exists("sid1").await.unwrap());
|
||||||
|
|
||||||
|
let info = store.get("sid1").await.unwrap().unwrap();
|
||||||
|
assert_eq!(info.sid, "sid1");
|
||||||
|
assert_eq!(info.transport, "polling");
|
||||||
|
assert_eq!(info.state, "connecting");
|
||||||
|
assert_eq!(info.server_id, "server1");
|
||||||
|
|
||||||
|
store.set_state("sid1", "open").await.unwrap();
|
||||||
|
let info = store.get("sid1").await.unwrap().unwrap();
|
||||||
|
assert_eq!(info.state, "open");
|
||||||
|
|
||||||
|
store.set_transport("sid1", "websocket").await.unwrap();
|
||||||
|
let info = store.get("sid1").await.unwrap().unwrap();
|
||||||
|
assert_eq!(info.transport, "websocket");
|
||||||
|
|
||||||
|
store.update_ping("sid1").await.unwrap();
|
||||||
|
let info = store.get("sid1").await.unwrap().unwrap();
|
||||||
|
assert!(info.last_ping > 0);
|
||||||
|
|
||||||
|
store.remove("sid1").await.unwrap();
|
||||||
|
assert!(!store.exists("sid1").await.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_in_memory_session_store_not_found() {
|
||||||
|
let store = InMemorySessionStore::new();
|
||||||
|
|
||||||
|
let result = store.get("nonexistent").await.unwrap();
|
||||||
|
assert!(result.is_none());
|
||||||
|
|
||||||
|
let result = store.set_state("nonexistent", "open").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bus_message_serialization() {
|
||||||
|
let msg = BusMessage::Broadcast {
|
||||||
|
namespace: "/".to_string(),
|
||||||
|
packet: "2[\"hello\"]".to_string(),
|
||||||
|
opts: BroadcastOptions::default(),
|
||||||
|
server_id: "server-1".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&msg).unwrap();
|
||||||
|
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bus_message_socket_join() {
|
||||||
|
let msg = BusMessage::SocketJoin {
|
||||||
|
namespace: "/admin".to_string(),
|
||||||
|
sid: "sid-1".to_string(),
|
||||||
|
room: "room-1".to_string(),
|
||||||
|
server_id: "server-1".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&msg).unwrap();
|
||||||
|
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bus_message_socket_leave() {
|
||||||
|
let msg = BusMessage::SocketLeave {
|
||||||
|
namespace: "/".to_string(),
|
||||||
|
sid: "sid-1".to_string(),
|
||||||
|
room: "room-1".to_string(),
|
||||||
|
server_id: "server-1".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&msg).unwrap();
|
||||||
|
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bus_message_socket_disconnect() {
|
||||||
|
let msg = BusMessage::SocketDisconnect {
|
||||||
|
namespace: "/".to_string(),
|
||||||
|
sid: "sid-1".to_string(),
|
||||||
|
server_id: "server-1".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&msg).unwrap();
|
||||||
|
let decoded: BusMessage = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded, msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_broadcast_options_serialization() {
|
||||||
|
let opts = BroadcastOptions {
|
||||||
|
rooms: HashSet::from(["room1".to_string(), "room2".to_string()]),
|
||||||
|
except: HashSet::from(["sid1".to_string()]),
|
||||||
|
flags: BroadcastFlags {
|
||||||
|
local_only: true,
|
||||||
|
broadcast: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&opts).unwrap();
|
||||||
|
let decoded: BroadcastOptions = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded.rooms, opts.rooms);
|
||||||
|
assert_eq!(decoded.except, opts.except);
|
||||||
|
assert_eq!(decoded.flags.local_only, opts.flags.local_only);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_info_serialization() {
|
||||||
|
let info = SocketInfo {
|
||||||
|
sid: "sid-1".to_string(),
|
||||||
|
namespace: "/admin".to_string(),
|
||||||
|
rooms: HashSet::from(["room1".to_string()]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&info).unwrap();
|
||||||
|
let decoded: SocketInfo = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded.sid, info.sid);
|
||||||
|
assert_eq!(decoded.namespace, info.namespace);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_info_serialization() {
|
||||||
|
let info = SessionInfo {
|
||||||
|
sid: "sid-1".to_string(),
|
||||||
|
transport: "websocket".to_string(),
|
||||||
|
state: "open".to_string(),
|
||||||
|
server_id: "server-1".to_string(),
|
||||||
|
created_at: 1234567890,
|
||||||
|
last_ping: 1234567900,
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = serde_json::to_vec(&info).unwrap();
|
||||||
|
let decoded: SessionInfo = serde_json::from_slice(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded.sid, info.sid);
|
||||||
|
assert_eq!(decoded.transport, info.transport);
|
||||||
|
assert_eq!(decoded.state, info.state);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_bus_error_display() {
|
||||||
|
let err = imks::socket::message_bus::MessageBusError::Redis("connection refused".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "Redis error: connection refused");
|
||||||
|
|
||||||
|
let err = imks::socket::message_bus::MessageBusError::Nats("timeout".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "NATS error: timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_adapter_error_display() {
|
||||||
|
let err = AdapterError::Redis("SADD failed".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "Redis error: SADD failed");
|
||||||
|
|
||||||
|
let err = AdapterError::Nats("publish failed".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "NATS error: publish failed");
|
||||||
|
|
||||||
|
let err = AdapterError::Serialization("json error".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "Serialization error: json error");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_error_display() {
|
||||||
|
let err = imks::socket::session_store::SessionError::Redis("timeout".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "Redis error: timeout");
|
||||||
|
|
||||||
|
let err = imks::socket::session_store::SessionError::NotFound("sid-1".to_string());
|
||||||
|
assert_eq!(format!("{}", err), "Session not found: sid-1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_valid_namespace() {
|
||||||
|
assert!(imks::socket::namespace::is_valid_namespace("/"));
|
||||||
|
assert!(imks::socket::namespace::is_valid_namespace("/admin"));
|
||||||
|
assert!(imks::socket::namespace::is_valid_namespace("/chat/room1"));
|
||||||
|
|
||||||
|
assert!(!imks::socket::namespace::is_valid_namespace(""));
|
||||||
|
assert!(!imks::socket::namespace::is_valid_namespace("admin"));
|
||||||
|
assert!(!imks::socket::namespace::is_valid_namespace(&"/".repeat(257)));
|
||||||
|
}
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
use imks::engine::codec;
|
||||||
|
use imks::engine::packet::{HandshakeData, Packet, PacketData, PacketType};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_handshake_encoding() {
|
||||||
|
let handshake = HandshakeData {
|
||||||
|
sid: "lv_VI97HAXpY6yYWAAAC".to_string(),
|
||||||
|
upgrades: vec!["websocket".to_string()],
|
||||||
|
ping_interval: 25000,
|
||||||
|
ping_timeout: 20000,
|
||||||
|
max_payload: 1000000,
|
||||||
|
};
|
||||||
|
|
||||||
|
let packet = Packet::open(&handshake);
|
||||||
|
let encoded = codec::encode_packet(&packet);
|
||||||
|
|
||||||
|
assert!(encoded.starts_with('0'));
|
||||||
|
assert!(encoded.contains("\"sid\":\"lv_VI97HAXpY6yYWAAAC\""));
|
||||||
|
assert!(encoded.contains("\"upgrades\":[\"websocket\"]"));
|
||||||
|
assert!(encoded.contains("\"pingInterval\":25000"));
|
||||||
|
assert!(encoded.contains("\"pingTimeout\":20000"));
|
||||||
|
assert!(encoded.contains("\"maxPayload\":1000000"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_packet_types() {
|
||||||
|
let open = Packet::open(&HandshakeData {
|
||||||
|
sid: "test".to_string(),
|
||||||
|
upgrades: vec![],
|
||||||
|
ping_interval: 25000,
|
||||||
|
ping_timeout: 20000,
|
||||||
|
max_payload: 1000000,
|
||||||
|
});
|
||||||
|
assert_eq!(open.packet_type, PacketType::Open);
|
||||||
|
|
||||||
|
let close = Packet::close();
|
||||||
|
assert_eq!(close.packet_type, PacketType::Close);
|
||||||
|
|
||||||
|
let ping = Packet::ping("test");
|
||||||
|
assert_eq!(ping.packet_type, PacketType::Ping);
|
||||||
|
|
||||||
|
let pong = Packet::pong("test");
|
||||||
|
assert_eq!(pong.packet_type, PacketType::Pong);
|
||||||
|
|
||||||
|
let msg = Packet::message_text("hello");
|
||||||
|
assert_eq!(msg.packet_type, PacketType::Message);
|
||||||
|
|
||||||
|
let upgrade = Packet::upgrade();
|
||||||
|
assert_eq!(upgrade.packet_type, PacketType::Upgrade);
|
||||||
|
|
||||||
|
let noop = Packet::noop();
|
||||||
|
assert_eq!(noop.packet_type, PacketType::Noop);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_polling_payload() {
|
||||||
|
let packets = vec![
|
||||||
|
Packet::message_text("hello"),
|
||||||
|
Packet::ping(""),
|
||||||
|
Packet::message_text("world"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let encoded = codec::encode_payload(&packets);
|
||||||
|
assert_eq!(encoded, "4hello\x1e2\x1e4world");
|
||||||
|
|
||||||
|
let decoded = codec::decode_payload(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.len(), 3);
|
||||||
|
assert_eq!(decoded[0].packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded[1].packet_type, PacketType::Ping);
|
||||||
|
assert_eq!(decoded[2].packet_type, PacketType::Message);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_binary_encoding() {
|
||||||
|
let packet = Packet::message_binary(vec![0x01, 0x02, 0x03, 0x04]);
|
||||||
|
let encoded = codec::encode_packet(&packet);
|
||||||
|
assert_eq!(encoded, "4bAQIDBA==");
|
||||||
|
|
||||||
|
let decoded = codec::decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Message);
|
||||||
|
assert_eq!(decoded.data, PacketData::Binary(vec![0x01, 0x02, 0x03, 0x04]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_webtransport_header_encoding() {
|
||||||
|
let header = codec::encode_webtransport_header(6, false);
|
||||||
|
assert_eq!(header, vec![0x06]);
|
||||||
|
|
||||||
|
let header = codec::encode_webtransport_header(200, true);
|
||||||
|
assert_eq!(header.len(), 3);
|
||||||
|
assert_eq!(header[0], 0x80 | 126);
|
||||||
|
|
||||||
|
let header = codec::encode_webtransport_header(70000, false);
|
||||||
|
assert_eq!(header.len(), 9);
|
||||||
|
assert_eq!(header[0], 127);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_webtransport_header_decoding() {
|
||||||
|
let header = vec![0x06];
|
||||||
|
let (len, is_binary) = codec::decode_webtransport_header(&header).unwrap();
|
||||||
|
assert_eq!(len, 6);
|
||||||
|
assert!(!is_binary);
|
||||||
|
|
||||||
|
let header = vec![0x80 | 126, 0x00, 0xC8];
|
||||||
|
let (len, is_binary) = codec::decode_webtransport_header(&header).unwrap();
|
||||||
|
assert_eq!(len, 200);
|
||||||
|
assert!(is_binary);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_probe_ping_pong() {
|
||||||
|
let ping = Packet::ping("probe");
|
||||||
|
let encoded = codec::encode_packet(&ping);
|
||||||
|
assert_eq!(encoded, "2probe");
|
||||||
|
|
||||||
|
let decoded = codec::decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Ping);
|
||||||
|
assert_eq!(decoded.data, PacketData::Text("probe".to_string()));
|
||||||
|
|
||||||
|
let pong = Packet::pong("probe");
|
||||||
|
let encoded = codec::encode_packet(&pong);
|
||||||
|
assert_eq!(encoded, "3probe");
|
||||||
|
|
||||||
|
let decoded = codec::decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Pong);
|
||||||
|
assert_eq!(decoded.data, PacketData::Text("probe".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_upgrade_packet() {
|
||||||
|
let upgrade = Packet::upgrade();
|
||||||
|
let encoded = codec::encode_packet(&upgrade);
|
||||||
|
assert_eq!(encoded, "5");
|
||||||
|
|
||||||
|
let decoded = codec::decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Upgrade);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_noop_packet() {
|
||||||
|
let noop = Packet::noop();
|
||||||
|
let encoded = codec::encode_packet(&noop);
|
||||||
|
assert_eq!(encoded, "6");
|
||||||
|
|
||||||
|
let decoded = codec::decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Noop);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_io_close_packet() {
|
||||||
|
let close = Packet::close();
|
||||||
|
let encoded = codec::encode_packet(&close);
|
||||||
|
assert_eq!(encoded, "1");
|
||||||
|
|
||||||
|
let decoded = codec::decode_packet(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.packet_type, PacketType::Close);
|
||||||
|
}
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
use imks::engine::session::{generate_sid, SessionState, SessionStore, TransportType};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_store_create_and_get() {
|
||||||
|
let store = SessionStore::new();
|
||||||
|
let sid = generate_sid();
|
||||||
|
|
||||||
|
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||||
|
|
||||||
|
assert!(store.exists(&sid));
|
||||||
|
assert!(store.get(&sid).is_some());
|
||||||
|
assert_eq!(store.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_store_remove() {
|
||||||
|
let store = SessionStore::new();
|
||||||
|
let sid = generate_sid();
|
||||||
|
|
||||||
|
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||||
|
assert!(store.exists(&sid));
|
||||||
|
|
||||||
|
store.remove(&sid);
|
||||||
|
assert!(!store.exists(&sid));
|
||||||
|
assert!(store.get(&sid).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_session_store_multiple_sessions() {
|
||||||
|
let store = SessionStore::new();
|
||||||
|
|
||||||
|
let sid1 = generate_sid();
|
||||||
|
let sid2 = generate_sid();
|
||||||
|
let sid3 = generate_sid();
|
||||||
|
|
||||||
|
let _rx1 = store.create(sid1.clone(), TransportType::Polling);
|
||||||
|
let _rx2 = store.create(sid2.clone(), TransportType::WebSocket);
|
||||||
|
let _rx3 = store.create(sid3.clone(), TransportType::WebTransport);
|
||||||
|
|
||||||
|
assert_eq!(store.len(), 3);
|
||||||
|
assert!(store.exists(&sid1));
|
||||||
|
assert!(store.exists(&sid2));
|
||||||
|
assert!(store.exists(&sid3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_sid_uniqueness() {
|
||||||
|
let sids: Vec<String> = (0..100).map(|_| generate_sid()).collect();
|
||||||
|
|
||||||
|
let unique_sids: std::collections::HashSet<String> = sids.into_iter().collect();
|
||||||
|
assert_eq!(unique_sids.len(), 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_sid_format() {
|
||||||
|
let sid = generate_sid();
|
||||||
|
|
||||||
|
assert_eq!(sid.len(), 20);
|
||||||
|
assert!(sid.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_session_state_transitions() {
|
||||||
|
let store = SessionStore::new();
|
||||||
|
let sid = generate_sid();
|
||||||
|
|
||||||
|
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||||
|
|
||||||
|
if let Some(session) = store.get(&sid) {
|
||||||
|
let mut session = session.write().await;
|
||||||
|
assert_eq!(session.state, SessionState::Connecting);
|
||||||
|
|
||||||
|
session.set_state(SessionState::Open);
|
||||||
|
assert_eq!(session.state, SessionState::Open);
|
||||||
|
|
||||||
|
session.set_state(SessionState::Upgrading);
|
||||||
|
assert_eq!(session.state, SessionState::Upgrading);
|
||||||
|
|
||||||
|
session.set_state(SessionState::Open);
|
||||||
|
assert_eq!(session.state, SessionState::Open);
|
||||||
|
|
||||||
|
session.set_state(SessionState::Closing);
|
||||||
|
assert_eq!(session.state, SessionState::Closing);
|
||||||
|
|
||||||
|
session.set_state(SessionState::Closed);
|
||||||
|
assert_eq!(session.state, SessionState::Closed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_session_transport_change() {
|
||||||
|
let store = SessionStore::new();
|
||||||
|
let sid = generate_sid();
|
||||||
|
|
||||||
|
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||||
|
|
||||||
|
if let Some(session) = store.get(&sid) {
|
||||||
|
let mut session = session.write().await;
|
||||||
|
assert_eq!(session.transport, TransportType::Polling);
|
||||||
|
|
||||||
|
session.set_transport(TransportType::WebSocket);
|
||||||
|
assert_eq!(session.transport, TransportType::WebSocket);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_session_ping_update() {
|
||||||
|
let store = SessionStore::new();
|
||||||
|
let sid = generate_sid();
|
||||||
|
|
||||||
|
let _rx = store.create(sid.clone(), TransportType::Polling);
|
||||||
|
|
||||||
|
if let Some(session) = store.get(&sid) {
|
||||||
|
let mut session = session.write().await;
|
||||||
|
let initial_ping = session.last_ping;
|
||||||
|
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||||
|
session.update_ping();
|
||||||
|
|
||||||
|
assert!(session.last_ping > initial_ping);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transport_type_as_str() {
|
||||||
|
assert_eq!(TransportType::Polling.as_str(), "polling");
|
||||||
|
assert_eq!(TransportType::WebSocket.as_str(), "websocket");
|
||||||
|
assert_eq!(TransportType::WebTransport.as_str(), "webtransport");
|
||||||
|
}
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
use imks::socket::packet::{Packet, PacketType};
|
||||||
|
use imks::socket::parser;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_connect_encoding() {
|
||||||
|
let packet = Packet::connect("/", None);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "0");
|
||||||
|
|
||||||
|
let packet = Packet::connect("/admin", Some(json!({"sid": "abc123"})));
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "0/admin,{\"sid\":\"abc123\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_disconnect_encoding() {
|
||||||
|
let packet = Packet::disconnect("/");
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "1");
|
||||||
|
|
||||||
|
let packet = Packet::disconnect("/admin");
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "1/admin,");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_event_encoding() {
|
||||||
|
let packet = Packet::event("/", json!(["foo"]), None);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "2[\"foo\"]");
|
||||||
|
|
||||||
|
let packet = Packet::event("/admin", json!(["bar"]), None);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "2/admin,[\"bar\"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_event_with_ack_encoding() {
|
||||||
|
let packet = Packet::event("/", json!(["foo"]), Some(12));
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "212[\"foo\"]");
|
||||||
|
|
||||||
|
let packet = Packet::event("/admin", json!(["bar"]), Some(13));
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "2/admin,13[\"bar\"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_ack_encoding() {
|
||||||
|
let packet = Packet::ack("/", json!([]), 12);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "312[]");
|
||||||
|
|
||||||
|
let packet = Packet::ack("/admin", json!(["bar"]), 13);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "3/admin,13[\"bar\"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_connect_error_encoding() {
|
||||||
|
let packet = Packet::connect_error("/", "Not authorized");
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert_eq!(encoded, "4{\"message\":\"Not authorized\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_connect_decoding() {
|
||||||
|
let packet = parser::decode("0").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Connect);
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
assert!(packet.data.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_connect_with_namespace_decoding() {
|
||||||
|
let packet = parser::decode("0/admin,{\"sid\":\"abc\"}").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Connect);
|
||||||
|
assert_eq!(packet.namespace, "/admin");
|
||||||
|
assert!(packet.data.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_disconnect_decoding() {
|
||||||
|
let packet = parser::decode("1").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Disconnect);
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_disconnect_with_namespace_decoding() {
|
||||||
|
let packet = parser::decode("1/admin,").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Disconnect);
|
||||||
|
assert_eq!(packet.namespace, "/admin");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_event_decoding() {
|
||||||
|
let packet = parser::decode("2[\"foo\"]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Event);
|
||||||
|
assert_eq!(packet.namespace, "/");
|
||||||
|
assert_eq!(packet.data, Some(json!(["foo"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_event_with_namespace_decoding() {
|
||||||
|
let packet = parser::decode("2/admin,[\"bar\"]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Event);
|
||||||
|
assert_eq!(packet.namespace, "/admin");
|
||||||
|
assert_eq!(packet.data, Some(json!(["bar"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_event_with_ack_decoding() {
|
||||||
|
let packet = parser::decode("212[\"foo\"]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Event);
|
||||||
|
assert_eq!(packet.id, Some(12));
|
||||||
|
assert_eq!(packet.data, Some(json!(["foo"])));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_ack_decoding() {
|
||||||
|
let packet = parser::decode("312[]").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::Ack);
|
||||||
|
assert_eq!(packet.id, Some(12));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_connect_error_decoding() {
|
||||||
|
let packet = parser::decode("4{\"message\":\"Not authorized\"}").unwrap();
|
||||||
|
assert_eq!(packet.packet_type, PacketType::ConnectError);
|
||||||
|
assert_eq!(packet.data, Some(json!({"message": "Not authorized"})));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_binary_event_encoding() {
|
||||||
|
let packet = Packet::binary_event(
|
||||||
|
"/",
|
||||||
|
json!(["baz", {"_placeholder": true, "num": 0}]),
|
||||||
|
None,
|
||||||
|
vec![vec![0x01, 0x02, 0x03, 0x04]],
|
||||||
|
);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert!(encoded.starts_with("51-"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_binary_ack_encoding() {
|
||||||
|
let packet = Packet::binary_ack(
|
||||||
|
"/",
|
||||||
|
json!(["bar", {"_placeholder": true, "num": 0}]),
|
||||||
|
15,
|
||||||
|
vec![vec![0x01, 0x02, 0x03, 0x04]],
|
||||||
|
);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
assert!(encoded.starts_with("61-"));
|
||||||
|
assert!(encoded.contains("15"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_roundtrip() {
|
||||||
|
let original = Packet::event("/admin", json!(["hello", "world"]), Some(42));
|
||||||
|
let encoded = parser::encode(&original);
|
||||||
|
let decoded = parser::decode(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded.packet_type, original.packet_type);
|
||||||
|
assert_eq!(decoded.namespace, original.namespace);
|
||||||
|
assert_eq!(decoded.id, original.id);
|
||||||
|
assert_eq!(decoded.data, original.data);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_namespace_handling() {
|
||||||
|
let namespaces = vec!["/", "/admin", "/chat", "/api/v1"];
|
||||||
|
|
||||||
|
for ns in namespaces {
|
||||||
|
let packet = Packet::event(ns, json!(["test"]), None);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
let decoded = parser::decode(&encoded).unwrap();
|
||||||
|
assert_eq!(decoded.namespace, ns);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_socket_io_complex_data() {
|
||||||
|
let complex_data = json!([
|
||||||
|
"event_name",
|
||||||
|
{
|
||||||
|
"user": {
|
||||||
|
"id": 123,
|
||||||
|
"name": "test",
|
||||||
|
"roles": ["admin", "user"]
|
||||||
|
},
|
||||||
|
"timestamp": 1234567890
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
|
||||||
|
let packet = Packet::event("/", complex_data.clone(), None);
|
||||||
|
let encoded = parser::encode(&packet);
|
||||||
|
let decoded = parser::decode(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded.data, Some(complex_data));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user