Files
imks/main.rs
T
zhenyi 821537186e refactor(tests): reformat code and update dependency management
- Reorganized import statements in adapter tests for better readability
- Replaced or_insert_with(Vec::new) with or_default() in test closures
- Updated Cargo.lock with new dependency versions and checksums
- Added TLS features to tonic dependency configuration
- Included sqlx, chrono, and uuid dependencies with specific features
- Added jsonwebtoken and arc-swap as project dependencies
- Reformatted assertion statements to comply with line length limits
- Adjusted base64 import order in engine codec module
- Updated protobuf include statement formatting
2026-06-11 12:11:05 +08:00

258 lines
11 KiB
Rust

use std::sync::{Arc, OnceLock};
use imks::database::{Database, DatabaseConfig};
use imks::engine::server::EngineConfig;
use imks::repo::MessageRepo;
use imks::rpc::{AppksClients, RpcConfig};
use imks::socket::adapter::{LocalBroadcastFn, NatsAdapter, RedisAdapter};
use imks::socket::message_bus::{NatsMessageBus, RedisMessageBus};
use imks::socket::server::SocketServerBuilder;
use imks::svc::{DeployConfig, MessageService};
fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let deploy = DeployConfig::from_env();
tracing::info!(
adapter = %deploy.adapter_mode,
server_id = %deploy.server_id,
wt_enabled = deploy.webtransport_enabled,
"Starting imks server"
);
let addr = "0.0.0.0:3000";
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async {
let engine_config = EngineConfig::default();
let mut builder = SocketServerBuilder::new(engine_config);
let namespace_holder: Arc<OnceLock<Arc<imks::socket::namespace::NamespaceManager>>> =
Arc::new(OnceLock::new());
// Pre-configure adapter for Redis/NATS mode.
// The callback resolves namespaces after SocketServer is built.
match deploy.adapter_mode.as_str() {
"redis" => {
let message_bus = Arc::new(
RedisMessageBus::new(&deploy.redis_url)
.await
.map_err(|e| format!("Failed to connect to Redis: {e}"))?,
);
let redis_client = message_bus.client().clone();
let server_id = deploy.server_id.clone();
let adapter = Arc::new(RedisAdapter::new(
message_bus.clone() as Arc<_>,
redis_client,
server_id,
"/".into(),
make_local_broadcast_fn(namespace_holder.clone()),
));
adapter
.init()
.await
.map_err(|e| format!("Failed to initialize Redis adapter: {e}"))?;
builder = builder.adapter(adapter);
tracing::info!("Redis adapter configured for multi-node");
}
"nats" => {
let message_bus = Arc::new(
NatsMessageBus::new(&deploy.nats_url)
.await
.map_err(|e| format!("Failed to connect to NATS: {e}"))?,
);
let server_id = deploy.server_id.clone();
let adapter = Arc::new(NatsAdapter::new(
message_bus.clone() as Arc<_>,
server_id,
"/".into(),
make_local_broadcast_fn(namespace_holder.clone()),
));
adapter
.init()
.await
.map_err(|e| format!("Failed to initialize NATS adapter: {e}"))?;
builder = builder.adapter(adapter);
tracing::info!("NATS adapter configured for multi-node");
}
_ => {
tracing::info!("Local adapter (single-node mode)");
}
};
let socket_server = Arc::new(builder.build());
let _ = namespace_holder.set(socket_server.namespaces.clone());
// Initialize database + gRPC + service
let service: Option<Arc<MessageService>> = {
let rpc_config = RpcConfig::from_env();
let db_config = DatabaseConfig::from_env();
match AppksClients::connect(&rpc_config).await {
Ok(clients) => {
let db = Database::connect(&db_config)
.await
.map_err(|e| format!("Database connection failed: {e}"))?;
imks::database::run_migrations(db.pool())
.await
.map_err(|e| format!("Database migration failed: {e}"))?;
let repo = MessageRepo::new(db.pool().clone());
let svc = MessageService::new(repo, clients, socket_server.namespaces.clone())
.await
.map_err(|e| format!("Failed to initialize message service: {e}"))?;
tracing::info!("Message service initialized with gRPC permission checks");
Some(Arc::new(svc))
}
Err(e) => {
tracing::warn!("gRPC unavailable: {e}. Running without permission checks.");
None
}
}
};
// Register connect handler
let namespace = socket_server.of("/");
let svc_connect = service.clone();
namespace
.on_connect(move |socket, auth_data| {
if let Some(ref svc) = svc_connect {
svc.authenticate_socket(socket, auth_data)
.map_err(|e| e.to_string())?;
}
tracing::info!(
"Socket {} connected (engine: {})",
socket.sid,
socket.engine_sid
);
Ok(())
})
.await;
// Register Socket.IO event handlers
if let Some(ref svc) = service {
macro_rules! register_event {
($svc:expr, $ns:expr, $event:expr, $method:ident) => {
let s = $svc.clone();
$ns.on_event($event, Arc::new(move |socket, data| {
let s = s.clone();
let data = data.clone();
tokio::spawn(async move {
if let Err(e) = s.$method(socket, &data).await {
tracing::error!(event = $event, error = %e, "Event handler failed");
}
});
})).await;
};
}
register_event!(svc, namespace, "channel:join", join_channel);
register_event!(svc, namespace, "channel:leave", leave_channel);
register_event!(svc, namespace, "message:send", send_message);
register_event!(svc, namespace, "message:edit", edit_message);
register_event!(svc, namespace, "message:delete", delete_message);
register_event!(svc, namespace, "reaction:add", toggle_reaction);
register_event!(svc, namespace, "pin:add", pin_message);
register_event!(svc, namespace, "pin:remove", unpin_message);
register_event!(svc, namespace, "poll:vote", poll_vote);
register_event!(svc, namespace, "poll:vote:remove", poll_remove_vote);
register_event!(svc, namespace, "typing:start", typing_start);
register_event!(svc, namespace, "typing:stop", typing_stop);
register_event!(svc, namespace, "presence:update", presence_update);
register_event!(svc, namespace, "draft:save", save_draft);
register_event!(svc, namespace, "draft:get", get_draft);
register_event!(svc, namespace, "draft:delete", delete_draft);
register_event!(svc, namespace, "read_state:mark", mark_read);
register_event!(svc, namespace, "read_state:get", get_read_state);
register_event!(svc, namespace, "notification:list", list_notifications);
register_event!(
svc,
namespace,
"notification:mark_read",
mark_notification_read
);
register_event!(
svc,
namespace,
"notification:mark_all_read",
mark_all_notifications_read
);
register_event!(svc, namespace, "bookmark:add", add_bookmark);
register_event!(svc, namespace, "bookmark:remove", remove_bookmark);
register_event!(svc, namespace, "bookmark:list", list_bookmarks);
register_event!(svc, namespace, "thread:create", create_thread);
register_event!(svc, namespace, "thread:resolve", resolve_thread);
register_event!(svc, namespace, "thread:join", join_thread);
register_event!(svc, namespace, "thread:leave", leave_thread);
register_event!(svc, namespace, "thread:list", list_threads);
register_event!(svc, namespace, "article:create", create_article);
register_event!(svc, namespace, "article:update", update_article);
register_event!(svc, namespace, "article:list", list_articles);
register_event!(svc, namespace, "article:delete", delete_article);
register_event!(svc, namespace, "component:interact", interact_component);
// Start scheduled message dispatcher (background task)
svc.clone().start_scheduled_dispatcher();
tracing::info!("Registered Socket.IO event handlers");
}
// Start servers
if deploy.webtransport_enabled && !deploy.cert_path.is_empty() {
let engine = socket_server.engine.clone();
let wt_port = deploy.webtransport_port;
let cert_path = deploy.cert_path.clone();
let key_path = deploy.key_path.clone();
let server = socket_server.clone();
tracing::info!("Starting HTTP on {} + WebTransport on {}", addr, wt_port);
tokio::select! {
result = server.run_http(addr) => {
result?;
}
result = engine.run_webtransport(wt_port, &cert_path, &key_path) => {
result?;
}
}
} else {
tracing::info!("Socket.IO HTTP server listening on {}", addr);
socket_server.run_http(addr).await?;
}
Ok::<(), Box<dyn std::error::Error>>(())
})?;
Ok(())
}
/// Create a local broadcast function for Redis/NATS adapters.
///
/// The callback is used both for same-node delivery and for cross-node messages
/// received from the message bus.
fn make_local_broadcast_fn(
namespaces: Arc<OnceLock<Arc<imks::socket::namespace::NamespaceManager>>>,
) -> LocalBroadcastFn {
Arc::new(move |packet, opts| {
let Some(manager) = namespaces.get() else {
tracing::warn!(namespace = %packet.namespace, "Namespace manager not initialized");
return;
};
let Some(namespace) = manager.get_namespace(&packet.namespace) else {
tracing::warn!(namespace = %packet.namespace, "Namespace not found for local broadcast");
return;
};
namespace.emit_local_filtered(packet, opts);
})
}