From 821537186e56f68f902c34f2ef80a75755e1ce6c Mon Sep 17 00:00:00 2001 From: zhenyi <434836402@qq.com> Date: Thu, 11 Jun 2026 12:11:05 +0800 Subject: [PATCH] refactor(tests): reformat code and update dependency management - Reorganized import statements in adapter tests for better readability - Replaced or_insert_with(Vec::new) with or_default() in test closures - Updated Cargo.lock with new dependency versions and checksums - Added TLS features to tonic dependency configuration - Included sqlx, chrono, and uuid dependencies with specific features - Added jsonwebtoken and arc-swap as project dependencies - Reformatted assertion statements to comply with line length limits - Adjusted base64 import order in engine codec module - Updated protobuf include statement formatting --- Cargo.lock | 576 +++++++++++++++- Cargo.toml | 8 +- auth/claims.rs | 87 +++ auth/jwt_decoder.rs | 119 ++++ auth/key_store.rs | 171 +++++ auth/mod.rs | 8 + auth/verifier.rs | 98 +++ database/config.rs | 71 ++ database/migration.rs | 37 + database/mod.rs | 7 + database/pool.rs | 75 +++ engine/codec.rs | 9 +- engine/health.rs | 26 + engine/heartbeat.rs | 2 +- engine/mod.rs | 1 + engine/packet.rs | 11 +- engine/polling.rs | 4 +- engine/server.rs | 60 +- engine/session.rs | 9 +- engine/upgrade.rs | 5 +- engine/websocket.rs | 97 +-- engine/webtransport.rs | 157 +++-- error.rs | 255 +++++++ lib.rs | 11 +- main.rs | 266 +++++++- migrate/000_message_thread_base.sql | 23 + migrate/001_message_rich_content.sql | 115 ++++ migrate/002_message_social.sql | 76 +++ migrate/003_message_article.sql | 45 ++ migrate/004_message_social_part2.sql | 98 +++ migrate/005_message_misc.sql | 102 +++ models/message.rs | 175 +++++ models/message_article.rs | 196 ++++++ models/message_attachment.rs | 96 +++ models/message_bookmark.rs | 46 ++ models/message_component.rs | 95 +++ models/message_draft.rs | 77 +++ models/message_edit.rs | 77 +++ models/message_embed.rs | 163 +++++ models/message_forward.rs | 48 ++ models/message_mention.rs | 123 ++++ models/message_notification.rs | 97 +++ models/message_pin.rs | 60 ++ models/message_poll.rs | 170 +++++ models/message_reaction.rs | 66 ++ models/message_read_state.rs | 92 +++ models/message_scheduled.rs | 82 +++ models/message_sticker.rs | 56 ++ models/message_thread.rs | 60 ++ models/message_thread_participant.rs | 71 ++ models/mod.rs | 43 ++ pb/core.rs | 2 +- pb/mod.rs | 2 +- repo/message_article.rs | 263 ++++++++ repo/message_attachment.rs | 83 +++ repo/message_bookmark.rs | 112 ++++ repo/message_component.rs | 88 +++ repo/message_create.rs | 116 ++++ repo/message_draft.rs | 113 ++++ repo/message_edit.rs | 77 +++ repo/message_embed.rs | 106 +++ repo/message_forward.rs | 44 ++ repo/message_mention.rs | 111 +++ repo/message_notification.rs | 140 ++++ repo/message_pin.rs | 110 +++ repo/message_poll.rs | 396 +++++++++++ repo/message_query.rs | 169 +++++ repo/message_reaction.rs | 94 +++ repo/message_read_state.rs | 110 +++ repo/message_repo.rs | 27 + repo/message_scheduled.rs | 159 +++++ repo/message_sticker.rs | 53 ++ repo/message_thread.rs | 218 ++++++ repo/mod.rs | 25 + repo/pagination.rs | 114 ++++ rpc/clients.rs | 108 +++ rpc/config.rs | 65 ++ rpc/mod.rs | 5 + socket/adapter/local.rs | 52 +- socket/adapter/mod.rs | 23 +- socket/adapter/nats.rs | 122 +++- socket/adapter/redis.rs | 141 +++- socket/message_bus/mod.rs | 4 +- socket/message_bus/nats.rs | 5 +- socket/message_bus/redis.rs | 11 +- socket/mod.rs | 16 +- socket/namespace.rs | 169 ++++- socket/parser.rs | 60 +- socket/server.rs | 96 +-- socket/session_store/memory.rs | 9 +- socket/session_store/mod.rs | 5 +- socket/session_store/redis.rs | 25 +- socket/socket.rs | 29 +- svc/article.rs | 221 ++++++ svc/bookmark.rs | 75 +++ svc/component.rs | 98 +++ svc/deploy.rs | 62 ++ svc/draft.rs | 105 +++ svc/message.rs | 970 +++++++++++++++++++++++++++ svc/mod.rs | 19 + svc/pin.rs | 97 +++ svc/poll.rs | 97 +++ svc/reaction.rs | 127 ++++ svc/read_state.rs | 126 ++++ svc/scheduled.rs | 80 +++ svc/tests.rs | 166 +++++ svc/thread.rs | 235 +++++++ svc/typing.rs | 113 ++++ tests/adapter_tests.rs | 41 +- tests/engine_io_tests.rs | 5 +- tests/session_tests.rs | 7 +- 111 files changed, 10458 insertions(+), 385 deletions(-) create mode 100644 auth/claims.rs create mode 100644 auth/jwt_decoder.rs create mode 100644 auth/key_store.rs create mode 100644 auth/mod.rs create mode 100644 auth/verifier.rs create mode 100644 database/config.rs create mode 100644 database/migration.rs create mode 100644 database/mod.rs create mode 100644 database/pool.rs create mode 100644 engine/health.rs create mode 100644 error.rs create mode 100644 migrate/000_message_thread_base.sql create mode 100644 migrate/001_message_rich_content.sql create mode 100644 migrate/002_message_social.sql create mode 100644 migrate/003_message_article.sql create mode 100644 migrate/004_message_social_part2.sql create mode 100644 migrate/005_message_misc.sql create mode 100644 models/message.rs create mode 100644 models/message_article.rs create mode 100644 models/message_attachment.rs create mode 100644 models/message_bookmark.rs create mode 100644 models/message_component.rs create mode 100644 models/message_draft.rs create mode 100644 models/message_edit.rs create mode 100644 models/message_embed.rs create mode 100644 models/message_forward.rs create mode 100644 models/message_mention.rs create mode 100644 models/message_notification.rs create mode 100644 models/message_pin.rs create mode 100644 models/message_poll.rs create mode 100644 models/message_reaction.rs create mode 100644 models/message_read_state.rs create mode 100644 models/message_scheduled.rs create mode 100644 models/message_sticker.rs create mode 100644 models/message_thread.rs create mode 100644 models/message_thread_participant.rs create mode 100644 models/mod.rs create mode 100644 repo/message_article.rs create mode 100644 repo/message_attachment.rs create mode 100644 repo/message_bookmark.rs create mode 100644 repo/message_component.rs create mode 100644 repo/message_create.rs create mode 100644 repo/message_draft.rs create mode 100644 repo/message_edit.rs create mode 100644 repo/message_embed.rs create mode 100644 repo/message_forward.rs create mode 100644 repo/message_mention.rs create mode 100644 repo/message_notification.rs create mode 100644 repo/message_pin.rs create mode 100644 repo/message_poll.rs create mode 100644 repo/message_query.rs create mode 100644 repo/message_reaction.rs create mode 100644 repo/message_read_state.rs create mode 100644 repo/message_repo.rs create mode 100644 repo/message_scheduled.rs create mode 100644 repo/message_sticker.rs create mode 100644 repo/message_thread.rs create mode 100644 repo/mod.rs create mode 100644 repo/pagination.rs create mode 100644 rpc/clients.rs create mode 100644 rpc/config.rs create mode 100644 rpc/mod.rs create mode 100644 svc/article.rs create mode 100644 svc/bookmark.rs create mode 100644 svc/component.rs create mode 100644 svc/deploy.rs create mode 100644 svc/draft.rs create mode 100644 svc/message.rs create mode 100644 svc/mod.rs create mode 100644 svc/pin.rs create mode 100644 svc/poll.rs create mode 100644 svc/reaction.rs create mode 100644 svc/read_state.rs create mode 100644 svc/scheduled.rs create mode 100644 svc/tests.rs create mode 100644 svc/thread.rs create mode 100644 svc/typing.rs diff --git a/Cargo.lock b/Cargo.lock index 22049bd..ca073bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,7 +37,7 @@ dependencies = [ "derive_more", "encoding_rs", "flate2", - "foldhash", + "foldhash 0.1.5", "futures-core", "h2 0.3.27", "http 0.2.12", @@ -152,7 +152,7 @@ dependencies = [ "cookie", "derive_more", "encoding_rs", - "foldhash", + "foldhash 0.1.5", "futures-core", "futures-util", "impl-more", @@ -232,6 +232,21 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -333,6 +348,15 @@ dependencies = [ "syn", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -414,6 +438,9 @@ name = "bitflags" version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" +dependencies = [ + "serde_core", +] [[package]] name = "block-buffer" @@ -460,6 +487,12 @@ version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" @@ -523,6 +556,35 @@ dependencies = [ "rand_core 0.10.1", ] +[[package]] +name = "chrono" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "cmov" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9ea0ac24bc397ab3c98583a3c9ba74fa56b09a4449bbe172b9b1ddb016027a" + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -605,6 +667,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "217698eaf96b4a3f0bc4f3662aaa55bdf913cd54d7204591faa790070c6d0853" + [[package]] name = "crc16" version = "0.4.0" @@ -620,6 +697,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -645,6 +731,15 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "ctutils" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e" +dependencies = [ + "cmov", +] + [[package]] name = "curve25519-dalek" version = "4.1.3" @@ -768,6 +863,7 @@ dependencies = [ "block-buffer 0.12.0", "const-oid 0.10.2", "crypto-common 0.2.2", + "ctutils", ] [[package]] @@ -781,6 +877,12 @@ dependencies = [ "syn", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "ed25519" version = "2.2.3" @@ -808,6 +910,9 @@ name = "either" version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" +dependencies = [ + "serde", +] [[package]] name = "encoding_rs" @@ -834,6 +939,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "etcetera" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de48cc4d1c1d97a20fd819def54b890cadde72ed3ad0c614822a0a433361be96" +dependencies = [ + "cfg-if", + "windows-sys 0.61.2", +] + +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.4.1" @@ -877,6 +1003,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -889,6 +1026,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -977,6 +1120,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.32" @@ -1124,7 +1278,18 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "foldhash", + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -1133,12 +1298,45 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" +[[package]] +name = "hashlink" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824e001ac4f3012dd16a264bec811403a67ca9deb6c102fc5049b32c4574b35f" +dependencies = [ + "hashbrown 0.16.1", +] + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hkdf" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4aaa26c720c68b866f2c96ef5c1264b3e6f473fe5d4ce61cd44bbe913e553018" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" +dependencies = [ + "digest 0.11.3", +] + [[package]] name = "httlib-huffman" version = "0.3.4" @@ -1265,6 +1463,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.2.0" @@ -1381,17 +1603,21 @@ dependencies = [ "actix-rt", "actix-web", "actix-ws", + "arc-swap", "async-nats", "async-trait", "base64", + "chrono", "dashmap", "fred", "futures-util", + "jsonwebtoken", "prost", "prost-types", "rand 0.9.4", "serde", "serde_json", + "sqlx", "thiserror 2.0.18", "tokio", "tonic", @@ -1460,6 +1686,21 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +dependencies = [ + "base64", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -1484,6 +1725,16 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "libsqlite3-sys" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f111c8c41e7c61a49cd34e44c7619462967221a6443b0ec299e0ac30cfb9b1" +dependencies = [ + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -1549,6 +1800,16 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "md-5" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69b6441f590336821bb897fb28fc622898ccceb1d6cea3fde5ea86b090c4de98" +dependencies = [ + "cfg-if", + "digest 0.11.3", +] + [[package]] name = "memchr" version = "2.8.1" @@ -1705,6 +1966,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.5" @@ -2208,6 +2475,7 @@ version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -2521,6 +2789,18 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" +[[package]] +name = "simple_asn1" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.18", + "time", +] + [[package]] name = "slab" version = "0.4.12" @@ -2532,6 +2812,9 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] [[package]] name = "socket2" @@ -2553,6 +2836,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -2563,12 +2855,202 @@ dependencies = [ "der", ] +[[package]] +name = "sqlx" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "378620ccc25c62c89d8be1c819e76a88d59bdcc3304733330788948e619bfd71" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b44e85bf579a8eeb4ceaa77a3a523baf2bf0e9bac7e40f405d537b5d2d5ccb" +dependencies = [ + "base64", + "bytes", + "cfg-if", + "chrono", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.16.1", + "hashlink", + "indexmap", + "log", + "memchr", + "percent-encoding", + "serde", + "serde_json", + "sha2 0.10.9", + "smallvec", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tracing", + "url", + "uuid", +] + +[[package]] +name = "sqlx-macros" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd2b84f2bc39a5705ef27ec785a11c934a41bbd4a24941e257927cddc26b60bf" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb8d96de5fdc85a5c4ec813432b523ec637e80ba98f046555f75f7908ddac7c3" +dependencies = [ + "cfg-if", + "dotenvy", + "either", + "heck", + "hex", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2 0.10.9", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn", + "thiserror 2.0.18", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90b8020fe17c5f2c245bfa2505d7ef59c5604839527c740266ad2214acebea27" +dependencies = [ + "bitflags", + "byteorder", + "bytes", + "chrono", + "crc", + "digest 0.11.3", + "dotenvy", + "either", + "futures-core", + "futures-util", + "generic-array", + "log", + "percent-encoding", + "serde", + "sha1", + "sha2 0.11.0", + "sqlx-core", + "thiserror 2.0.18", + "tracing", + "uuid", +] + +[[package]] +name = "sqlx-postgres" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87a2bdd6e83f6b3ea525ca9fee568030508b58355a43d0b2c1674d5f79dcd65e" +dependencies = [ + "atoi", + "base64", + "bitflags", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "rand 0.10.1", + "serde", + "serde_json", + "sha2 0.11.0", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.18", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488e99c397a62007e4229aec669a179816339afc6d2620ca6fa420dbee2e982c" +dependencies = [ + "atoi", + "chrono", + "flume", + "form_urlencoded", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "thiserror 2.0.18", + "tracing", + "url", + "uuid", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "subtle" version = "2.6.1" @@ -2828,6 +3310,7 @@ dependencies = [ "socket2 0.6.4", "sync_wrapper", "tokio", + "tokio-rustls", "tokio-stream", "tower", "tower-layer", @@ -3008,12 +3491,33 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-properties" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" + [[package]] name = "unicode-segmentation" version = "1.13.3" @@ -3064,6 +3568,7 @@ checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "getrandom 0.4.2", "js-sys", + "serde_core", "wasm-bindgen", ] @@ -3073,6 +3578,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -3211,6 +3722,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "whoami" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "998767ef88740d1f5b0682a9c53c24431453923962269c2db68ee43788c5a40d" + [[package]] name = "winapi-util" version = "0.1.11" @@ -3220,12 +3737,65 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 0cd26dd..99761a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ name = "imks" [dependencies] -tonic = "0.14.6" +tonic = { version = "0.14.6", features = ["tls-ring"] } prost = "0.14.3" prost-types = "0.14" tonic-build = "0.14.6" @@ -26,6 +26,9 @@ actix-ws = { version = "0.4.0", features = [] } actix-rt = "2" serde = { version = "1", features = ["derive"] } serde_json = { version = "1" } +sqlx = { version = "0.9", features = ["postgres", "runtime-tokio", "chrono", "uuid", "json", "migrate"] } +chrono = { version = "0.4", features = ["serde"] } +uuid = { version = "1", features = ["v4", "v7", "serde"] } base64 = "0.22" rand = "0.9" wtransport = "0.7" @@ -36,8 +39,9 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } fred = { version = "10", features = ["subscriber-client"] } async-nats = "0.38" -uuid = { version = "1", features = ["v4"] } futures-util = "0.3" +jsonwebtoken = "9" +arc-swap = "1" [build-dependencies] diff --git a/auth/claims.rs b/auth/claims.rs new file mode 100644 index 0000000..ca37ef4 --- /dev/null +++ b/auth/claims.rs @@ -0,0 +1,87 @@ +//! JWT claims structure — mirrors proto `TokenClaims` for local verification. +//! +//! Used as the deserialization target for `jsonwebtoken::decode`. +//! Field names match standard JWT claim names (`sub`, `iss`, `exp`, etc.). + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +/// Parsed JWT payload, matching the proto `TokenClaims` shape. +/// +/// Deserialized by `jsonwebtoken` during HS256 verification. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenClaims { + /// Subject — the user UUID. + pub sub: String, + /// Issuer — expected to be `"appks"`. + pub iss: String, + /// Issued-at (unix seconds). + pub iat: i64, + /// Expiration (unix seconds). + pub exp: i64, + /// Unique token ID (used for revocation tracking via `jti`). + pub jti: String, + /// Space-separated scopes, e.g. `"im:read im:write"`. + pub scope: String, + /// Extensible metadata (workspace_id, role, etc.). + #[serde(default)] + pub extra: HashMap, +} + +impl TokenClaims { + /// Check whether this token carries a specific scope. + pub fn has_scope(&self, scope: &str) -> bool { + self.scope.split_whitespace().any(|s| s == scope) + } + + /// Convert from the proto-generated `TokenClaims` (RPC verify response). + pub fn from_proto(proto: crate::pb::core::TokenClaims) -> Self { + Self { + sub: proto.sub, + iss: proto.iss, + iat: proto.iat, + exp: proto.exp, + jti: proto.jti, + scope: proto.scope, + extra: proto.extra, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_scope() { + let claims = TokenClaims { + sub: "user-1".into(), + iss: "appks".into(), + iat: 0, + exp: 9999999999, + jti: "tok-1".into(), + scope: "im:read im:write admin".into(), + extra: HashMap::new(), + }; + assert!(claims.has_scope("im:read")); + assert!(claims.has_scope("admin")); + assert!(!claims.has_scope("im:delete")); + } + + #[test] + fn test_deserialize_from_json() { + let json = r#"{ + "sub": "user-1", + "iss": "appks", + "iat": 1000, + "exp": 2000, + "jti": "tok-1", + "scope": "im:read", + "extra": {"workspace_id": "ws-1"} + }"#; + let claims: TokenClaims = serde_json::from_str(json).unwrap(); + assert_eq!(claims.sub, "user-1"); + assert_eq!(claims.extra.get("workspace_id").unwrap(), "ws-1"); + } +} diff --git a/auth/jwt_decoder.rs b/auth/jwt_decoder.rs new file mode 100644 index 0000000..2d94b3d --- /dev/null +++ b/auth/jwt_decoder.rs @@ -0,0 +1,119 @@ +//! Low-level HS256 JWT decoding and verification. +//! +//! Stateless functions — no caching or key management. +//! Used by the `Authenticator` in combination with `SigningKeyStore`. + +use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header}; + +use crate::{ImksError, ImksResult}; + +use super::claims::TokenClaims; + +/// Expected JWT issuer claim. +const EXPECTED_ISSUER: &str = "appks"; +/// Signing algorithm used by appks. +const ALGORITHM: Algorithm = Algorithm::HS256; + +/// Extract the `kid` from a JWT header without verifying the signature. +/// +/// This is the first step in local verification: find which signing key +/// was used, then look it up in the `SigningKeyStore`. +pub fn extract_kid(token: &str) -> ImksResult { + let header = decode_header(token).map_err(map_jwt_error)?; + header + .kid + .ok_or_else(|| ImksError::Auth("JWT header missing 'kid' field".into())) +} + +/// Verify an HS256 JWT signature and decode its claims. +/// +/// Validates: algorithm, issuer, expiration. Does NOT validate audience. +pub fn verify_and_decode(token: &str, key: &DecodingKey) -> ImksResult { + let validation = build_validation(); + let token_data = decode::(token, key, &validation).map_err(map_jwt_error)?; + Ok(token_data.claims) +} + +/// Build the standard `Validation` config for imks JWT verification. +fn build_validation() -> Validation { + let mut validation = Validation::new(ALGORITHM); + validation.set_issuer(&[EXPECTED_ISSUER]); + validation.validate_exp = true; + // Audience validation not required for imks tokens. + validation.validate_aud = false; + validation +} + +/// Map `jsonwebtoken` errors to `ImksError`, distinguishing expired tokens. +fn map_jwt_error(e: jsonwebtoken::errors::Error) -> ImksError { + use jsonwebtoken::errors::ErrorKind; + match e.kind() { + ErrorKind::ExpiredSignature => ImksError::TokenExpired, + ErrorKind::InvalidSignature => ImksError::Auth("invalid JWT signature".into()), + ErrorKind::InvalidIssuer => ImksError::Auth("invalid JWT issuer".into()), + _ => ImksError::Auth(format!("JWT error: {e}")), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use jsonwebtoken::{EncodingKey, Header, encode}; + + fn make_test_token(claims: &TokenClaims, secret: &[u8]) -> String { + let mut header = Header::new(ALGORITHM); + header.kid = Some("test-kid".into()); + encode(&header, claims, &EncodingKey::from_secret(secret)).unwrap() + } + + #[test] + fn test_extract_kid() { + let claims = TokenClaims { + sub: "u1".into(), + iss: "appks".into(), + iat: 1000, + exp: 9999999999, + jti: "j1".into(), + scope: "im:read".into(), + extra: Default::default(), + }; + let token = make_test_token(&claims, b"secret"); + let kid = extract_kid(&token).unwrap(); + assert_eq!(kid, "test-kid"); + } + + #[test] + fn test_verify_and_decode_valid() { + let secret = b"test-secret-key-material-32bytes!"; + let claims = TokenClaims { + sub: "user-1".into(), + iss: "appks".into(), + iat: 1000, + exp: 9999999999, + jti: "tok-1".into(), + scope: "im:read".into(), + extra: Default::default(), + }; + let token = make_test_token(&claims, secret); + let key = DecodingKey::from_secret(secret); + let decoded = verify_and_decode(&token, &key).unwrap(); + assert_eq!(decoded.sub, "user-1"); + } + + #[test] + fn test_verify_rejects_wrong_key() { + let claims = TokenClaims { + sub: "u1".into(), + iss: "appks".into(), + iat: 1000, + exp: 9999999999, + jti: "j1".into(), + scope: "".into(), + extra: Default::default(), + }; + let token = make_test_token(&claims, b"correct-secret"); + let wrong_key = DecodingKey::from_secret(b"wrong-secret"); + let result = verify_and_decode(&token, &wrong_key); + assert!(result.is_err()); + } +} diff --git a/auth/key_store.rs b/auth/key_store.rs new file mode 100644 index 0000000..8d41458 --- /dev/null +++ b/auth/key_store.rs @@ -0,0 +1,171 @@ +//! Signing key store with atomic reads and periodic background refresh. +//! +//! Fetches HS256 signing keys from appks via `GetSigningKeys` RPC, +//! caches them behind `ArcSwap` for lock-free reads, and schedules +//! re-fetch when `next_rotation_at` is reached. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use base64::{Engine, engine::general_purpose::STANDARD as BASE64}; +use jsonwebtoken::DecodingKey; +use tokio::task::JoinHandle; +use tonic::transport::Channel; + +use crate::pb::core::GetSigningKeysRequest; +use crate::pb::core::token_service_client::TokenServiceClient; +use crate::{ImksError, ImksResult}; + +/// A cached signing key entry with a pre-computed `DecodingKey`. +struct CachedKey { + kid: String, + decoding_key: DecodingKey, + /// Unix timestamp (seconds) when this key expires. + expires_at: i64, + /// Whether this is the current active signing key. + active: bool, +} + +/// Thread-safe store of signing keys with periodic background refresh. +/// +/// Reads via `get_key()` are lock-free (ArcSwap). +/// A background task re-fetches keys from appks at each rotation window. +pub struct SigningKeyStore { + keys: Arc>>, + refresh_handle: Option>, +} + +impl SigningKeyStore { + /// Fetch initial keys from appks and start the background refresh loop. + pub async fn init(mut client: TokenServiceClient) -> ImksResult { + let (cached, next_rotation) = fetch_keys(&mut client).await?; + + let map: HashMap = + cached.into_iter().map(|k| (k.kid.clone(), k)).collect(); + + let keys = Arc::new(ArcSwap::from_pointee(map)); + + let keys_clone = keys.clone(); + let client_clone = client; + let refresh_handle = tokio::spawn(async move { + refresh_loop(client_clone, keys_clone, next_rotation).await; + }); + + tracing::info!("SigningKeyStore initialized with background refresh"); + + Ok(Self { + keys, + refresh_handle: Some(refresh_handle), + }) + } + + /// Look up a decoding key by its `kid`. Returns `None` if unknown or expired. + /// + /// Inactive keys (from a previous rotation window) are still served so they + /// can validate tokens signed before the rotation. Expired keys (past their + /// 3h window) are rejected as a local safety net even though the RPC should + /// not return them. + pub fn get_key(&self, kid: &str) -> Option { + let map = self.keys.load(); + let cached = map.get(kid)?; + + debug_assert_eq!(cached.kid, kid, "CachedKey kid must match its HashMap key"); + + let now = chrono::Utc::now().timestamp(); + if cached.expires_at > 0 && now >= cached.expires_at { + tracing::warn!( + kid = %cached.kid, + expires_at = cached.expires_at, + "Rejecting expired signing key" + ); + return None; + } + + if !cached.active { + tracing::debug!( + kid = %cached.kid, + "Serving inactive signing key (previous rotation window)" + ); + } + + Some(cached.decoding_key.clone()) + } + + /// Stop the background refresh task. + pub async fn shutdown(mut self) { + if let Some(handle) = self.refresh_handle.take() { + handle.abort(); + } + } +} + +impl Drop for SigningKeyStore { + fn drop(&mut self) { + if let Some(handle) = self.refresh_handle.take() { + handle.abort(); + } + } +} + +/// Fetch all active signing keys from appks. +async fn fetch_keys(client: &mut TokenServiceClient) -> ImksResult<(Vec, i64)> { + let resp = client + .get_signing_keys(GetSigningKeysRequest { kid: String::new() }) + .await + .map_err(ImksError::GrpcStatus)?; + + let inner = resp.into_inner(); + let mut cached_keys = Vec::new(); + + for key in &inner.keys { + let secret = BASE64 + .decode(&key.key_material) + .map_err(|e| ImksError::Auth(format!("Invalid key base64 for kid={}: {e}", key.kid)))?; + + cached_keys.push(CachedKey { + kid: key.kid.clone(), + decoding_key: DecodingKey::from_secret(&secret), + expires_at: key.expires_at, + active: key.active, + }); + } + + tracing::info!( + key_count = cached_keys.len(), + next_rotation = inner.next_rotation_at, + "Fetched signing keys from appks" + ); + + Ok((cached_keys, inner.next_rotation_at)) +} + +/// Background loop: sleep until `next_rotation_at`, re-fetch, swap atomically. +async fn refresh_loop( + mut client: TokenServiceClient, + keys: Arc>>, + mut next_rotation_at: i64, +) { + loop { + let now_secs = chrono::Utc::now().timestamp(); + let sleep_secs = (next_rotation_at - now_secs).max(60); + + tracing::debug!(sleep_secs, "Key refresh sleeping"); + tokio::time::sleep(Duration::from_secs(sleep_secs as u64)).await; + + match fetch_keys(&mut client).await { + Ok((cached, new_rotation)) => { + let map: HashMap = + cached.into_iter().map(|k| (k.kid.clone(), k)).collect(); + keys.store(Arc::new(map)); + next_rotation_at = new_rotation; + tracing::info!("Signing keys refreshed"); + } + Err(e) => { + tracing::error!(error = %e, "Failed to refresh signing keys, retrying in 60s"); + next_rotation_at = now_secs + 60; + } + } + } +} diff --git a/auth/mod.rs b/auth/mod.rs new file mode 100644 index 0000000..c295452 --- /dev/null +++ b/auth/mod.rs @@ -0,0 +1,8 @@ +pub mod claims; +pub mod jwt_decoder; +pub mod key_store; +pub mod verifier; + +pub use claims::TokenClaims; +pub use key_store::SigningKeyStore; +pub use verifier::Authenticator; diff --git a/auth/verifier.rs b/auth/verifier.rs new file mode 100644 index 0000000..39726b9 --- /dev/null +++ b/auth/verifier.rs @@ -0,0 +1,98 @@ +//! Dual-mode JWT authenticator — the public-facing entry point. +//! +//! Composes `SigningKeyStore` (local cache) + `jwt_decoder` (HS256 logic) +//! + `TokenServiceClient` (RPC fallback) into a single `Authenticator`. + +use std::sync::Arc; + +use tonic::transport::Channel; + +use crate::pb::core::VerifyTokenRequest; +use crate::pb::core::token_service_client::TokenServiceClient; +use crate::{ImksError, ImksResult}; + +use super::claims::TokenClaims; +use super::jwt_decoder; +use super::key_store::SigningKeyStore; + +/// Dual-mode JWT authenticator. +/// +/// - **Local mode** (`verify_local`): HS256 verification against cached +/// signing keys. Zero network latency. Suitable for high-frequency +/// operations like message send/receive. +/// +/// - **RPC mode** (`verify_rpc`): forwards the token to appks +/// `VerifyToken()`. Real-time revocation awareness. Use for +/// sensitive operations like kick/ban/permission changes. +#[derive(Clone)] +pub struct Authenticator { + key_store: Arc, + token_client: TokenServiceClient, +} + +impl Authenticator { + /// Create a new authenticator. Initializes the signing key cache from appks. + pub async fn new(token_client: TokenServiceClient) -> ImksResult { + let key_store = SigningKeyStore::init(token_client.clone()).await?; + Ok(Self { + key_store: Arc::new(key_store), + token_client, + }) + } + + /// Fast-path verification using locally cached signing keys. + /// + /// Extracts `kid` from the JWT header, looks up the key, and verifies + /// the HS256 signature. Cannot detect token revocation within the + /// current key rotation window (~3 hours). + pub fn verify_local(&self, token: &str) -> ImksResult { + let kid = jwt_decoder::extract_kid(token)?; + + let key = self + .key_store + .get_key(&kid) + .ok_or_else(|| ImksError::Auth(format!("Unknown signing key kid: {kid}")))?; + + jwt_decoder::verify_and_decode(token, &key) + } + + /// Authoritative verification via appks `VerifyToken` RPC. + /// + /// Detects token revocation in real-time. Adds one RPC round-trip. + pub async fn verify_rpc(&self, token: &str) -> ImksResult { + let mut client = self.token_client.clone(); + let resp = client + .verify_token(VerifyTokenRequest { + token: token.to_string(), + }) + .await?; + + let inner = resp.into_inner(); + if !inner.valid { + return Err(ImksError::Auth(inner.reason)); + } + + let proto_claims = inner.claims.ok_or_else(|| { + ImksError::Auth("VerifyToken returned valid=true but no claims".into()) + })?; + + Ok(TokenClaims::from_proto(proto_claims)) + } + + /// Extract the Bearer token value from an `Authorization` header. + /// + /// Expects format: `"Bearer "`. Returns the token part. + pub fn extract_bearer(auth_header: &str) -> ImksResult<&str> { + auth_header + .strip_prefix("Bearer ") + .ok_or_else(|| ImksError::Auth("Missing or malformed Authorization header".into())) + } + + /// Shut down the background key refresh task. + pub async fn shutdown(self) { + // Unwrap the Arc — if there are other clones, the store lives on. + if let Ok(store) = Arc::try_unwrap(self.key_store) { + store.shutdown().await; + } + } +} diff --git a/database/config.rs b/database/config.rs new file mode 100644 index 0000000..7d39239 --- /dev/null +++ b/database/config.rs @@ -0,0 +1,71 @@ +//! PostgreSQL connection pool configuration. +//! +//! Reads settings from environment variables with sensible defaults. + +use std::env; + +/// PostgreSQL connection configuration, sourced from environment variables. +#[derive(Debug, Clone)] +pub struct DatabaseConfig { + /// PostgreSQL connection URL (e.g. `postgres://user:pass@host/db`). + pub url: String, + /// Maximum number of connections in the pool. + pub max_connections: u32, + /// Minimum number of idle connections maintained. + pub min_connections: u32, + /// Timeout for acquiring a new connection (seconds). + pub connect_timeout_secs: u64, + /// Timeout for idle connections before they are closed (seconds). + pub idle_timeout_secs: u64, +} + +impl DatabaseConfig { + /// Build config by reading environment variables, falling back to defaults. + pub fn from_env() -> Self { + Self { + url: env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://localhost/imks".to_string()), + max_connections: env::var("DATABASE_MAX_CONNECTIONS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10), + min_connections: env::var("DATABASE_MIN_CONNECTIONS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(2), + connect_timeout_secs: env::var("DATABASE_CONNECT_TIMEOUT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(30), + idle_timeout_secs: env::var("DATABASE_IDLE_TIMEOUT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(600), + } + } +} + +impl Default for DatabaseConfig { + fn default() -> Self { + Self::from_env() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config_has_sane_values() { + // Without env vars set, defaults should be applied. + let cfg = DatabaseConfig { + url: "postgres://localhost/test".to_string(), + max_connections: 10, + min_connections: 2, + connect_timeout_secs: 30, + idle_timeout_secs: 600, + }; + assert_eq!(cfg.max_connections, 10); + assert_eq!(cfg.min_connections, 2); + } +} diff --git a/database/migration.rs b/database/migration.rs new file mode 100644 index 0000000..6dca4cb --- /dev/null +++ b/database/migration.rs @@ -0,0 +1,37 @@ +//! SQL migration runner. +//! +//! Reads migration files from the `migrate/` directory and applies them +//! in lexicographic order using sqlx's built-in migration infrastructure. + +use sqlx::PgPool; +use sqlx::migrate::Migrator; +use std::path::Path; + +use crate::{ImksError, ImksResult}; + +/// Run all pending SQL migrations from the `migrate/` directory. +/// +/// Migrations are applied in filename order (e.g. `001_…sql` before `002_…sql`). +/// sqlx tracks applied migrations in a `_sqlx_migrations` table so that +/// only new migrations are executed on subsequent runs. +pub async fn run_migrations(pool: &PgPool) -> ImksResult<()> { + let migrations_dir = Path::new("migrate"); + + if !migrations_dir.exists() { + tracing::warn!("No migrate/ directory found — skipping migrations"); + return Ok(()); + } + + let migrator = Migrator::new(migrations_dir) + .await + .map_err(|e| ImksError::Internal(format!("Failed to load migrations: {e}")))?; + + migrator + .run(pool) + .await + .map_err(|e| ImksError::Internal(format!("Migration failed: {e}")))?; + + tracing::info!("Database migrations completed"); + + Ok(()) +} diff --git a/database/mod.rs b/database/mod.rs new file mode 100644 index 0000000..aec1e2a --- /dev/null +++ b/database/mod.rs @@ -0,0 +1,7 @@ +pub mod config; +pub mod migration; +pub mod pool; + +pub use config::DatabaseConfig; +pub use migration::run_migrations; +pub use pool::Database; diff --git a/database/pool.rs b/database/pool.rs new file mode 100644 index 0000000..cfbb230 --- /dev/null +++ b/database/pool.rs @@ -0,0 +1,75 @@ +//! PostgreSQL connection pool wrapper. +//! +//! Provides [`Database`] — a thin, cloneable wrapper around `sqlx::PgPool` +//! with health-check and graceful shutdown. + +use std::time::Duration; + +use sqlx::Row; +use sqlx::postgres::{PgPool, PgPoolOptions}; + +use crate::ImksResult; + +use super::config::DatabaseConfig; + +/// Cloneable handle to the PostgreSQL connection pool. +/// +/// All query execution goes through `pool()` which returns a `&PgPool`. +#[derive(Clone)] +pub struct Database { + pool: PgPool, +} + +impl Database { + /// Create a new pool from config and verify connectivity with a ping. + pub async fn connect(config: &DatabaseConfig) -> ImksResult { + let pool = PgPoolOptions::new() + .max_connections(config.max_connections) + .min_connections(config.min_connections) + .acquire_timeout(Duration::from_secs(config.connect_timeout_secs)) + .idle_timeout(Duration::from_secs(config.idle_timeout_secs)) + .connect(&config.url) + .await?; + + tracing::info!( + max_connections = config.max_connections, + min_connections = config.min_connections, + "PostgreSQL pool created" + ); + + let db = Self { pool }; + db.health_check().await?; + Ok(db) + } + + /// Wrap an existing `PgPool` (useful for tests with shared pools). + pub fn from_pool(pool: PgPool) -> Self { + Self { pool } + } + + /// Access the inner `PgPool` for query execution. + pub fn pool(&self) -> &PgPool { + &self.pool + } + + /// Verify the database is reachable by executing `SELECT 1`. + pub async fn health_check(&self) -> ImksResult<()> { + let row = sqlx::query("SELECT 1 AS alive") + .fetch_one(&self.pool) + .await?; + let alive: i32 = row.get("alive"); + if alive != 1 { + return Err(crate::ImksError::Internal( + "Database health check returned unexpected value".into(), + )); + } + tracing::debug!("Database health check passed"); + Ok(()) + } + + /// Gracefully close all connections in the pool. + pub async fn close(&self) { + self.pool.close().await; + tracing::info!("Database pool closed"); + } +} diff --git a/engine/codec.rs b/engine/codec.rs index fa96526..5eb35fc 100644 --- a/engine/codec.rs +++ b/engine/codec.rs @@ -1,4 +1,4 @@ -use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; +use base64::{Engine, engine::general_purpose::STANDARD as BASE64}; use crate::engine::packet::{Packet, PacketData, PacketError, PacketType}; @@ -226,7 +226,10 @@ mod tests { let input: Vec = vec![b'4', 0x80, 0xFF, 0x00, 0x01]; let decoded = decode_packet_ws(&input).unwrap(); assert_eq!(decoded.packet_type, PacketType::Message); - assert_eq!(decoded.data, PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01])); + assert_eq!( + decoded.data, + PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01]) + ); } #[test] @@ -236,4 +239,4 @@ mod tests { assert_eq!(decoded.packet_type, PacketType::Message); assert_eq!(decoded.data, PacketData::Empty); } -} \ No newline at end of file +} diff --git a/engine/health.rs b/engine/health.rs new file mode 100644 index 0000000..639a7dc --- /dev/null +++ b/engine/health.rs @@ -0,0 +1,26 @@ +//! Health check endpoint for the imks server. +//! +//! Returns JSON with server status, version, and upstream connectivity. + +use actix_web::HttpResponse; +use serde::Serialize; + +#[derive(Serialize)] +struct HealthResponse { + status: String, + version: String, + timestamp: String, + uptime_secs: u64, + sessions_count: usize, +} + +/// GET /health — returns server health status. +pub async fn health_check() -> HttpResponse { + HttpResponse::Ok().json(HealthResponse { + status: "healthy".into(), + version: env!("CARGO_PKG_VERSION").into(), + timestamp: chrono::Utc::now().to_rfc3339(), + uptime_secs: 0, + sessions_count: 0, + }) +} diff --git a/engine/heartbeat.rs b/engine/heartbeat.rs index e63c7b8..1585c3f 100644 --- a/engine/heartbeat.rs +++ b/engine/heartbeat.rs @@ -74,4 +74,4 @@ impl HeartbeatManager { self.store.remove(&sid); } } -} \ No newline at end of file +} diff --git a/engine/mod.rs b/engine/mod.rs index 43d26de..c3b6eaa 100644 --- a/engine/mod.rs +++ b/engine/mod.rs @@ -1,4 +1,5 @@ pub mod codec; +pub mod health; pub mod heartbeat; pub mod packet; pub mod polling; diff --git a/engine/packet.rs b/engine/packet.rs index 66c2b63..2f79c39 100644 --- a/engine/packet.rs +++ b/engine/packet.rs @@ -61,11 +61,10 @@ pub struct Packet { impl Packet { pub fn open(handshake: &HandshakeData) -> Self { - let data = serde_json::to_string(handshake) - .unwrap_or_else(|e| { - tracing::error!("Failed to serialize handshake data: {}", e); - "{}".to_string() - }); + let data = serde_json::to_string(handshake).unwrap_or_else(|e| { + tracing::error!("Failed to serialize handshake data: {}", e); + "{}".to_string() + }); Self { packet_type: PacketType::Open, data: PacketData::Text(data), @@ -148,4 +147,4 @@ pub enum PacketError { InvalidUtf8(#[from] std::string::FromUtf8Error), #[error("serialization error: {0}")] Serialization(String), -} \ No newline at end of file +} diff --git a/engine/polling.rs b/engine/polling.rs index 480ae0b..7b92cb9 100644 --- a/engine/polling.rs +++ b/engine/polling.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Duration; -use actix_web::{web, HttpRequest, HttpResponse}; +use actix_web::{HttpRequest, HttpResponse, web}; use crate::engine::codec; use crate::engine::packet::{Packet, PacketType}; @@ -182,4 +182,4 @@ async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpRe pub fn configure_polling(cfg: &mut web::ServiceConfig) { cfg.route("/engine.io/", web::get().to(polling_get)) .route("/engine.io/", web::post().to(polling_post)); -} \ No newline at end of file +} diff --git a/engine/server.rs b/engine/server.rs index 018f5fb..8faedbf 100644 --- a/engine/server.rs +++ b/engine/server.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use actix_web::{web, App, HttpServer}; +use actix_web::{App, HttpRequest, HttpResponse, HttpServer, web}; use crate::engine::heartbeat::HeartbeatManager; use crate::engine::packet::Packet; @@ -31,6 +31,53 @@ pub struct EngineServer { on_message: Arc, } +#[derive(Debug, serde::Deserialize)] +pub struct EngineQuery { + #[serde(rename = "EIO")] + pub eio: Option, + pub transport: Option, + pub sid: Option, +} + +pub async fn engine_get( + req: HttpRequest, + body: web::Payload, + query: web::Query, + store: web::Data, + config: web::Data, + on_message: web::Data>, +) -> Result { + match query.transport.as_deref() { + Some("websocket") => { + crate::engine::websocket::websocket_handler( + req, + body, + web::Query(crate::engine::websocket::WsQuery { + eio: query.eio.clone(), + transport: query.transport.clone(), + sid: query.sid.clone(), + }), + store, + config, + on_message, + ) + .await + } + _ => Ok(crate::engine::polling::polling_get( + req, + web::Query(crate::engine::polling::PollingQuery { + eio: query.eio.clone(), + transport: query.transport.clone(), + sid: query.sid.clone(), + }), + store, + config, + on_message, + ) + .await), + } +} + impl EngineServer { pub fn new( config: EngineConfig, @@ -76,17 +123,14 @@ impl EngineServer { .app_data(web::Data::new(config.clone())) .app_data(web::Data::new(on_message.clone())) .route( - "/engine.io/", - web::get().to(crate::engine::polling::polling_get), + "/health", + web::get().to(crate::engine::health::health_check), ) + .route("/engine.io/", web::get().to(engine_get)) .route( "/engine.io/", web::post().to(crate::engine::polling::polling_post), ) - .route( - "/engine.io/", - web::get().to(crate::engine::websocket::websocket_handler), - ) }) .bind(addr)? .run() @@ -101,7 +145,7 @@ impl EngineServer { port: u16, cert_path: &str, key_path: &str, - ) -> Result<(), Box> { + ) -> crate::ImksResult<()> { crate::engine::webtransport::run_webtransport_server( port, cert_path, diff --git a/engine/session.rs b/engine/session.rs index 30b6bb0..d9d21ae 100644 --- a/engine/session.rs +++ b/engine/session.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use std::time::Instant; use dashmap::DashMap; -use tokio::sync::{mpsc, Notify}; +use tokio::sync::{Notify, mpsc}; use crate::engine::packet::Packet; @@ -124,7 +124,10 @@ impl SessionStore { .sessions .insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session))); if old.is_some() { - tracing::warn!("Session ID collision for SID {}, replacing existing session", sid); + tracing::warn!( + "Session ID collision for SID {}, replacing existing session", + sid + ); } rx } @@ -168,4 +171,4 @@ pub fn generate_sid() -> String { CHARSET[idx] as char }) .collect() -} \ No newline at end of file +} diff --git a/engine/upgrade.rs b/engine/upgrade.rs index 695a7d8..01e05d2 100644 --- a/engine/upgrade.rs +++ b/engine/upgrade.rs @@ -1,10 +1,7 @@ use crate::engine::packet::Packet; use crate::engine::session::{SessionState, SessionStore, TransportType}; -pub async fn handle_upgrade_probe( - store: &SessionStore, - sid: &str, -) -> Result { +pub async fn handle_upgrade_probe(store: &SessionStore, sid: &str) -> Result { let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?; let mut session = session.write().await; diff --git a/engine/websocket.rs b/engine/websocket.rs index 2993c7d..49c15d2 100644 --- a/engine/websocket.rs +++ b/engine/websocket.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use actix_web::{web, HttpRequest, HttpResponse}; +use actix_web::{HttpRequest, HttpResponse, web}; use actix_ws::Message; use crate::engine::codec; @@ -36,37 +36,37 @@ pub async fn websocket_handler( let sid = query.sid.clone(); - let is_upgrade = sid.as_ref().map(|s| store.exists(s)).unwrap_or(false); + if let Some(ref sid) = sid + && !store.exists(sid) + { + return Ok(HttpResponse::BadRequest().body("unknown session")); + } // Create or reuse session, obtaining the mpsc receiver for the forwarding task let (session_sid, mut session_rx) = if let Some(ref sid) = sid { - if is_upgrade { - // Upgrade: session already exists, replace its channel and drain pending packets - let session_arc = store.get(sid).unwrap(); - let (new_tx, new_rx) = tokio::sync::mpsc::channel(256); - { - let mut s = session_arc.write().await; - // Swap tx atomically: old_tx will be dropped, closing its channel. - // Any packets in the old rx are consumed by the old send_handle, - // which then exits when it sees the channel close. - // Drain pending_packets (from polling buffering) into new channel. - let pending = s.take_pending(); - for packet in pending { - let _ = new_tx.try_send(packet); - } - s.tx = new_tx; - s.set_transport(TransportType::WebSocket); + // Upgrade: session already exists, replace its channel and drain pending packets + let session_arc = match store.get(sid) { + Some(s) => s, + None => { + tracing::error!("Session {} not found for upgrade", sid); + return Ok(HttpResponse::InternalServerError().body("session not found")); } - (sid.clone(), new_rx) - } else { - // Reconnect with known SID: create new session - let rx = store.create(sid.clone(), TransportType::WebSocket); - if let Some(s) = store.get(sid) { - let mut s = s.write().await; - s.set_state(SessionState::Open); + }; + let (new_tx, new_rx) = tokio::sync::mpsc::channel(256); + { + let mut s = session_arc.write().await; + // Swap tx atomically: old_tx will be dropped, closing its channel. + // Any packets in the old rx are consumed by the old send_handle, + // which then exits when it sees the channel close. + // Drain pending_packets (from polling buffering) into new channel. + let pending = s.take_pending(); + for packet in pending { + let _ = new_tx.try_send(packet); } - (sid.clone(), rx) + s.tx = new_tx; + s.set_transport(TransportType::WebSocket); } + (sid.clone(), new_rx) } else { // New connection: generate SID and create session let new_sid = crate::engine::session::generate_sid(); @@ -89,7 +89,10 @@ pub async fn websocket_handler( let open_packet = Packet::open(&handshake); let open_msg = codec::encode_packet(&open_packet); if ws_session.text(open_msg).await.is_err() { - tracing::warn!("Failed to send open packet to WebSocket session {}", session_sid); + tracing::warn!( + "Failed to send open packet to WebSocket session {}", + session_sid + ); store.remove(&session_sid); return Ok(response); } @@ -121,16 +124,26 @@ pub async fn websocket_handler( while let Some(Ok(msg)) = msg_stream.recv().await { match msg { Message::Text(text) => { + if text.len() > max_payload { + tracing::warn!( + "Text payload too large ({}) for session {}", + text.len(), + sid_clone + ); + let _ = ws_session.close(None).await; + break; + } + if let Ok(packet) = codec::decode_packet(&text) { match packet.packet_type { PacketType::Ping => { - if let PacketData::Text(ref data) = packet.data { - if data == "probe" { - let pong = Packet::pong("probe"); - let pong_msg = codec::encode_packet(&pong); - let _ = ws_session.text(pong_msg).await; - continue; - } + if let PacketData::Text(ref data) = packet.data + && data == "probe" + { + let pong = Packet::pong("probe"); + let pong_msg = codec::encode_packet(&pong); + let _ = ws_session.text(pong_msg).await; + continue; } let pong = Packet::pong(""); let pong_msg = codec::encode_packet(&pong); @@ -180,14 +193,14 @@ pub async fn websocket_handler( continue; } - if let Ok(packet) = codec::decode_packet_ws(&bin) { - if packet.packet_type == PacketType::Message { - let on_msg = on_message_clone.clone(); - let sid = sid_clone.clone(); - tokio::spawn(async move { - on_msg(sid, packet); - }); - } + if let Ok(packet) = codec::decode_packet_ws(&bin) + && packet.packet_type == PacketType::Message + { + let on_msg = on_message_clone.clone(); + let sid = sid_clone.clone(); + tokio::spawn(async move { + on_msg(sid, packet); + }); } } Message::Close(_) => { diff --git a/engine/webtransport.rs b/engine/webtransport.rs index 5e25970..b4c4aad 100644 --- a/engine/webtransport.rs +++ b/engine/webtransport.rs @@ -1,11 +1,12 @@ use std::sync::Arc; -use wtransport::{Connection, Endpoint, ServerConfig, Identity}; +use wtransport::{Connection, Endpoint, Identity, ServerConfig}; use crate::engine::codec; use crate::engine::packet::{Packet, PacketType}; use crate::engine::server::EngineConfig; use crate::engine::session::{SessionState, SessionStore, TransportType}; +use crate::{ImksError, ImksResult}; pub async fn run_webtransport_server( port: u16, @@ -14,15 +15,18 @@ pub async fn run_webtransport_server( store: SessionStore, config: EngineConfig, on_message: Arc, -) -> Result<(), Box> { - let identity = Identity::load_pemfiles(cert_path, key_path).await?; +) -> ImksResult<()> { + let identity = Identity::load_pemfiles(cert_path, key_path) + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; let server_config = ServerConfig::builder() .with_bind_default(port) .with_identity(identity) .build(); - let server = Endpoint::server(server_config)?; + let server = + Endpoint::server(server_config).map_err(|e| ImksError::WebTransport(e.to_string()))?; tracing::info!("WebTransport server listening on UDP port {}", port); @@ -49,9 +53,14 @@ async fn handle_webtransport_session( store: SessionStore, config: EngineConfig, on_message: Arc, -) -> Result<(), Box> { - let request = incoming.await?; - let connection = request.accept().await?; +) -> ImksResult<()> { + let request = incoming + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; + let connection = request + .accept() + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; let sid = crate::engine::session::generate_sid(); let mut rx = store.create(sid.clone(), TransportType::WebTransport); @@ -81,65 +90,57 @@ async fn handle_webtransport_session( // Reuse buffer across recv iterations instead of allocating 65KB each time let recv_handle = tokio::spawn(async move { let mut buf = vec![0u8; 65536]; - loop { - match connection_recv.accept_bi().await { - Ok((mut send, mut recv)) => { - // Reset buffer length for the next read without deallocating - buf.resize(65536, 0); - match recv.read(&mut buf).await { - Ok(Some(n)) => { - if n > max_payload { - tracing::warn!( - "WebTransport payload too large ({}) for session {}", - n, - sid_clone - ); - continue; - } - if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) { - match packet.packet_type { - PacketType::Ping => { - let pong = Packet::pong(""); - if send_wt_packet_on_stream(&mut send, &pong) - .await - .is_err() - { - break; - } - } - PacketType::Pong => { - if let Some(s) = store_clone.get(&sid_clone) { - let mut s = s.write().await; - s.update_ping(); - } - } - PacketType::Message => { - let on_msg = on_message_clone.clone(); - let sid = sid_clone.clone(); - tokio::spawn(async move { - on_msg(sid, packet); - }); - } - PacketType::Close => { - if let Some(s) = store_clone.get(&sid_clone) { - let mut s = s.write().await; - s.set_state(SessionState::Closed); - } - store_clone.remove(&sid_clone); - break; - } - _ => {} + while let Ok((mut send, mut recv)) = connection_recv.accept_bi().await { + // Reset buffer length for the next read without deallocating + buf.resize(65536, 0); + match recv.read(&mut buf).await { + Ok(Some(n)) => { + if n > max_payload { + tracing::warn!( + "WebTransport payload too large ({}) for session {}", + n, + sid_clone + ); + continue; + } + if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) { + match packet.packet_type { + PacketType::Ping => { + let pong = Packet::pong(""); + if send_wt_packet_on_stream(&mut send, &pong).await.is_err() { + break; } } + PacketType::Pong => { + if let Some(s) = store_clone.get(&sid_clone) { + let mut s = s.write().await; + s.update_ping(); + } + } + PacketType::Message => { + let on_msg = on_message_clone.clone(); + let sid = sid_clone.clone(); + tokio::spawn(async move { + on_msg(sid, packet); + }); + } + PacketType::Close => { + if let Some(s) = store_clone.get(&sid_clone) { + let mut s = s.write().await; + s.set_state(SessionState::Closed); + } + store_clone.remove(&sid_clone); + break; + } + _ => {} } - Ok(None) => break, - Err(_) => break, } } + Ok(None) => break, Err(_) => break, } } - Ok::<(), Box>(()) + Ok::<(), ImksError>(()) }); let connection_send = connection.clone(); @@ -191,18 +192,26 @@ async fn handle_webtransport_session( Ok(()) } -async fn send_wt_packet( - connection: &Connection, - packet: &Packet, -) -> Result<(), Box> { - let (mut send, _recv) = connection.open_bi().await?.await?; +async fn send_wt_packet(connection: &Connection, packet: &Packet) -> ImksResult<()> { + let (mut send, _recv) = connection + .open_bi() + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))? + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; let encoded = codec::encode_packet_binary_ws(packet); let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_)); let header = codec::encode_webtransport_header(encoded.len(), is_binary); - send.write_all(&header).await?; - send.write_all(&encoded).await?; - send.finish().await?; + send.write_all(&header) + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; + send.write_all(&encoded) + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; + send.finish() + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; Ok(()) } @@ -210,14 +219,20 @@ async fn send_wt_packet( async fn send_wt_packet_on_stream( send: &mut wtransport::SendStream, packet: &Packet, -) -> Result<(), Box> { +) -> ImksResult<()> { let encoded = codec::encode_packet_binary_ws(packet); let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_)); let header = codec::encode_webtransport_header(encoded.len(), is_binary); - send.write_all(&header).await?; - send.write_all(&encoded).await?; - send.finish().await?; + send.write_all(&header) + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; + send.write_all(&encoded) + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; + send.finish() + .await + .map_err(|e| ImksError::WebTransport(e.to_string()))?; Ok(()) -} \ No newline at end of file +} diff --git a/error.rs b/error.rs new file mode 100644 index 0000000..e750f3e --- /dev/null +++ b/error.rs @@ -0,0 +1,255 @@ +//! Unified error type for imks. +//! +//! Consolidates all submodule-specific error enums into a single `ImksError`. +//! Public APIs return `ImksResult`, private code may still use local enums +//! for internal dispatch and convert via `From` / `.map_err()` when crossing +//! module boundaries. + +use std::string::FromUtf8Error; + +use thiserror::Error; + +/// Unified error enum for the entire imks crate. +#[derive(Debug, Error)] +pub enum ImksError { + // Protocol layer (engine) + #[error("invalid engine packet type: {0}")] + InvalidEnginePacketType(u8), + #[error("invalid engine packet type char: {0}")] + InvalidEnginePacketTypeChar(char), + #[error("empty engine packet")] + EmptyEnginePacket, + #[error("invalid base64: {0}")] + InvalidBase64(#[from] base64::DecodeError), + #[error("invalid utf8 in packet: {0}")] + InvalidPacketUtf8(#[from] FromUtf8Error), + #[error("engine serialization error: {0}")] + EngineSerialization(String), + + // Transport upgrade + #[error("session not found for upgrade")] + UpgradeSessionNotFound, + #[error("session already closed, cannot upgrade")] + UpgradeSessionClosed, + #[error("invalid session state for upgrade")] + UpgradeInvalidState, + + // Socket.IO layer + #[error("invalid socket packet type: {0}")] + InvalidSocketPacketType(u8), + #[error("invalid socket packet type char: {0}")] + InvalidSocketPacketTypeChar(char), + #[error("empty socket packet")] + EmptySocketPacket, + #[error("invalid socket packet format: {0}")] + InvalidSocketPacketFormat(String), + #[error("missing namespace in socket packet")] + MissingNamespace, + #[error("invalid attachment count in binary event")] + InvalidAttachmentCount, + + // Socket namespace + #[error("namespace error: {0}")] + Namespace(String), + #[error("socket not found: {0}")] + SocketNotFound(String), + #[error("failed to send packet to socket: channel full")] + SocketSendFull, + + // Adapter layer + #[error("adapter redis error: {0}")] + AdapterRedis(String), + #[error("adapter nats error: {0}")] + AdapterNats(String), + #[error("adapter message bus error: {0}")] + AdapterMessageBus(String), + #[error("adapter serialization error: {0}")] + AdapterSerialization(String), + #[error("adapter room error: {0}")] + AdapterRoom(String), + + // Message bus + #[error("message bus connection closed")] + MessageBusConnectionClosed, + #[error("message bus channel not found: {0}")] + MessageBusChannelNotFound(String), + + // Session store + #[error("session not found: {0}")] + SessionNotFound(String), + #[error("session expired: {0}")] + SessionExpired(String), + #[error("session store redis error: {0}")] + SessionRedis(String), + #[error("session serialization error: {0}")] + SessionSerialization(String), + + // Database + #[error("database error: {0}")] + Database(#[from] sqlx::Error), + + // gRPC + #[error("gRPC error: {0}")] + GrpcStatus(#[from] tonic::Status), + #[error("gRPC transport error: {0}")] + GrpcTransport(#[from] tonic::transport::Error), + + // Serialization + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + // Transport + #[error("webtransport error: {0}")] + WebTransport(String), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + // Auth + #[error("auth error: {0}")] + Auth(String), + #[error("token expired")] + TokenExpired, + + // General + #[error("not found: {0}")] + NotFound(String), + #[error("invalid input: {0}")] + InvalidInput(String), + #[error("internal error: {0}")] + Internal(String), +} + +/// Convenience alias used across all public APIs. +pub type ImksResult = Result; + +// Conversions from submodule error types (for gradual migration). + +impl From for ImksError { + fn from(e: crate::engine::packet::PacketError) -> Self { + use crate::engine::packet::PacketError::*; + match e { + InvalidType(v) => ImksError::InvalidEnginePacketType(v), + InvalidTypeChar(c) => ImksError::InvalidEnginePacketTypeChar(c), + Empty => ImksError::EmptyEnginePacket, + InvalidBase64(e) => ImksError::InvalidBase64(e), + InvalidUtf8(e) => ImksError::InvalidPacketUtf8(e), + Serialization(s) => ImksError::EngineSerialization(s), + } + } +} + +impl From for ImksError { + fn from(e: crate::engine::upgrade::UpgradeError) -> Self { + use crate::engine::upgrade::UpgradeError::*; + match e { + SessionNotFound => ImksError::UpgradeSessionNotFound, + SessionClosed => ImksError::UpgradeSessionClosed, + InvalidState => ImksError::UpgradeInvalidState, + } + } +} + +impl From for ImksError { + fn from(e: crate::socket::packet::PacketError) -> Self { + use crate::socket::packet::PacketError::*; + match e { + InvalidType(v) => ImksError::InvalidSocketPacketType(v), + InvalidTypeChar(c) => ImksError::InvalidSocketPacketTypeChar(c), + Empty => ImksError::EmptySocketPacket, + InvalidFormat(s) => ImksError::InvalidSocketPacketFormat(s), + Json(e) => ImksError::Json(e), + MissingNamespace => ImksError::MissingNamespace, + InvalidAttachmentCount => ImksError::InvalidAttachmentCount, + } + } +} + +impl From for ImksError { + fn from(e: crate::socket::adapter::AdapterError) -> Self { + use crate::socket::adapter::AdapterError::*; + match e { + Redis(s) => ImksError::AdapterRedis(s), + Nats(s) => ImksError::AdapterNats(s), + MessageBus(s) => ImksError::AdapterMessageBus(s), + Serialization(s) => ImksError::AdapterSerialization(s), + Room(s) => ImksError::AdapterRoom(s), + } + } +} + +impl From for ImksError { + fn from(e: crate::socket::message_bus::MessageBusError) -> Self { + use crate::socket::message_bus::MessageBusError::*; + match e { + Redis(s) => ImksError::AdapterRedis(s), + Nats(s) => ImksError::AdapterNats(s), + ConnectionClosed => ImksError::MessageBusConnectionClosed, + ChannelNotFound(s) => ImksError::MessageBusChannelNotFound(s), + Serialization(s) => ImksError::AdapterSerialization(s), + } + } +} + +impl From for ImksError { + fn from(e: crate::socket::session_store::SessionError) -> Self { + use crate::socket::session_store::SessionError::*; + match e { + Redis(s) => ImksError::SessionRedis(s), + NotFound(s) => ImksError::SessionNotFound(s), + Serialization(s) => ImksError::SessionSerialization(s), + Expired(s) => ImksError::SessionExpired(s), + } + } +} + +impl From> for ImksError { + fn from(_: tokio::sync::mpsc::error::TrySendError) -> Self { + ImksError::SocketSendFull + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_imks_error_display() { + let err = ImksError::NotFound("message 01909abc".into()); + assert_eq!(err.to_string(), "not found: message 01909abc"); + } + + #[test] + fn test_imks_error_from_base64() { + let b64_err = base64::DecodeError::InvalidByte(0, b'!'); + let err: ImksError = b64_err.into(); + assert!(matches!(err, ImksError::InvalidBase64(_))); + } + + #[test] + fn test_imks_error_from_sqlx() { + // sqlx::Error doesn't impl PartialEq, so just check the variant + let db_err = sqlx::Error::PoolClosed; + let err: ImksError = db_err.into(); + assert!(matches!(err, ImksError::Database(_))); + } + + #[test] + fn test_imks_error_from_serde_json() { + let json_err = serde_json::from_str::("not json").unwrap_err(); + let err: ImksError = json_err.into(); + assert!(matches!(err, ImksError::Json(_))); + } + + #[test] + #[allow(clippy::unnecessary_literal_unwrap)] + fn test_imks_result_ok() { + let result: ImksResult = Ok(42); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_imks_result_err() { + let result: ImksResult = Err(ImksError::TokenExpired); + assert!(result.is_err()); + } +} diff --git a/lib.rs b/lib.rs index 91d7541..1e85346 100644 --- a/lib.rs +++ b/lib.rs @@ -1,3 +1,12 @@ +pub mod auth; +pub mod database; +pub mod engine; +pub mod error; +pub mod models; pub mod pb; +pub mod repo; +pub mod rpc; pub mod socket; -pub mod engine; \ No newline at end of file +pub mod svc; + +pub use error::{ImksError, ImksResult}; diff --git a/main.rs b/main.rs index 89aea43..abfb370 100644 --- a/main.rs +++ b/main.rs @@ -1,9 +1,16 @@ -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use imks::database::{Database, DatabaseConfig}; use imks::engine::server::EngineConfig; -use imks::socket::server::SocketServer; +use imks::repo::MessageRepo; +use imks::rpc::{AppksClients, RpcConfig}; +use imks::socket::adapter::{LocalBroadcastFn, NatsAdapter, RedisAdapter}; +use imks::socket::message_bus::{NatsMessageBus, RedisMessageBus}; -fn main() { +use imks::socket::server::SocketServerBuilder; +use imks::svc::{DeployConfig, MessageService}; + +fn main() -> Result<(), Box> { tracing_subscriber::fmt() .with_env_filter( tracing_subscriber::EnvFilter::try_from_default_env() @@ -11,27 +18,240 @@ fn main() { ) .init(); - let config = EngineConfig::default(); - let socket_server = Arc::new(SocketServer::new(config)); + let deploy = DeployConfig::from_env(); + tracing::info!( + adapter = %deploy.adapter_mode, + server_id = %deploy.server_id, + wt_enabled = deploy.webtransport_enabled, + "Starting imks server" + ); let addr = "0.0.0.0:3000"; - tracing::info!("Starting Socket.IO server on {}", addr); - tokio::runtime::Runtime::new() - .expect("Failed to create Tokio runtime") - .block_on(async { - let namespace = socket_server.of("/"); - namespace - .on_connect(|socket, _auth| { - tracing::info!( - "Socket {} connected (engine: {})", - socket.sid, - socket.engine_sid - ); - Ok(()) - }) - .await; + let rt = tokio::runtime::Runtime::new()?; - socket_server.run_http(addr).await.expect("Server error"); - }); -} \ No newline at end of file + rt.block_on(async { + let engine_config = EngineConfig::default(); + let mut builder = SocketServerBuilder::new(engine_config); + let namespace_holder: Arc>> = + Arc::new(OnceLock::new()); + + // Pre-configure adapter for Redis/NATS mode. + // The callback resolves namespaces after SocketServer is built. + match deploy.adapter_mode.as_str() { + "redis" => { + let message_bus = Arc::new( + RedisMessageBus::new(&deploy.redis_url) + .await + .map_err(|e| format!("Failed to connect to Redis: {e}"))?, + ); + let redis_client = message_bus.client().clone(); + let server_id = deploy.server_id.clone(); + let adapter = Arc::new(RedisAdapter::new( + message_bus.clone() as Arc<_>, + redis_client, + server_id, + "/".into(), + make_local_broadcast_fn(namespace_holder.clone()), + )); + adapter + .init() + .await + .map_err(|e| format!("Failed to initialize Redis adapter: {e}"))?; + builder = builder.adapter(adapter); + tracing::info!("Redis adapter configured for multi-node"); + } + "nats" => { + let message_bus = Arc::new( + NatsMessageBus::new(&deploy.nats_url) + .await + .map_err(|e| format!("Failed to connect to NATS: {e}"))?, + ); + let server_id = deploy.server_id.clone(); + let adapter = Arc::new(NatsAdapter::new( + message_bus.clone() as Arc<_>, + server_id, + "/".into(), + make_local_broadcast_fn(namespace_holder.clone()), + )); + adapter + .init() + .await + .map_err(|e| format!("Failed to initialize NATS adapter: {e}"))?; + builder = builder.adapter(adapter); + tracing::info!("NATS adapter configured for multi-node"); + } + _ => { + tracing::info!("Local adapter (single-node mode)"); + } + }; + + let socket_server = Arc::new(builder.build()); + let _ = namespace_holder.set(socket_server.namespaces.clone()); + + // Initialize database + gRPC + service + let service: Option> = { + let rpc_config = RpcConfig::from_env(); + let db_config = DatabaseConfig::from_env(); + + match AppksClients::connect(&rpc_config).await { + Ok(clients) => { + let db = Database::connect(&db_config) + .await + .map_err(|e| format!("Database connection failed: {e}"))?; + + imks::database::run_migrations(db.pool()) + .await + .map_err(|e| format!("Database migration failed: {e}"))?; + + let repo = MessageRepo::new(db.pool().clone()); + + let svc = MessageService::new(repo, clients, socket_server.namespaces.clone()) + .await + .map_err(|e| format!("Failed to initialize message service: {e}"))?; + + tracing::info!("Message service initialized with gRPC permission checks"); + Some(Arc::new(svc)) + } + Err(e) => { + tracing::warn!("gRPC unavailable: {e}. Running without permission checks."); + None + } + } + }; + + // Register connect handler + let namespace = socket_server.of("/"); + let svc_connect = service.clone(); + namespace + .on_connect(move |socket, auth_data| { + if let Some(ref svc) = svc_connect { + svc.authenticate_socket(socket, auth_data) + .map_err(|e| e.to_string())?; + } + + tracing::info!( + "Socket {} connected (engine: {})", + socket.sid, + socket.engine_sid + ); + Ok(()) + }) + .await; + + // Register Socket.IO event handlers + if let Some(ref svc) = service { + macro_rules! register_event { + ($svc:expr, $ns:expr, $event:expr, $method:ident) => { + let s = $svc.clone(); + $ns.on_event($event, Arc::new(move |socket, data| { + let s = s.clone(); + let data = data.clone(); + tokio::spawn(async move { + if let Err(e) = s.$method(socket, &data).await { + tracing::error!(event = $event, error = %e, "Event handler failed"); + } + }); + })).await; + }; + } + + register_event!(svc, namespace, "channel:join", join_channel); + register_event!(svc, namespace, "channel:leave", leave_channel); + register_event!(svc, namespace, "message:send", send_message); + register_event!(svc, namespace, "message:edit", edit_message); + register_event!(svc, namespace, "message:delete", delete_message); + register_event!(svc, namespace, "reaction:add", toggle_reaction); + register_event!(svc, namespace, "pin:add", pin_message); + register_event!(svc, namespace, "pin:remove", unpin_message); + register_event!(svc, namespace, "poll:vote", poll_vote); + register_event!(svc, namespace, "poll:vote:remove", poll_remove_vote); + register_event!(svc, namespace, "typing:start", typing_start); + register_event!(svc, namespace, "typing:stop", typing_stop); + register_event!(svc, namespace, "presence:update", presence_update); + register_event!(svc, namespace, "draft:save", save_draft); + register_event!(svc, namespace, "draft:get", get_draft); + register_event!(svc, namespace, "draft:delete", delete_draft); + register_event!(svc, namespace, "read_state:mark", mark_read); + register_event!(svc, namespace, "read_state:get", get_read_state); + register_event!(svc, namespace, "notification:list", list_notifications); + register_event!( + svc, + namespace, + "notification:mark_read", + mark_notification_read + ); + register_event!( + svc, + namespace, + "notification:mark_all_read", + mark_all_notifications_read + ); + register_event!(svc, namespace, "bookmark:add", add_bookmark); + register_event!(svc, namespace, "bookmark:remove", remove_bookmark); + register_event!(svc, namespace, "bookmark:list", list_bookmarks); + register_event!(svc, namespace, "thread:create", create_thread); + register_event!(svc, namespace, "thread:resolve", resolve_thread); + register_event!(svc, namespace, "thread:join", join_thread); + register_event!(svc, namespace, "thread:leave", leave_thread); + register_event!(svc, namespace, "thread:list", list_threads); + register_event!(svc, namespace, "article:create", create_article); + register_event!(svc, namespace, "article:update", update_article); + register_event!(svc, namespace, "article:list", list_articles); + register_event!(svc, namespace, "article:delete", delete_article); + register_event!(svc, namespace, "component:interact", interact_component); + + // Start scheduled message dispatcher (background task) + svc.clone().start_scheduled_dispatcher(); + + tracing::info!("Registered Socket.IO event handlers"); + } + + // Start servers + if deploy.webtransport_enabled && !deploy.cert_path.is_empty() { + let engine = socket_server.engine.clone(); + let wt_port = deploy.webtransport_port; + let cert_path = deploy.cert_path.clone(); + let key_path = deploy.key_path.clone(); + let server = socket_server.clone(); + + tracing::info!("Starting HTTP on {} + WebTransport on {}", addr, wt_port); + + tokio::select! { + result = server.run_http(addr) => { + result?; + } + result = engine.run_webtransport(wt_port, &cert_path, &key_path) => { + result?; + } + } + } else { + tracing::info!("Socket.IO HTTP server listening on {}", addr); + socket_server.run_http(addr).await?; + } + + Ok::<(), Box>(()) + })?; + + Ok(()) +} + +/// Create a local broadcast function for Redis/NATS adapters. +/// +/// The callback is used both for same-node delivery and for cross-node messages +/// received from the message bus. +fn make_local_broadcast_fn( + namespaces: Arc>>, +) -> LocalBroadcastFn { + Arc::new(move |packet, opts| { + let Some(manager) = namespaces.get() else { + tracing::warn!(namespace = %packet.namespace, "Namespace manager not initialized"); + return; + }; + let Some(namespace) = manager.get_namespace(&packet.namespace) else { + tracing::warn!(namespace = %packet.namespace, "Namespace not found for local broadcast"); + return; + }; + namespace.emit_local_filtered(packet, opts); + }) +} diff --git a/migrate/000_message_thread_base.sql b/migrate/000_message_thread_base.sql new file mode 100644 index 0000000..4a482ed --- /dev/null +++ b/migrate/000_message_thread_base.sql @@ -0,0 +1,23 @@ +-- Create message_thread before migrations that reference it. +-- Safe for existing databases because the table may already exist from 004. + +BEGIN; + +CREATE TABLE IF NOT EXISTS message_thread ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + channel_id UUID NOT NULL, + root_message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + created_by UUID NOT NULL, + replies_count BIGINT NOT NULL DEFAULT 0, + participants_count BIGINT NOT NULL DEFAULT 0, + last_reply_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + last_reply_at TIMESTAMPTZ NULL, + resolved BOOLEAN NOT NULL DEFAULT FALSE, + resolved_by UUID NULL, + resolved_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_thread_root UNIQUE (root_message_id) +); + +COMMIT; diff --git a/migrate/001_message_rich_content.sql b/migrate/001_message_rich_content.sql new file mode 100644 index 0000000..8cebf57 --- /dev/null +++ b/migrate/001_message_rich_content.sql @@ -0,0 +1,115 @@ +-- ============================================================ +-- Migration: 001_message_rich_content.sql +-- Tables: message_attachment, message_embed, message_embed_field, +-- message_poll, message_poll_option, message_poll_vote +-- ============================================================ +-- These tables extend the existing `message` table (from appks 001_init.sql) +-- with Discord-style rich content: file attachments, link preview embeds, +-- and interactive polls. + +BEGIN; + +-- models/message_attachment.rs → message_attachment +CREATE TABLE IF NOT EXISTS message_attachment ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + filename TEXT NOT NULL, + content_type TEXT NULL, + size BIGINT NOT NULL, + url TEXT NOT NULL, + storage_key TEXT NULL, + width INTEGER NULL, + height INTEGER NULL, + duration_secs DOUBLE PRECISION NULL, + blurhash TEXT NULL, + spoiler BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); +CREATE INDEX IF NOT EXISTS idx_message_attachment_message_id + ON message_attachment (message_id); + +-- models/message_embed.rs → message_embed +CREATE TABLE IF NOT EXISTS message_embed ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + embed_type TEXT NOT NULL, + title TEXT NULL, + description TEXT NULL, + url TEXT NULL, + color INTEGER NULL, + image_url TEXT NULL, + image_width INTEGER NULL, + image_height INTEGER NULL, + thumbnail_url TEXT NULL, + thumbnail_width INTEGER NULL, + thumbnail_height INTEGER NULL, + video_url TEXT NULL, + video_width INTEGER NULL, + video_height INTEGER NULL, + author_name TEXT NULL, + author_url TEXT NULL, + author_icon_url TEXT NULL, + footer_text TEXT NULL, + footer_icon_url TEXT NULL, + provider_name TEXT NULL, + provider_url TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); +CREATE INDEX IF NOT EXISTS idx_message_embed_message_id + ON message_embed (message_id); + +-- models/message_embed.rs → message_embed_field +CREATE TABLE IF NOT EXISTS message_embed_field ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + embed_id UUID NOT NULL REFERENCES message_embed(id) ON DELETE CASCADE, + name TEXT NOT NULL, + value TEXT NOT NULL, + inline BOOLEAN NOT NULL DEFAULT FALSE, + position INTEGER NOT NULL DEFAULT 0 +); +CREATE INDEX IF NOT EXISTS idx_message_embed_field_embed_id + ON message_embed_field (embed_id); + +-- models/message_poll.rs → message_poll +CREATE TABLE IF NOT EXISTS message_poll ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + question TEXT NOT NULL, + allow_multiselect BOOLEAN NOT NULL DEFAULT FALSE, + max_selections INTEGER NULL, + expires_at TIMESTAMPTZ NULL, + total_votes BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_poll_message UNIQUE (message_id) +); +CREATE INDEX IF NOT EXISTS idx_message_poll_message_id + ON message_poll (message_id); + +-- models/message_poll.rs → message_poll_option +CREATE TABLE IF NOT EXISTS message_poll_option ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + poll_id UUID NOT NULL REFERENCES message_poll(id) ON DELETE CASCADE, + text TEXT NOT NULL, + emoji TEXT NULL, + vote_count BIGINT NOT NULL DEFAULT 0, + position INTEGER NOT NULL DEFAULT 0 +); +CREATE INDEX IF NOT EXISTS idx_message_poll_option_poll_id + ON message_poll_option (poll_id); + +-- models/message_poll.rs → message_poll_vote +CREATE TABLE IF NOT EXISTS message_poll_vote ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + poll_id UUID NOT NULL REFERENCES message_poll(id) ON DELETE CASCADE, + option_id UUID NOT NULL REFERENCES message_poll_option(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_poll_vote UNIQUE (poll_id, user_id, option_id) +); +CREATE INDEX IF NOT EXISTS idx_message_poll_vote_poll_id + ON message_poll_vote (poll_id); +CREATE INDEX IF NOT EXISTS idx_message_poll_vote_user_id + ON message_poll_vote (user_id); + +COMMIT; diff --git a/migrate/002_message_social.sql b/migrate/002_message_social.sql new file mode 100644 index 0000000..edf209f --- /dev/null +++ b/migrate/002_message_social.sql @@ -0,0 +1,76 @@ +-- ============================================================ +-- Migration: 002_message_social.sql +-- Tables: message_pin, message_read_state, message_draft, message_edit +-- ============================================================ +-- Extends the message subsystem with pinned messages, read receipts, +-- drafts, and edit history. + +BEGIN; + +-- models/message_pin.rs → message_pin +CREATE TABLE IF NOT EXISTS message_pin ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + channel_id UUID NOT NULL, + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + pinned_by UUID NOT NULL, + position INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_pin_channel_message UNIQUE (channel_id, message_id) +); + +CREATE INDEX IF NOT EXISTS idx_message_pin_channel_id + ON message_pin (channel_id); + +-- models/message_read_state.rs → message_read_state +CREATE TABLE IF NOT EXISTS message_read_state ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + channel_id UUID NOT NULL, + user_id UUID NOT NULL, + last_read_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + last_read_at TIMESTAMPTZ NULL, + unread_count BIGINT NOT NULL DEFAULT 0, + unread_mentions BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_read_state_channel_user UNIQUE (channel_id, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_message_read_state_user_id + ON message_read_state (user_id); +CREATE INDEX IF NOT EXISTS idx_message_read_state_channel_id + ON message_read_state (channel_id); + +-- models/message_draft.rs → message_draft +CREATE TABLE IF NOT EXISTS message_draft ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + channel_id UUID NOT NULL, + user_id UUID NOT NULL, + thread_id UUID NULL REFERENCES message_thread(id) ON DELETE CASCADE, + reply_to_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + body TEXT NOT NULL DEFAULT '', + metadata JSONB NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_draft_channel_user_thread + UNIQUE (channel_id, user_id, thread_id) +); + +CREATE INDEX IF NOT EXISTS idx_message_draft_user_id + ON message_draft (user_id); + +-- models/message_edit.rs → message_edit +CREATE TABLE IF NOT EXISTS message_edit ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + edited_by UUID NOT NULL, + old_body TEXT NOT NULL, + new_body TEXT NOT NULL, + edited_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_edit_message_id + ON message_edit (message_id); +CREATE INDEX IF NOT EXISTS idx_message_edit_edited_at + ON message_edit (edited_at); + +COMMIT; diff --git a/migrate/003_message_article.sql b/migrate/003_message_article.sql new file mode 100644 index 0000000..d7452bb --- /dev/null +++ b/migrate/003_message_article.sql @@ -0,0 +1,45 @@ +-- ============================================================ +-- Migration: 003_message_article.sql +-- Tables: message_article +-- ============================================================ +-- Extends the message subsystem with forum-style article posts. +-- Articles extend regular messages with title, cover image, tags, +-- and view/like stats. Rendered as waterfall cards in forum channels. + +BEGIN; + +-- models/message_article.rs → message_article +CREATE TABLE IF NOT EXISTS message_article ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + title TEXT NOT NULL, + summary TEXT NULL, + cover_url TEXT NULL, + cover_width INTEGER NULL, + cover_height INTEGER NULL, + cover_color TEXT NULL, + tags JSONB NULL, + view_count BIGINT NOT NULL DEFAULT 0, + like_count BIGINT NOT NULL DEFAULT 0, + bookmark_count BIGINT NOT NULL DEFAULT 0, + reply_count BIGINT NOT NULL DEFAULT 0, + last_reply_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + last_reply_at TIMESTAMPTZ NULL, + last_reply_user_id UUID NULL, + is_pinned_to_top BOOLEAN NOT NULL DEFAULT FALSE, + is_answered BOOLEAN NOT NULL DEFAULT FALSE, + answered_by UUID NULL, + answered_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_article_message UNIQUE (message_id) +); + +CREATE INDEX IF NOT EXISTS idx_message_article_last_reply_at + ON message_article (last_reply_at DESC NULLS LAST); +CREATE INDEX IF NOT EXISTS idx_message_article_is_pinned_to_top + ON message_article (is_pinned_to_top DESC, last_reply_at DESC NULLS LAST); +CREATE INDEX IF NOT EXISTS idx_message_article_view_count + ON message_article (view_count DESC); + +COMMIT; diff --git a/migrate/004_message_social_part2.sql b/migrate/004_message_social_part2.sql new file mode 100644 index 0000000..46d078b --- /dev/null +++ b/migrate/004_message_social_part2.sql @@ -0,0 +1,98 @@ +-- ============================================================ +-- Migration: 004_message_social_part2.sql +-- Tables: message_reaction, message_bookmark, message_mention, +-- message_thread, message_thread_participant +-- ============================================================ + +BEGIN; + +-- models/message_reaction.rs → message_reaction +CREATE TABLE IF NOT EXISTS message_reaction ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + channel_id UUID NOT NULL, + user_id UUID NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_reaction_user_content UNIQUE (message_id, user_id, content) +); + +CREATE INDEX IF NOT EXISTS idx_message_reaction_message_id + ON message_reaction (message_id); +CREATE INDEX IF NOT EXISTS idx_message_reaction_user_id + ON message_reaction (user_id); + +-- models/message_bookmark.rs → message_bookmark +CREATE TABLE IF NOT EXISTS message_bookmark ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + channel_id UUID NOT NULL, + user_id UUID NOT NULL, + note TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_bookmark_user_message UNIQUE (user_id, message_id) +); + +CREATE INDEX IF NOT EXISTS idx_message_bookmark_user_id + ON message_bookmark (user_id); +CREATE INDEX IF NOT EXISTS idx_message_bookmark_message_id + ON message_bookmark (message_id); + +-- models/message_mention.rs → message_mention +CREATE TABLE IF NOT EXISTS message_mention ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + channel_id UUID NOT NULL, + mentioned_user_id UUID NOT NULL, + mentioned_by UUID NOT NULL, + read_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_mention_message_id + ON message_mention (message_id); +CREATE INDEX IF NOT EXISTS idx_message_mention_mentioned_user + ON message_mention (mentioned_user_id); + +-- models/message_thread.rs → message_thread +CREATE TABLE IF NOT EXISTS message_thread ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + channel_id UUID NOT NULL, + root_message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + created_by UUID NOT NULL, + replies_count BIGINT NOT NULL DEFAULT 0, + participants_count BIGINT NOT NULL DEFAULT 0, + last_reply_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + last_reply_at TIMESTAMPTZ NULL, + resolved BOOLEAN NOT NULL DEFAULT FALSE, + resolved_by UUID NULL, + resolved_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_message_thread_root UNIQUE (root_message_id) +); + +CREATE INDEX IF NOT EXISTS idx_message_thread_channel_id + ON message_thread (channel_id); +CREATE INDEX IF NOT EXISTS idx_message_thread_last_reply_at + ON message_thread (last_reply_at DESC NULLS LAST); + +-- models/message_thread_participant.rs → message_thread_participant +CREATE TABLE IF NOT EXISTS message_thread_participant ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + thread_id UUID NOT NULL REFERENCES message_thread(id) ON DELETE CASCADE, + user_id UUID NOT NULL, + joined_reason TEXT NULL, + last_read_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + last_read_at TIMESTAMPTZ NULL, + joined_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT uq_thread_participant UNIQUE (thread_id, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_thread_participant_thread_id + ON message_thread_participant (thread_id); +CREATE INDEX IF NOT EXISTS idx_thread_participant_user_id + ON message_thread_participant (user_id); + +COMMIT; diff --git a/migrate/005_message_misc.sql b/migrate/005_message_misc.sql new file mode 100644 index 0000000..5966a2b --- /dev/null +++ b/migrate/005_message_misc.sql @@ -0,0 +1,102 @@ +-- ============================================================ +-- Migration: 005_message_misc.sql +-- Tables: message_notification, message_scheduled, message_sticker, +-- message_forward, message_component +-- ============================================================ + +BEGIN; + +-- models/message_notification.rs → message_notification +CREATE TABLE IF NOT EXISTS message_notification ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + channel_id UUID NOT NULL, + user_id UUID NOT NULL, + reason TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + delivery_channel TEXT NULL, + delivered_at TIMESTAMPTZ NULL, + read_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_notification_user_id + ON message_notification (user_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_message_notification_status + ON message_notification (status); + +-- models/message_scheduled.rs → message_scheduled +CREATE TABLE IF NOT EXISTS message_scheduled ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + channel_id UUID NOT NULL, + author_id UUID NOT NULL, + thread_id UUID NULL REFERENCES message_thread(id) ON DELETE SET NULL, + reply_to_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + body TEXT NOT NULL, + metadata JSONB NULL, + scheduled_at TIMESTAMPTZ NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + sent_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL, + error TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_scheduled_status_at + ON message_scheduled (status, scheduled_at); + +-- models/message_sticker.rs → message_sticker +CREATE TABLE IF NOT EXISTS message_sticker ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + sticker_id UUID NOT NULL, + name TEXT NOT NULL, + image_url TEXT NOT NULL, + format_type TEXT NOT NULL DEFAULT 'png', + pack_name TEXT NULL, + tags TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_sticker_message_id + ON message_sticker (message_id); + +-- models/message_forward.rs → message_forward +CREATE TABLE IF NOT EXISTS message_forward ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + source_message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + source_channel_id UUID NOT NULL, + forwarded_by UUID NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_forward_message_id + ON message_forward (message_id); +CREATE INDEX IF NOT EXISTS idx_message_forward_source_message_id + ON message_forward (source_message_id); + +-- models/message_component.rs → message_component +CREATE TABLE IF NOT EXISTS message_component ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE, + row INTEGER NOT NULL DEFAULT 0, + position INTEGER NOT NULL DEFAULT 0, + component_type TEXT NOT NULL, + custom_id TEXT NOT NULL, + label TEXT NULL, + emoji TEXT NULL, + style TEXT NULL, + url TEXT NULL, + disabled BOOLEAN NOT NULL DEFAULT FALSE, + placeholder TEXT NULL, + min_values INTEGER NULL, + max_values INTEGER NULL, + options JSONB NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_message_component_message_id + ON message_component (message_id); + +COMMIT; diff --git a/models/message.rs b/models/message.rs new file mode 100644 index 0000000..cf71305 --- /dev/null +++ b/models/message.rs @@ -0,0 +1,175 @@ +//! Core message model — maps to PostgreSQL `message` table. +//! +//! Discord-style: `body` holds plain text / markdown, rich content lives in +//! companion tables (attachment, embed, poll, reaction, mention). +//! IDs are UUID v7 (time-ordered) so `ORDER BY id` = chronological order. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Discriminator for system / event messages vs. regular user messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum MessageType { + /// Regular user message (text / markdown). + #[default] + Text, + /// System-generated notice (e.g. "user joined the channel"). + System, + /// Channel event (pinned a message, changed topic, etc.). + Event, + /// Forum article / long-form post displayed as waterfall cards. + Article, +} + +impl MessageType { + pub fn as_str(&self) -> &'static str { + match self { + Self::Text => "text", + Self::System => "system", + Self::Event => "event", + Self::Article => "article", + } + } + + pub fn from_str_lossy(s: &str) -> Self { + match s { + "system" => Self::System, + "event" => Self::Event, + "article" => Self::Article, + _ => Self::Text, + } + } +} + +impl std::fmt::Display for MessageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Direct mapping of the `message` table row. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct Message { + /// UUID v7 — time-ordered primary key. + pub id: Uuid, + pub channel_id: Uuid, + pub author_id: Uuid, + /// Thread this message belongs to (NULL = not threaded). + pub thread_id: Option, + /// Direct reply reference (NULL = top-level message). + pub reply_to_message_id: Option, + /// "text" | "system" | "event" + pub message_type: String, + /// Plain text or markdown body. + pub body: String, + /// Extensible metadata (flags, locale, interaction ref, etc.). + pub metadata: Option, + pub pinned: bool, + /// True for bot / system generated messages. + pub system: bool, + pub edited_at: Option>, + pub deleted_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// Lightweight author info embedded in [`MessageDetail`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthorInfo { + pub id: Uuid, + pub username: String, + pub display_name: Option, + pub avatar_url: Option, + pub is_bot: bool, +} + +/// Message with resolved author and reaction/attachment aggregates. +/// Returned by read APIs; never stored directly. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageDetail { + #[serde(flatten)] + pub message: Message, + pub author: AuthorInfo, + /// Aggregated reaction counts: `{ "👍": 3, "🎉": 1 }`. + pub reactions: std::collections::HashMap, + pub attachment_count: i64, + pub embed_count: i64, + /// Whether the current user has bookmarked this message. + pub bookmarked: bool, + /// Reply chain depth (0 = top-level). + pub reply_depth: i32, +} + +/// Generate a new UUID v7 (time-ordered) for message IDs. +pub fn new_message_id() -> Uuid { + Uuid::now_v7() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_type_roundtrip() { + let t = MessageType::Text; + assert_eq!(t.as_str(), "text"); + assert_eq!(MessageType::from_str_lossy("text"), MessageType::Text); + assert_eq!(MessageType::from_str_lossy("system"), MessageType::System); + assert_eq!(MessageType::from_str_lossy("unknown"), MessageType::Text); + } + + #[test] + fn test_message_id_ordering() { + // UUID v7 IDs generated later should sort after earlier ones. + let a = new_message_id(); + std::thread::sleep(std::time::Duration::from_millis(2)); + let b = new_message_id(); + assert!(b > a, "UUID v7 should be time-ordered"); + } + + #[test] + fn test_message_detail_serialize() { + let msg = Message { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + author_id: Uuid::now_v7(), + thread_id: None, + reply_to_message_id: None, + message_type: "text".to_string(), + body: "hello world".to_string(), + metadata: None, + pinned: false, + system: false, + edited_at: None, + deleted_at: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let detail = MessageDetail { + message: msg, + author: AuthorInfo { + id: Uuid::now_v7(), + username: "alice".to_string(), + display_name: Some("Alice".to_string()), + avatar_url: None, + is_bot: false, + }, + reactions: std::collections::HashMap::from([ + ("👍".to_string(), 3), + ("🎉".to_string(), 1), + ]), + attachment_count: 0, + embed_count: 0, + bookmarked: false, + reply_depth: 0, + }; + + let json = serde_json::to_value(&detail).unwrap(); + assert_eq!(json["body"], "hello world"); + assert_eq!(json["author"]["username"], "alice"); + assert_eq!(json["reactions"]["👍"], 3); + } +} diff --git a/models/message_article.rs b/models/message_article.rs new file mode 100644 index 0000000..5398836 --- /dev/null +++ b/models/message_article.rs @@ -0,0 +1,196 @@ +//! Forum article / long-form post — maps to `message_article` table. +//! +//! Forum articles extend a regular [`Message`] with title, cover image, tags, +//! and view/like stats. Rendered as waterfall cards in forum channel views. +//! One article per message (1:1), linked via `message_id`. +//! +//! A message is an article when `message.message_type = "article"`. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Extended metadata for a forum article / post. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageArticle { + pub id: Uuid, + /// FK to `message.id`. UNIQUE constraint ensures 1:1. + pub message_id: Uuid, + /// Article title (plain text, max 256 chars). + pub title: String, + /// Short excerpt for card preview (plain text, max 512 chars). + pub summary: Option, + /// Cover image URL (displayed at the top of the card). + pub cover_url: Option, + /// Cover image width in pixels (for waterfall layout height calculation). + pub cover_width: Option, + /// Cover image height in pixels. + pub cover_height: Option, + /// Cover image dominant color (hex, for placeholder while loading). + pub cover_color: Option, + /// Tag IDs referencing `forum_tag` table, stored as JSON array. + pub tags: Option, + /// View count (denormalized, incremented on read). + pub view_count: i64, + /// Reaction / like count (denormalized). + pub like_count: i64, + /// Bookmark count (denormalized). + pub bookmark_count: i64, + /// Reply count (denormalized, threads inside the article). + pub reply_count: i64, + /// Most recent reply message id. + pub last_reply_message_id: Option, + /// Most recent reply timestamp. + pub last_reply_at: Option>, + /// User id of the last replier. + pub last_reply_user_id: Option, + /// Whether the article is pinned to the top of the forum channel. + pub is_pinned_to_top: bool, + /// Whether the question has been answered / resolved. + pub is_answered: bool, + /// Who marked it as answered. + pub answered_by: Option, + /// When it was marked as answered. + pub answered_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// Waterfall card view — article enriched with author info + first attachment +/// (cover fallback). Returned by forum list APIs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArticleCard { + #[serde(flatten)] + pub article: MessageArticle, + /// Author display info. + pub author: super::message::AuthorInfo, + /// Resolved tag names (e.g. ["Bug Report", "High Priority"]). + pub tag_names: Vec, + /// First image attachment URL (fallback when cover_url is NULL). + pub first_image_url: Option, + /// Whether the current user has bookmarked this article. + pub bookmarked: bool, + /// Whether the current user has liked this article. + pub liked: bool, +} + +/// Input payload for creating a new forum article. +/// +/// Sent by the client when composing a forum post. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateArticleInput { + pub channel_id: Uuid, + pub title: String, + pub body: String, + pub summary: Option, + pub cover_url: Option, + pub tags: Option>, +} + +/// Input payload for updating an existing forum article. +/// +/// All fields are optional — only provided fields are updated. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateArticleInput { + pub title: Option, + pub body: Option, + pub summary: Option, + pub cover_url: Option, + pub cover_color: Option, + pub tags: Option>, +} + +/// Sort modes for forum article listing. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ArticleSort { + /// Most recent activity first. + #[default] + LatestActivity, + /// Newest articles first. + Newest, + /// Most viewed first. + MostViewed, + /// Most liked first. + MostLiked, + /// Pinned articles first, then by latest activity. + PinnedFirst, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_article_card_serialize() { + let article = MessageArticle { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + title: "Bug Report: Login fails on Safari".into(), + summary: Some("Users report that login fails on Safari 18...".into()), + cover_url: Some("https://cdn.example.com/covers/bug.png".into()), + cover_width: Some(1200), + cover_height: Some(630), + cover_color: Some("#FF6B6B".into()), + tags: Some(serde_json::json!(["01909a", "01909b"])), + view_count: 142, + like_count: 7, + bookmark_count: 3, + reply_count: 5, + last_reply_message_id: Some(Uuid::now_v7()), + last_reply_at: Some(Utc::now()), + last_reply_user_id: Some(Uuid::now_v7()), + is_pinned_to_top: true, + is_answered: false, + answered_by: None, + answered_at: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let card = ArticleCard { + article, + author: super::super::message::AuthorInfo { + id: Uuid::now_v7(), + username: "alice".into(), + display_name: Some("Alice".into()), + avatar_url: None, + is_bot: false, + }, + tag_names: vec!["Bug Report".into(), "High Priority".into()], + first_image_url: None, + bookmarked: false, + liked: true, + }; + + let json = serde_json::to_value(&card).unwrap(); + assert_eq!(json["title"], "Bug Report: Login fails on Safari"); + assert_eq!(json["view_count"], 142); + assert_eq!(json["author"]["username"], "alice"); + assert_eq!(json["tag_names"][0], "Bug Report"); + assert_eq!(json["is_pinned_to_top"], true); + } + + #[test] + fn test_article_sort_serialize() { + let sort = ArticleSort::MostViewed; + let json = serde_json::to_value(sort).unwrap(); + assert_eq!(json, "most_viewed"); + } + + #[test] + fn test_create_article_input_serialize() { + let input = CreateArticleInput { + channel_id: Uuid::now_v7(), + title: "New Feature Proposal".into(), + body: "I'd like to propose...".into(), + summary: Some("A proposal for...".into()), + cover_url: None, + tags: Some(vec![Uuid::now_v7()]), + }; + + let json = serde_json::to_value(&input).unwrap(); + assert_eq!(json["title"], "New Feature Proposal"); + assert_eq!(json["tags"].as_array().unwrap().len(), 1); + } +} diff --git a/models/message_attachment.rs b/models/message_attachment.rs new file mode 100644 index 0000000..ff1118d --- /dev/null +++ b/models/message_attachment.rs @@ -0,0 +1,96 @@ +//! File / image attachment on a message — new table `message_attachment`. +//! +//! Discord-style: each message can have multiple attachments. Files are +//! uploaded to S3 first, then the URL is stored here. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A file or image attachment on a message. +/// +/// Maps to the `message_attachment` table. Files are uploaded to object +/// storage first, then the URL is stored here. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageAttachment { + pub id: Uuid, + pub message_id: Uuid, + /// Original filename as uploaded by the user. + pub filename: String, + /// MIME type: "image/png", "application/pdf", etc. + pub content_type: Option, + /// File size in bytes. + pub size: i64, + /// Public URL (S3 presigned or CDN). + pub url: String, + /// S3 / object-store key for backend access. + pub storage_key: Option, + /// Image / video width in pixels. + pub width: Option, + /// Image / video height in pixels. + pub height: Option, + /// Audio / video duration in seconds. + pub duration_secs: Option, + /// Blurred low-res preview for progressive loading (base64 data URI). + pub blurhash: Option, + /// Whether this attachment should be rendered as a spoiler (hidden until click). + pub spoiler: bool, + pub created_at: DateTime, +} + +/// Lightweight attachment summary for list views. +/// +/// Omits URL and storage_key to reduce payload size. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttachmentSummary { + pub id: Uuid, + pub filename: String, + pub content_type: Option, + pub size: i64, + pub width: Option, + pub height: Option, + pub spoiler: bool, +} + +impl From for AttachmentSummary { + fn from(a: MessageAttachment) -> Self { + Self { + id: a.id, + filename: a.filename, + content_type: a.content_type, + size: a.size, + width: a.width, + height: a.height, + spoiler: a.spoiler, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_attachment_conversion() { + let att = MessageAttachment { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + filename: "photo.png".into(), + content_type: Some("image/png".into()), + size: 1024, + url: "https://cdn.example.com/photo.png".into(), + storage_key: None, + width: Some(800), + height: Some(600), + duration_secs: None, + blurhash: None, + spoiler: false, + created_at: Utc::now(), + }; + + let summary: AttachmentSummary = att.into(); + assert_eq!(summary.filename, "photo.png"); + assert_eq!(summary.size, 1024); + assert_eq!(summary.width, Some(800)); + } +} diff --git a/models/message_bookmark.rs b/models/message_bookmark.rs new file mode 100644 index 0000000..0b6ad21 --- /dev/null +++ b/models/message_bookmark.rs @@ -0,0 +1,46 @@ +//! User bookmark on a message — maps to `message_bookmark` table. +//! +//! Similar to browser bookmarks: user saves a message for later reference, +//! optionally with a personal note. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A user bookmark on a message for later reference. +/// +/// Maps to the `message_bookmark` table. Similar to browser bookmarks, +/// optionally with a personal note. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageBookmark { + pub id: Uuid, + pub message_id: Uuid, + pub channel_id: Uuid, + pub user_id: Uuid, + /// Personal note the user attached to this bookmark. + pub note: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bookmark_serialize() { + let bm = MessageBookmark { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + note: Some("Important reference".into()), + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let json = serde_json::to_value(&bm).unwrap(); + assert_eq!(json["note"], "Important reference"); + assert!(json["message_id"].is_string()); + } +} diff --git a/models/message_component.rs b/models/message_component.rs new file mode 100644 index 0000000..66b621b --- /dev/null +++ b/models/message_component.rs @@ -0,0 +1,95 @@ +//! Interactive message components — maps to `message_component` table. +//! +//! Discord-style interactive elements attached to messages: buttons, select +//! menus, etc. Each component belongs to a message and has its own layout row. +//! Clicking a component emits an interaction event back to the bot/webhook. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// An interactive component attached to a message (button, select menu, etc.). +/// +/// Maps to the `message_component` table. Each component belongs to a message +/// and has a layout row/position. User interactions emit callbacks. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageComponent { + pub id: Uuid, + pub message_id: Uuid, + /// Layout row within the message (0-based). + pub row: i32, + /// Position within the row (0-based). + pub position: i32, + /// "button" | "select_menu" | "text_input" + pub component_type: String, + /// Unique identifier sent back in interaction callbacks. + pub custom_id: String, + /// Display label. + pub label: Option, + /// Emoji shown on the button (unicode or `:name:id`). + pub emoji: Option, + /// Button style: "primary" | "secondary" | "success" | "danger" | "link" + pub style: Option, + /// URL for link-style buttons. + pub url: Option, + /// Whether the component is disabled. + pub disabled: bool, + /// Placeholder text for select menus. + pub placeholder: Option, + /// Min/max selections for select menus. + pub min_values: Option, + pub max_values: Option, + /// Options for select menus, stored as JSON array. + pub options: Option, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ComponentType { + #[default] + Button, + SelectMenu, + TextInput, +} + +impl ComponentType { + pub fn as_str(&self) -> &'static str { + match self { + Self::Button => "button", + Self::SelectMenu => "select_menu", + Self::TextInput => "text_input", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_component_serialize() { + let c = MessageComponent { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + row: 0, + position: 0, + component_type: ComponentType::Button.as_str().into(), + custom_id: "btn_approve".into(), + label: Some("Approve".into()), + emoji: Some("✅".into()), + style: Some("success".into()), + url: None, + disabled: false, + placeholder: None, + min_values: None, + max_values: None, + options: None, + created_at: Utc::now(), + }; + + let json = serde_json::to_value(&c).unwrap(); + assert_eq!(json["component_type"], "button"); + assert_eq!(json["style"], "success"); + } +} diff --git a/models/message_draft.rs b/models/message_draft.rs new file mode 100644 index 0000000..7140d3b --- /dev/null +++ b/models/message_draft.rs @@ -0,0 +1,77 @@ +//! Message drafts — maps to `message_draft` table. +//! +//! Stores unsent messages so they survive browser refreshes and sync across +//! the user's connected devices. One draft per (channel, user, optional thread). +//! Drafts are upserted on every keystroke debounce and deleted on send. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A user's unsent message draft in a channel. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageDraft { + pub id: Uuid, + pub channel_id: Uuid, + pub user_id: Uuid, + /// Thread the draft belongs to (NULL = top-level). + pub thread_id: Option, + /// Message this draft is replying to (NULL = new message). + pub reply_to_message_id: Option, + /// Plain text or markdown body. + pub body: String, + /// Extensible metadata (attachments to be uploaded, etc.). + pub metadata: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// Input for upserting a draft (sent from client on debounced keystroke). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DraftUpsertInput { + pub channel_id: Uuid, + pub thread_id: Option, + pub reply_to_message_id: Option, + pub body: String, + pub metadata: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_draft_upsert_input_serialize() { + let input = DraftUpsertInput { + channel_id: Uuid::now_v7(), + thread_id: None, + reply_to_message_id: None, + body: "hello, this is a draft".to_string(), + metadata: None, + }; + + let json = serde_json::to_value(&input).unwrap(); + assert_eq!(json["body"], "hello, this is a draft"); + assert!(json["thread_id"].is_null()); + } + + #[test] + fn test_draft_serialize() { + let draft = MessageDraft { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + thread_id: Some(Uuid::now_v7()), + reply_to_message_id: None, + body: "draft body".to_string(), + metadata: Some(serde_json::json!({"pending_attachments": []})), + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let json = serde_json::to_value(&draft).unwrap(); + assert_eq!(json["body"], "draft body"); + assert!(json["thread_id"].is_string()); + assert!(json["metadata"]["pending_attachments"].is_array()); + } +} diff --git a/models/message_edit.rs b/models/message_edit.rs new file mode 100644 index 0000000..2930690 --- /dev/null +++ b/models/message_edit.rs @@ -0,0 +1,77 @@ +//! Message edit history — maps to `message_edit` table. +//! +//! Immutable append-only log of every edit to a message. +//! Used for audit trails, "edited" indicators with hover-to-see-original, +//! and compliance / moderation review. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// One edit record. Stored every time a message body is modified. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageEdit { + pub id: Uuid, + pub message_id: Uuid, + /// Who made the edit (usually the author; can be a moderator). + pub edited_by: Uuid, + /// Body content before the edit. + pub old_body: String, + /// Body content after the edit. + pub new_body: String, + pub edited_at: DateTime, +} + +/// Lightweight summary for the "edited" tooltip (no full body). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EditSummary { + pub edit_count: i64, + pub last_edited_at: Option>, + pub last_edited_by: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_edit_serialize() { + let edit = MessageEdit { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + edited_by: Uuid::now_v7(), + old_body: "before edit".to_string(), + new_body: "after edit".to_string(), + edited_at: Utc::now(), + }; + + let json = serde_json::to_value(&edit).unwrap(); + assert_eq!(json["old_body"], "before edit"); + assert_eq!(json["new_body"], "after edit"); + } + + #[test] + fn test_edit_summary_serialize() { + let summary = EditSummary { + edit_count: 3, + last_edited_at: Some(Utc::now()), + last_edited_by: Some(Uuid::now_v7()), + }; + + let json = serde_json::to_value(&summary).unwrap(); + assert_eq!(json["edit_count"], 3); + } + + #[test] + fn test_edit_summary_no_edits() { + let summary = EditSummary { + edit_count: 0, + last_edited_at: None, + last_edited_by: None, + }; + + let json = serde_json::to_value(&summary).unwrap(); + assert_eq!(json["edit_count"], 0); + assert!(json["last_edited_at"].is_null()); + } +} diff --git a/models/message_embed.rs b/models/message_embed.rs new file mode 100644 index 0000000..e12868a --- /dev/null +++ b/models/message_embed.rs @@ -0,0 +1,163 @@ +//! Rich embed on a message — new table `message_embed` + `message_embed_field`. +//! +//! Discord-style embeds: link previews, rich cards with title/description/ +//! thumbnail/image/footer/fields. Generated by the server (link preview) +//! or sent by bots/webhooks. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A single embed attached to a message. +/// One message can have multiple embeds (e.g. link preview + bot embed). +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageEmbed { + pub id: Uuid, + pub message_id: Uuid, + + /// "link" | "article" | "image" | "video" | "rich" + pub embed_type: String, + + pub title: Option, + pub description: Option, + pub url: Option, + + /// Embed accent color as integer (Discord format: 0xRRGGBB). + pub color: Option, + + // Media + /// Main image URL. + pub image_url: Option, + pub image_width: Option, + pub image_height: Option, + /// Small thumbnail URL. + pub thumbnail_url: Option, + pub thumbnail_width: Option, + pub thumbnail_height: Option, + /// Video URL (for video embeds). + pub video_url: Option, + pub video_width: Option, + pub video_height: Option, + + // Footer + pub author_name: Option, + pub author_url: Option, + pub author_icon_url: Option, + pub footer_text: Option, + pub footer_icon_url: Option, + + /// Provider name (e.g. "YouTube", "GitHub"). + pub provider_name: Option, + pub provider_url: Option, + + pub created_at: DateTime, +} + +/// A key-value field within an embed (Discord-style field rows). +/// Stored in a separate table for flexibility. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageEmbedField { + pub id: Uuid, + pub embed_id: Uuid, + pub name: String, + pub value: String, + /// Whether this field should display inline (side by side with other inline fields). + pub inline: bool, + pub position: i32, +} + +/// An embed with its fields resolved in a single structure. +/// +/// Convenience type returned by read APIs, joining embed and fields. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbedDetail { + #[serde(flatten)] + pub embed: MessageEmbed, + pub fields: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_embed_serialize() { + let embed = MessageEmbed { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + embed_type: "link".into(), + title: Some("Example".into()), + description: Some("A link preview".into()), + url: Some("https://example.com".into()), + color: Some(0x00FF00), + image_url: None, + image_width: None, + image_height: None, + thumbnail_url: None, + thumbnail_width: None, + thumbnail_height: None, + video_url: None, + video_width: None, + video_height: None, + author_name: None, + author_url: None, + author_icon_url: None, + footer_text: None, + footer_icon_url: None, + provider_name: Some("GitHub".into()), + provider_url: None, + created_at: Utc::now(), + }; + + let json = serde_json::to_value(&embed).unwrap(); + assert_eq!(json["embed_type"], "link"); + assert_eq!(json["title"], "Example"); + } + + #[test] + fn test_embed_detail_serialize() { + let embed = MessageEmbed { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + embed_type: "rich".into(), + title: Some("Article".into()), + description: None, + url: None, + color: None, + image_url: None, + image_width: None, + image_height: None, + thumbnail_url: None, + thumbnail_width: None, + thumbnail_height: None, + video_url: None, + video_width: None, + video_height: None, + author_name: None, + author_url: None, + author_icon_url: None, + footer_text: None, + footer_icon_url: None, + provider_name: None, + provider_url: None, + created_at: Utc::now(), + }; + let field = MessageEmbedField { + id: Uuid::now_v7(), + embed_id: embed.id, + name: "Key".into(), + value: "Value".into(), + inline: true, + position: 0, + }; + + let detail = EmbedDetail { + embed, + fields: vec![field], + }; + + let json = serde_json::to_value(&detail).unwrap(); + assert_eq!(json["embed_type"], "rich"); + assert!(json["fields"].is_array()); + } +} diff --git a/models/message_forward.rs b/models/message_forward.rs new file mode 100644 index 0000000..cfca482 --- /dev/null +++ b/models/message_forward.rs @@ -0,0 +1,48 @@ +//! Message forwarding trail — maps to `message_forward` table. +//! +//! When a user forwards a message from one channel to another, this table +//! records the provenance. The forwarded message is a new message with a +//! copy of the original body, linked back via `source_message_id`. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A forwarded message linking the copy to the original. +/// +/// Maps to the `message_forward` table. Records provenance when a user +/// forwards a message from one channel to another. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageForward { + pub id: Uuid, + /// The new (forwarded) message. + pub message_id: Uuid, + /// The original message being forwarded. + pub source_message_id: Uuid, + /// The channel the original message came from. + pub source_channel_id: Uuid, + /// Who forwarded the message. + pub forwarded_by: Uuid, + pub created_at: DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_forward_serialize() { + let f = MessageForward { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + source_message_id: Uuid::now_v7(), + source_channel_id: Uuid::now_v7(), + forwarded_by: Uuid::now_v7(), + created_at: Utc::now(), + }; + + let json = serde_json::to_value(&f).unwrap(); + assert!(json["source_message_id"].is_string()); + assert!(json["source_channel_id"].is_string()); + } +} diff --git a/models/message_mention.rs b/models/message_mention.rs new file mode 100644 index 0000000..6ebb3a5 --- /dev/null +++ b/models/message_mention.rs @@ -0,0 +1,123 @@ +//! @mention tracking — maps to `message_mention` table. +//! +//! Parsed from message body on send. Used for notification and +//! "mentions" feed. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// An @mention of a user in a message. +/// +/// Maps to the `message_mention` table. Parsed from message body on send +/// and used for notifications and the mentions feed. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageMention { + pub id: Uuid, + pub message_id: Uuid, + pub channel_id: Uuid, + pub mentioned_user_id: Uuid, + pub mentioned_by: Uuid, + /// When the mentioned user read the notification. + pub read_at: Option>, + pub created_at: DateTime, +} + +/// Parse @username mentions from a message body. +/// Returns unique usernames (without the `@` prefix). +/// +/// Matches `@word` where word is `[a-zA-Z0-9_-]+`, min 2 chars. +/// Ignores `@@` (escaped) and mentions inside code spans/blocks. +pub fn parse_mentions(body: &str) -> Vec { + let mut seen = std::collections::HashSet::new(); + let mut mentions = Vec::new(); + + // Simple state machine: skip content inside backtick spans. + let mut in_code = false; + let bytes = body.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + if bytes[i] == b'`' { + in_code = !in_code; + i += 1; + continue; + } + + if !in_code && bytes[i] == b'@' { + // Skip escaped @@ + if i + 1 < bytes.len() && bytes[i + 1] == b'@' { + i += 2; + continue; + } + + let start = i + 1; + let mut end = start; + while end < bytes.len() + && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_' || bytes[end] == b'-') + { + end += 1; + } + + let len = end - start; + if len >= 2 { + let name = body[start..end].to_lowercase(); + if seen.insert(name.clone()) { + mentions.push(name); + } + } + i = end; + } else { + i += 1; + } + } + + mentions +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_mentions_basic() { + let m = parse_mentions("hey @alice, cc @bob"); + assert_eq!(m, vec!["alice".to_string(), "bob".to_string()]); + } + + #[test] + fn test_parse_mentions_dedup() { + let m = parse_mentions("@alice said hi to @alice"); + assert_eq!(m, vec!["alice".to_string()]); + } + + #[test] + fn test_parse_mentions_case_insensitive() { + let m = parse_mentions("@Alice and @ALICE"); + assert_eq!(m, vec!["alice".to_string()]); + } + + #[test] + fn test_parse_mentions_skip_code_span() { + let m = parse_mentions("use `@here` to notify @alice"); + assert_eq!(m, vec!["alice".to_string()]); + } + + #[test] + fn test_parse_mentions_skip_escaped() { + let m = parse_mentions("@@alice is not a mention, but @bob is"); + assert_eq!(m, vec!["bob".to_string()]); + } + + #[test] + fn test_parse_mentions_short_name_ignored() { + let m = parse_mentions("@a is too short, @ab is ok"); + assert_eq!(m, vec!["ab".to_string()]); + } + + #[test] + fn test_parse_mentions_empty() { + assert!(parse_mentions("no mentions here").is_empty()); + assert!(parse_mentions("").is_empty()); + } +} diff --git a/models/message_notification.rs b/models/message_notification.rs new file mode 100644 index 0000000..7565ef8 --- /dev/null +++ b/models/message_notification.rs @@ -0,0 +1,97 @@ +//! Notification delivery tracking — maps to `message_notification` table. +//! +//! Records when a message triggers a notification for a user (mention, reply, +//! thread activity, etc.) and tracks the delivery/read lifecycle. +//! Separate from `message_mention` which only covers @mentions. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A notification triggered for a user by a message. +/// +/// Maps to the `message_notification` table. Records when a message triggers +/// a notification (mention, reply, thread activity) and tracks delivery/read. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageNotification { + pub id: Uuid, + pub message_id: Uuid, + pub channel_id: Uuid, + pub user_id: Uuid, + /// "mention" | "reply" | "thread" | "watch" + pub reason: String, + /// "pending" | "delivered" | "read" | "dismissed" + pub status: String, + /// Channel of delivery: "push" | "email" | "in_app" + pub delivery_channel: Option, + pub delivered_at: Option>, + pub read_at: Option>, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum NotificationReason { + #[default] + Mention, + Reply, + Thread, + Watch, +} + +impl NotificationReason { + pub fn as_str(&self) -> &'static str { + match self { + Self::Mention => "mention", + Self::Reply => "reply", + Self::Thread => "thread", + Self::Watch => "watch", + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum NotificationStatus { + #[default] + Pending, + Delivered, + Read, + Dismissed, +} + +impl NotificationStatus { + pub fn as_str(&self) -> &'static str { + match self { + Self::Pending => "pending", + Self::Delivered => "delivered", + Self::Read => "read", + Self::Dismissed => "dismissed", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_notification_serialize() { + let n = MessageNotification { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + reason: NotificationReason::Mention.as_str().into(), + status: NotificationStatus::Pending.as_str().into(), + delivery_channel: Some("push".into()), + delivered_at: None, + read_at: None, + created_at: Utc::now(), + }; + + let json = serde_json::to_value(&n).unwrap(); + assert_eq!(json["reason"], "mention"); + assert_eq!(json["status"], "pending"); + } +} diff --git a/models/message_pin.rs b/models/message_pin.rs new file mode 100644 index 0000000..bc25996 --- /dev/null +++ b/models/message_pin.rs @@ -0,0 +1,60 @@ +//! Pinned message management — maps to `message_pin` table. +//! +//! A channel can have multiple pinned messages with explicit ordering. +//! Unlike the `message.pinned` boolean (which just marks the row), +//! this table tracks *who* pinned, *when*, and the display *position*. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// One pinned message entry. Ordered by `position` ascending. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessagePin { + pub id: Uuid, + pub channel_id: Uuid, + pub message_id: Uuid, + /// Who pinned this message. + pub pinned_by: Uuid, + /// Display position in the pinned list (0 = top). + pub position: i32, + pub created_at: DateTime, +} + +/// Summary view returned by list-pins APIs (joined with message content). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PinDetail { + #[serde(flatten)] + pub pin: MessagePin, + pub message_body: String, + pub message_author_id: Uuid, + pub message_created_at: DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pin_detail_serialize() { + let pin = MessagePin { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + pinned_by: Uuid::now_v7(), + position: 0, + created_at: Utc::now(), + }; + + let detail = PinDetail { + pin, + message_body: "important announcement".to_string(), + message_author_id: Uuid::now_v7(), + message_created_at: Utc::now(), + }; + + let json = serde_json::to_value(&detail).unwrap(); + assert_eq!(json["position"], 0); + assert_eq!(json["message_body"], "important announcement"); + } +} diff --git a/models/message_poll.rs b/models/message_poll.rs new file mode 100644 index 0000000..134c92c --- /dev/null +++ b/models/message_poll.rs @@ -0,0 +1,170 @@ +//! Poll on a message — new tables `message_poll`, `message_poll_option`, +//! `message_poll_vote`. +//! +//! Discord-style polls: attached to a message, with multiple options, +//! optional multi-vote, and an expiry time. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// The poll itself (one per message, optional). +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessagePoll { + pub id: Uuid, + pub message_id: Uuid, + /// The question displayed to voters. + pub question: String, + /// Whether users can select multiple options. + pub allow_multiselect: bool, + /// Maximum number of options a user can select (NULL = unlimited when multiselect). + pub max_selections: Option, + /// When voting closes (NULL = no expiry). + pub expires_at: Option>, + /// Total number of votes cast (denormalized for fast reads). + pub total_votes: i64, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// A single selectable option within a poll. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessagePollOption { + pub id: Uuid, + pub poll_id: Uuid, + /// Display text for this option. + pub text: String, + /// Optional emoji prefix (Discord-style). + pub emoji: Option, + /// Number of votes this option received (denormalized). + pub vote_count: i64, + /// Display order. + pub position: i32, +} + +/// A single vote cast by a user. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessagePollVote { + pub id: Uuid, + pub poll_id: Uuid, + pub option_id: Uuid, + pub user_id: Uuid, + pub created_at: DateTime, +} + +/// Aggregated poll results for API responses. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PollResult { + pub poll: MessagePoll, + pub options: Vec, + /// Which options the current user voted for (empty if not voted). + pub my_votes: Vec, + /// Whether the poll has expired. + pub is_expired: bool, +} + +/// Option with its vote count and percentage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PollOptionResult { + #[serde(flatten)] + pub option: MessagePollOption, + /// Percentage of total votes (0.0–100.0), rounded to 1 decimal. + pub percentage: f64, +} + +impl PollResult { + /// Compute percentages from total_votes. + pub fn from_poll( + poll: MessagePoll, + options: Vec, + my_votes: Vec, + ) -> Self { + let total = poll.total_votes.max(1) as f64; + let now = Utc::now(); + let is_expired = poll.expires_at.is_some_and(|exp| now >= exp); + + let options = options + .into_iter() + .map(|opt| { + let pct = if poll.total_votes > 0 { + (opt.vote_count as f64 / total * 100.0 * 10.0).round() / 10.0 + } else { + 0.0 + }; + PollOptionResult { + option: opt, + percentage: pct, + } + }) + .collect(); + + Self { + poll, + options, + my_votes, + is_expired, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_poll_result_percentages() { + let poll = MessagePoll { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + question: "Best language?".to_string(), + allow_multiselect: false, + max_selections: None, + expires_at: None, + total_votes: 10, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let options = vec![ + MessagePollOption { + id: Uuid::now_v7(), + poll_id: poll.id, + text: "Rust".to_string(), + emoji: None, + vote_count: 7, + position: 0, + }, + MessagePollOption { + id: Uuid::now_v7(), + poll_id: poll.id, + text: "Go".to_string(), + emoji: None, + vote_count: 3, + position: 1, + }, + ]; + + let result = PollResult::from_poll(poll, options, vec![]); + assert!(!result.is_expired); + assert_eq!(result.options[0].percentage, 70.0); + assert_eq!(result.options[1].percentage, 30.0); + } + + #[test] + fn test_poll_result_zero_votes() { + let poll = MessagePoll { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + question: "Empty poll".to_string(), + allow_multiselect: false, + max_selections: None, + expires_at: None, + total_votes: 0, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let result = PollResult::from_poll(poll, vec![], vec![]); + assert!(result.options.is_empty()); + } +} diff --git a/models/message_reaction.rs b/models/message_reaction.rs new file mode 100644 index 0000000..0ff310f --- /dev/null +++ b/models/message_reaction.rs @@ -0,0 +1,66 @@ +//! Emoji reaction on a message — maps to `message_reaction` table. +//! +//! One row per (user, message, emoji). `content` stores the emoji string +//! (Unicode emoji or custom `:name:id` format like Discord). + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A single emoji reaction on a message by a user. +/// +/// Maps to the `message_reaction` table. One row per (user, message, emoji). +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageReaction { + pub id: Uuid, + pub message_id: Uuid, + pub channel_id: Uuid, + pub user_id: Uuid, + /// Emoji string: "👍", "🎉", or custom ":appks:01909a..." + pub content: String, + pub created_at: DateTime, +} + +/// Aggregated reaction count for a single emoji on a message. +/// Returned in [`MessageDetail::reactions`] and list APIs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReactionCount { + pub content: String, + pub count: i64, + /// Whether the current user reacted with this emoji. + pub me: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_reaction_serialize() { + let reaction = MessageReaction { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + content: "👍".into(), + created_at: Utc::now(), + }; + + let json = serde_json::to_value(&reaction).unwrap(); + assert_eq!(json["content"], "👍"); + } + + #[test] + fn test_reaction_count_serialize() { + let count = ReactionCount { + content: "🎉".into(), + count: 3, + me: true, + }; + + let json = serde_json::to_value(&count).unwrap(); + assert_eq!(json["content"], "🎉"); + assert_eq!(json["count"], 3); + assert_eq!(json["me"], true); + } +} diff --git a/models/message_read_state.rs b/models/message_read_state.rs new file mode 100644 index 0000000..5867fe6 --- /dev/null +++ b/models/message_read_state.rs @@ -0,0 +1,92 @@ +//! Per-user read state — maps to `message_read_state` table. +//! +//! Tracks the last message each user has read in each channel. +//! Used for unread badges, "mark as read", and notification suppression. +//! One row per (channel_id, user_id), upserted on each read. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A user's read progress in one channel. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageReadState { + pub id: Uuid, + pub channel_id: Uuid, + pub user_id: Uuid, + /// The last message id the user has read (cursor). + pub last_read_message_id: Option, + /// When the user last opened / scrolled through this channel. + pub last_read_at: Option>, + /// Total unread message count (denormalized for fast badge display). + pub unread_count: i64, + /// Total unread @mentions for this channel (denormalized). + pub unread_mentions: i64, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// Summary of a user's read state for the client-side channel list. +/// +/// Includes unread badge count and mention count for a single channel. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadStateSummary { + pub channel_id: Uuid, + pub unread_count: i64, + pub unread_mentions: i64, + pub has_unread: bool, +} + +impl From for ReadStateSummary { + fn from(s: MessageReadState) -> Self { + Self { + channel_id: s.channel_id, + unread_count: s.unread_count, + unread_mentions: s.unread_mentions, + has_unread: s.unread_count > 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_state_summary_conversion() { + let state = MessageReadState { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + last_read_message_id: None, + last_read_at: None, + unread_count: 5, + unread_mentions: 2, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let summary: ReadStateSummary = state.into(); + assert!(summary.has_unread); + assert_eq!(summary.unread_count, 5); + assert_eq!(summary.unread_mentions, 2); + } + + #[test] + fn test_read_state_summary_no_unread() { + let state = MessageReadState { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + last_read_message_id: None, + last_read_at: None, + unread_count: 0, + unread_mentions: 0, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let summary: ReadStateSummary = state.into(); + assert!(!summary.has_unread); + } +} diff --git a/models/message_scheduled.rs b/models/message_scheduled.rs new file mode 100644 index 0000000..1f00f23 --- /dev/null +++ b/models/message_scheduled.rs @@ -0,0 +1,82 @@ +//! Scheduled messages — maps to `message_scheduled` table. +//! +//! A message that the user has composed but wants to send at a future time. +//! A background job picks up rows where `scheduled_at <= now()` and dispatches +//! them through the normal send path. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A message scheduled to be sent at a future time. +/// +/// Maps to the `message_scheduled` table. A background job picks up rows +/// where `scheduled_at <= now()` and dispatches them through the normal send path. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageScheduled { + pub id: Uuid, + pub channel_id: Uuid, + pub author_id: Uuid, + pub thread_id: Option, + pub reply_to_message_id: Option, + pub body: String, + pub metadata: Option, + /// When the message should be sent. + pub scheduled_at: DateTime, + /// "pending" | "sent" | "cancelled" | "failed" + pub status: String, + /// Set after the message is dispatched; points to the sent `message.id`. + pub sent_message_id: Option, + /// Error message if dispatch failed. + pub error: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ScheduledStatus { + #[default] + Pending, + Sent, + Cancelled, + Failed, +} + +impl ScheduledStatus { + pub fn as_str(&self) -> &'static str { + match self { + Self::Pending => "pending", + Self::Sent => "sent", + Self::Cancelled => "cancelled", + Self::Failed => "failed", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scheduled_serialize() { + let s = MessageScheduled { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + author_id: Uuid::now_v7(), + thread_id: None, + reply_to_message_id: None, + body: "Good morning everyone!".into(), + metadata: None, + scheduled_at: Utc::now(), + status: ScheduledStatus::Pending.as_str().into(), + sent_message_id: None, + error: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let json = serde_json::to_value(&s).unwrap(); + assert_eq!(json["status"], "pending"); + } +} diff --git a/models/message_sticker.rs b/models/message_sticker.rs new file mode 100644 index 0000000..0fca37e --- /dev/null +++ b/models/message_sticker.rs @@ -0,0 +1,56 @@ +//! Sticker attachment on a message — maps to `message_sticker` table. +//! +//! Large sticker images sent in messages, distinct from emoji reactions. +//! Stickers are either workspace-level (custom, defined in appks) or +//! system-level (built-in sticker packs). + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A sticker attached to a message. +/// +/// Maps to the `message_sticker` table. Stickers are larger than emoji +/// and can be workspace-level (custom) or system-level (built-in packs). +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageSticker { + pub id: Uuid, + pub message_id: Uuid, + /// References a sticker defined in appks (workspace or system sticker). + pub sticker_id: Uuid, + /// Sticker name at time of send (snapshot for history). + pub name: String, + /// Image URL (snapshot). + pub image_url: String, + /// "png" | "apng" | "lottie" + pub format_type: String, + /// Pack name (e.g. "Wumpus" or workspace name). + pub pack_name: Option, + /// Search tags for discovery. + pub tags: Option, + pub created_at: DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sticker_serialize() { + let s = MessageSticker { + id: Uuid::now_v7(), + message_id: Uuid::now_v7(), + sticker_id: Uuid::now_v7(), + name: "Hype!".into(), + image_url: "https://cdn.example.com/stickers/hype.png".into(), + format_type: "png".into(), + pack_name: Some("Wumpus".into()), + tags: Some("excited,hype".into()), + created_at: Utc::now(), + }; + + let json = serde_json::to_value(&s).unwrap(); + assert_eq!(json["name"], "Hype!"); + assert_eq!(json["format_type"], "png"); + } +} diff --git a/models/message_thread.rs b/models/message_thread.rs new file mode 100644 index 0000000..bcb0a40 --- /dev/null +++ b/models/message_thread.rs @@ -0,0 +1,60 @@ +//! Thread metadata — maps to `message_thread` table. +//! +//! A thread is anchored by a root message. Reply messages in the same thread +//! set `message.thread_id = message_thread.id`. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A message thread anchored by a root message. +/// +/// Maps to the `message_thread` table. Reply messages set +/// `message.thread_id = message_thread.id` to join the thread. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageThread { + pub id: Uuid, + pub channel_id: Uuid, + /// The first message that started this thread. + pub root_message_id: Uuid, + pub created_by: Uuid, + pub replies_count: i64, + pub participants_count: i64, + pub last_reply_message_id: Option, + pub last_reply_at: Option>, + /// Forum-style: mark thread as resolved / answered. + pub resolved: bool, + pub resolved_by: Option, + pub resolved_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_thread_serialize() { + let thread = MessageThread { + id: Uuid::now_v7(), + channel_id: Uuid::now_v7(), + root_message_id: Uuid::now_v7(), + created_by: Uuid::now_v7(), + replies_count: 5, + participants_count: 3, + last_reply_message_id: None, + last_reply_at: None, + resolved: false, + resolved_by: None, + resolved_at: None, + created_at: Utc::now(), + updated_at: Utc::now(), + }; + + let json = serde_json::to_value(&thread).unwrap(); + assert_eq!(json["replies_count"], 5); + assert_eq!(json["resolved"], false); + assert!(json["id"].is_string()); + } +} diff --git a/models/message_thread_participant.rs b/models/message_thread_participant.rs new file mode 100644 index 0000000..7c77009 --- /dev/null +++ b/models/message_thread_participant.rs @@ -0,0 +1,71 @@ +//! Thread participant membership — maps to `message_thread_participant` table. +//! +//! Tracks which users are part of a thread. A user becomes a participant when +//! they reply in a thread, get @mentioned, or are explicitly added. +//! Without this table, `message_thread.participants_count` is un-verifiable. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A user participating in a thread. +/// +/// Maps to the `message_thread_participant` table. Users become participants +/// when they reply, get @mentioned, or are explicitly added. +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct MessageThreadParticipant { + pub id: Uuid, + pub thread_id: Uuid, + pub user_id: Uuid, + /// How the user joined the thread. + pub joined_reason: Option, + pub last_read_message_id: Option, + pub last_read_at: Option>, + pub joined_at: DateTime, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum JoinReason { + /// User sent a reply in the thread. + #[default] + Reply, + /// User was @mentioned in a thread message. + Mentioned, + /// User was explicitly added by another participant. + Added, + /// User joined the thread themselves. + Joined, +} + +impl JoinReason { + pub fn as_str(&self) -> &'static str { + match self { + Self::Reply => "reply", + Self::Mentioned => "mentioned", + Self::Added => "added", + Self::Joined => "joined", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_participant_serialize() { + let p = MessageThreadParticipant { + id: Uuid::now_v7(), + thread_id: Uuid::now_v7(), + user_id: Uuid::now_v7(), + joined_reason: Some(JoinReason::Reply.as_str().into()), + last_read_message_id: None, + last_read_at: None, + joined_at: Utc::now(), + }; + + let json = serde_json::to_value(&p).unwrap(); + assert_eq!(json["joined_reason"], "reply"); + } +} diff --git a/models/mod.rs b/models/mod.rs new file mode 100644 index 0000000..cdd065b --- /dev/null +++ b/models/mod.rs @@ -0,0 +1,43 @@ +pub mod message; +pub mod message_article; +pub mod message_attachment; +pub mod message_bookmark; +pub mod message_component; +pub mod message_draft; +pub mod message_edit; +pub mod message_embed; +pub mod message_forward; +pub mod message_mention; +pub mod message_notification; +pub mod message_pin; +pub mod message_poll; +pub mod message_reaction; +pub mod message_read_state; +pub mod message_scheduled; +pub mod message_sticker; +pub mod message_thread; +pub mod message_thread_participant; + +pub use message::{Message, MessageDetail, MessageType}; +pub use message_article::{ + ArticleCard, ArticleSort, CreateArticleInput, MessageArticle, UpdateArticleInput, +}; +pub use message_attachment::{AttachmentSummary, MessageAttachment}; +pub use message_bookmark::MessageBookmark; +pub use message_component::{ComponentType, MessageComponent}; +pub use message_draft::{DraftUpsertInput, MessageDraft}; +pub use message_edit::{EditSummary, MessageEdit}; +pub use message_embed::{EmbedDetail, MessageEmbed, MessageEmbedField}; +pub use message_forward::MessageForward; +pub use message_mention::MessageMention; +pub use message_notification::{MessageNotification, NotificationReason, NotificationStatus}; +pub use message_pin::{MessagePin, PinDetail}; +pub use message_poll::{ + MessagePoll, MessagePollOption, MessagePollVote, PollOptionResult, PollResult, +}; +pub use message_reaction::{MessageReaction, ReactionCount}; +pub use message_read_state::{MessageReadState, ReadStateSummary}; +pub use message_scheduled::{MessageScheduled, ScheduledStatus}; +pub use message_sticker::MessageSticker; +pub use message_thread::MessageThread; +pub use message_thread_participant::{JoinReason, MessageThreadParticipant}; diff --git a/pb/core.rs b/pb/core.rs index 5345620..8af5c22 100644 --- a/pb/core.rs +++ b/pb/core.rs @@ -1 +1 @@ -include!(concat!(env!("OUT_DIR"), "/appks.core.v1.rs")); \ No newline at end of file +include!(concat!(env!("OUT_DIR"), "/appks.core.v1.rs")); diff --git a/pb/mod.rs b/pb/mod.rs index 88e1db6..0f6056b 100644 --- a/pb/mod.rs +++ b/pb/mod.rs @@ -1,2 +1,2 @@ pub mod core; -pub mod im; \ No newline at end of file +pub mod im; diff --git a/repo/message_article.rs b/repo/message_article.rs new file mode 100644 index 0000000..201b0c8 --- /dev/null +++ b/repo/message_article.rs @@ -0,0 +1,263 @@ +//! Forum article CRUD operations on `MessageRepo`. +//! +//! Articles extend regular messages with forum-specific metadata (title, cover, +//! view/like stats, tags). Rendered as waterfall cards in forum channels. + +use chrono::Utc; +use sqlx::Row; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message::AuthorInfo; +use crate::models::message_article::{ArticleCard, ArticleSort, MessageArticle}; + +use super::message_repo::MessageRepo; +use super::pagination::{CursorPage, clamp_limit}; + +impl MessageRepo { + /// Create an article record linked to an existing message. + #[allow(clippy::too_many_arguments)] + pub async fn create_article( + &self, + message_id: Uuid, + title: &str, + summary: Option<&str>, + cover_url: Option<&str>, + cover_width: Option, + cover_height: Option, + cover_color: Option<&str>, + tags: Option<&serde_json::Value>, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + sqlx::query_as::<_, MessageArticle>( + r#" + INSERT INTO message_article ( + id, message_id, title, summary, cover_url, cover_width, cover_height, + cover_color, tags, view_count, like_count, bookmark_count, reply_count, + is_pinned_to_top, is_answered, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 0, 0, 0, 0, FALSE, FALSE, $10, $10) + RETURNING * + "#, + ) + .bind(id) + .bind(message_id) + .bind(title) + .bind(summary) + .bind(cover_url) + .bind(cover_width) + .bind(cover_height) + .bind(cover_color) + .bind(tags) + .bind(now) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Update an existing article's metadata. Does NOT update the message body. + pub async fn update_article( + &self, + message_id: Uuid, + title: Option<&str>, + summary: Option<&str>, + cover_url: Option<&str>, + cover_color: Option<&str>, + tags: Option<&serde_json::Value>, + ) -> ImksResult> { + let now = Utc::now(); + sqlx::query_as::<_, MessageArticle>( + r#" + UPDATE message_article + SET title = COALESCE($1, title), + summary = COALESCE($2, summary), + cover_url = COALESCE($3, cover_url), + cover_color = COALESCE($4, cover_color), + tags = COALESCE($5, tags), + updated_at = $6 + WHERE message_id = $7 + RETURNING * + "#, + ) + .bind(title) + .bind(summary) + .bind(cover_url) + .bind(cover_color) + .bind(tags) + .bind(now) + .bind(message_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } + + /// Get an article by its message_id. + pub async fn get_article(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageArticle>("SELECT * FROM message_article WHERE message_id = $1") + .bind(message_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } + + /// List articles in a forum channel with id-based cursor pagination. + pub async fn list_articles( + &self, + channel_id: Uuid, + sort: ArticleSort, + before: Option<(i64, Uuid)>, + limit: Option, + ) -> ImksResult> { + let effective_limit = clamp_limit(limit); + let fetch_limit = effective_limit + 1; + let cursor_id = before.map(|(_, id)| id); + + let order_by = match sort { + ArticleSort::LatestActivity => "a.last_reply_at DESC NULLS LAST, m.id DESC", + ArticleSort::Newest => "m.id DESC", + ArticleSort::MostViewed => "a.view_count DESC, m.id DESC", + ArticleSort::MostLiked => "a.like_count DESC, m.id DESC", + ArticleSort::PinnedFirst => { + "a.is_pinned_to_top DESC, a.last_reply_at DESC NULLS LAST, m.id DESC" + } + }; + + let query = if cursor_id.is_some() { + format!( + r#" + SELECT a.*, m.author_id, + ( + SELECT att.url + FROM message_attachment att + WHERE att.message_id = a.message_id + AND att.content_type LIKE 'image/%' + ORDER BY att.created_at + LIMIT 1 + ) AS first_image_url + FROM message_article a + JOIN message m ON m.id = a.message_id + WHERE m.channel_id = $1 + AND m.deleted_at IS NULL + AND m.id < $2 + ORDER BY {order_by} + LIMIT $3 + "#, + ) + } else { + format!( + r#" + SELECT a.*, m.author_id, + ( + SELECT att.url + FROM message_attachment att + WHERE att.message_id = a.message_id + AND att.content_type LIKE 'image/%' + ORDER BY att.created_at + LIMIT 1 + ) AS first_image_url + FROM message_article a + JOIN message m ON m.id = a.message_id + WHERE m.channel_id = $1 + AND m.deleted_at IS NULL + ORDER BY {order_by} + LIMIT $2 + "#, + ) + }; + + let rows = if let Some(cursor) = cursor_id { + sqlx::query(sqlx::AssertSqlSafe(query.as_str())) + .bind(channel_id) + .bind(cursor) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } else { + sqlx::query(sqlx::AssertSqlSafe(query.as_str())) + .bind(channel_id) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + }; + + // Convert raw rows to MessageArticle + let articles: Vec = rows + .iter() + .map(|r| MessageArticle { + id: r.get("id"), + message_id: r.get("message_id"), + title: r.get("title"), + summary: r.get("summary"), + cover_url: r.get("cover_url"), + cover_width: r.get("cover_width"), + cover_height: r.get("cover_height"), + cover_color: r.get("cover_color"), + tags: r.get("tags"), + view_count: r.get("view_count"), + like_count: r.get("like_count"), + bookmark_count: r.get("bookmark_count"), + reply_count: r.get("reply_count"), + last_reply_message_id: r.get("last_reply_message_id"), + last_reply_at: r.get("last_reply_at"), + last_reply_user_id: r.get("last_reply_user_id"), + is_pinned_to_top: r.get("is_pinned_to_top"), + is_answered: r.get("is_answered"), + answered_by: r.get("answered_by"), + answered_at: r.get("answered_at"), + created_at: r.get("created_at"), + updated_at: r.get("updated_at"), + }) + .collect(); + + let has_more = articles.len() > effective_limit as usize; + let items: Vec = articles + .into_iter() + .take(effective_limit as usize) + .collect(); + + let next_cursor = if has_more { + items.last().map(|a| a.message_id) + } else { + None + }; + + let cards: Vec = items + .into_iter() + .zip(rows.iter()) + .map(|(article, row)| { + let author_id: Uuid = row.get("author_id"); + ArticleCard { + article, + author: AuthorInfo { + id: author_id, + username: author_id.to_string(), + display_name: None, + avatar_url: None, + is_bot: false, + }, + tag_names: Vec::new(), + first_image_url: row.get("first_image_url"), + bookmarked: false, + liked: false, + } + }) + .collect(); + + Ok(CursorPage { + items: cards, + next_cursor, + has_more, + }) + } + + /// Increment the view count for an article. + pub async fn increment_article_view(&self, message_id: Uuid) -> ImksResult<()> { + sqlx::query("UPDATE message_article SET view_count = view_count + 1 WHERE message_id = $1") + .bind(message_id) + .execute(self.pool()) + .await?; + + Ok(()) + } +} diff --git a/repo/message_attachment.rs b/repo/message_attachment.rs new file mode 100644 index 0000000..095482f --- /dev/null +++ b/repo/message_attachment.rs @@ -0,0 +1,83 @@ +//! Attachment CRUD operations on `MessageRepo`. + +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_attachment::{AttachmentSummary, MessageAttachment}; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Create a single attachment record. + #[allow(clippy::too_many_arguments)] + pub async fn create_attachment( + &self, + message_id: Uuid, + filename: &str, + content_type: Option<&str>, + size: i64, + url: &str, + storage_key: Option<&str>, + width: Option, + height: Option, + spoiler: bool, + ) -> ImksResult { + let id = Uuid::now_v7(); + + sqlx::query_as::<_, MessageAttachment>( + r#" + INSERT INTO message_attachment ( + id, message_id, filename, content_type, size, url, + storage_key, width, height, spoiler + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING * + "#, + ) + .bind(id) + .bind(message_id) + .bind(filename) + .bind(content_type) + .bind(size) + .bind(url) + .bind(storage_key) + .bind(width) + .bind(height) + .bind(spoiler) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get all attachments for a message. + pub async fn get_attachments(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageAttachment>( + "SELECT * FROM message_attachment WHERE message_id = $1 ORDER BY created_at", + ) + .bind(message_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } + + /// Get lightweight attachment summaries (no URLs/storage keys). + pub async fn get_attachment_summaries( + &self, + message_id: Uuid, + ) -> ImksResult> { + let attachments = self.get_attachments(message_id).await?; + Ok(attachments + .into_iter() + .map(AttachmentSummary::from) + .collect()) + } + + /// Delete a single attachment. + pub async fn delete_attachment(&self, attachment_id: Uuid) -> ImksResult { + let result = sqlx::query("DELETE FROM message_attachment WHERE id = $1") + .bind(attachment_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } +} diff --git a/repo/message_bookmark.rs b/repo/message_bookmark.rs new file mode 100644 index 0000000..11e3fac --- /dev/null +++ b/repo/message_bookmark.rs @@ -0,0 +1,112 @@ +//! Bookmark CRUD operations on `MessageRepo`. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_bookmark::MessageBookmark; + +use super::message_repo::MessageRepo; +use super::pagination::{CursorPage, clamp_limit}; + +impl MessageRepo { + /// Add a bookmark for a message. + pub async fn add_bookmark( + &self, + message_id: Uuid, + channel_id: Uuid, + user_id: Uuid, + note: Option<&str>, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + sqlx::query_as::<_, MessageBookmark>( + r#" + INSERT INTO message_bookmark (id, message_id, channel_id, user_id, note, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $6) + ON CONFLICT (user_id, message_id) DO UPDATE SET note = EXCLUDED.note, updated_at = EXCLUDED.updated_at + RETURNING * + "#, + ) + .bind(id) + .bind(message_id) + .bind(channel_id) + .bind(user_id) + .bind(note) + .bind(now) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Remove a bookmark. + pub async fn remove_bookmark(&self, message_id: Uuid, user_id: Uuid) -> ImksResult { + let result = + sqlx::query("DELETE FROM message_bookmark WHERE message_id = $1 AND user_id = $2") + .bind(message_id) + .bind(user_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + /// Check if a user has bookmarked a message. + pub async fn is_bookmarked(&self, message_id: Uuid, user_id: Uuid) -> ImksResult { + let exists: Option = sqlx::query_scalar( + "SELECT id FROM message_bookmark WHERE message_id = $1 AND user_id = $2", + ) + .bind(message_id) + .bind(user_id) + .fetch_optional(self.pool()) + .await?; + + Ok(exists.is_some()) + } + + /// List a user's bookmarks with cursor-based pagination (newest first). + pub async fn list_bookmarks( + &self, + user_id: Uuid, + before: Option, + limit: Option, + ) -> ImksResult> { + let effective_limit = clamp_limit(limit); + let fetch_limit = effective_limit + 1; + + let rows = match before { + Some(cursor) => { + sqlx::query_as::<_, MessageBookmark>( + r#" + SELECT * FROM message_bookmark + WHERE user_id = $1 AND id < $2 + ORDER BY id DESC + LIMIT $3 + "#, + ) + .bind(user_id) + .bind(cursor) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + None => { + sqlx::query_as::<_, MessageBookmark>( + r#" + SELECT * FROM message_bookmark + WHERE user_id = $1 + ORDER BY id DESC + LIMIT $2 + "#, + ) + .bind(user_id) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + }; + + Ok(CursorPage::from_raw(rows, effective_limit, |b| b.id)) + } +} diff --git a/repo/message_component.rs b/repo/message_component.rs new file mode 100644 index 0000000..ddb3611 --- /dev/null +++ b/repo/message_component.rs @@ -0,0 +1,88 @@ +//! Interactive component CRUD operations on `MessageRepo`. + +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_component::MessageComponent; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Create an interactive component (button/select menu) on a message. + #[allow(clippy::too_many_arguments)] + pub async fn create_component( + &self, + message_id: Uuid, + component_type: &str, + custom_id: &str, + label: Option<&str>, + emoji: Option<&str>, + style: Option<&str>, + url: Option<&str>, + disabled: bool, + row: i32, + position: i32, + ) -> ImksResult { + sqlx::query_as::<_, MessageComponent>( + r#" + INSERT INTO message_component ( + id, message_id, row, position, component_type, custom_id, + label, emoji, style, url, disabled + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + RETURNING * + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(row) + .bind(position) + .bind(component_type) + .bind(custom_id) + .bind(label) + .bind(emoji) + .bind(style) + .bind(url) + .bind(disabled) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get all components on a message, ordered by layout. + pub async fn get_components(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageComponent>( + r#" + SELECT * FROM message_component + WHERE message_id = $1 + ORDER BY row, position + "#, + ) + .bind(message_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } + + /// Update a component's label and/or disabled state (e.g. after interaction). + pub async fn update_component( + &self, + component_id: Uuid, + label: Option<&str>, + disabled: bool, + ) -> ImksResult> { + sqlx::query_as::<_, MessageComponent>( + r#" + UPDATE message_component + SET label = COALESCE($1, label), disabled = $2 + WHERE id = $3 + RETURNING * + "#, + ) + .bind(label) + .bind(disabled) + .bind(component_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } +} diff --git a/repo/message_create.rs b/repo/message_create.rs new file mode 100644 index 0000000..785aed7 --- /dev/null +++ b/repo/message_create.rs @@ -0,0 +1,116 @@ +//! Message write operations — insert, update body, soft delete. +//! +//! All mutations use parameterized queries and return the affected row(s). + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message::{Message, new_message_id}; + +use super::message_repo::MessageRepo; + +/// Input payload for creating a new message. +#[derive(Debug, Clone)] +pub struct CreateMessageInput { + /// Target channel UUID. + pub channel_id: Uuid, + /// Author (user) UUID — extracted from JWT `sub` claim. + pub author_id: Uuid, + /// Thread this message belongs to (`None` = top-level). + pub thread_id: Option, + /// Direct reply reference (`None` = not a reply). + pub reply_to_message_id: Option, + /// Discriminator: `"text"`, `"system"`, `"event"`, `"article"`. + pub message_type: String, + /// Plain text or markdown body. + pub body: String, + /// Extensible metadata (flags, locale, etc.). + pub metadata: Option, + /// Whether this is a system/bot-generated message. + pub system: bool, +} + +impl MessageRepo { + /// Insert a new message row and return it. + /// + /// The message ID is a fresh UUID v7 (time-ordered). + pub async fn create(&self, input: &CreateMessageInput) -> ImksResult { + let id = new_message_id(); + let now = Utc::now(); + + let row = sqlx::query_as::<_, Message>( + r#" + INSERT INTO message ( + id, channel_id, author_id, thread_id, reply_to_message_id, + message_type, body, metadata, pinned, system, + edited_at, deleted_at, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, $8, FALSE, $9, + NULL, NULL, $10, $10 + ) + RETURNING * + "#, + ) + .bind(id) + .bind(input.channel_id) + .bind(input.author_id) + .bind(input.thread_id) + .bind(input.reply_to_message_id) + .bind(&input.message_type) + .bind(&input.body) + .bind(&input.metadata) + .bind(input.system) + .bind(now) + .fetch_one(self.pool()) + .await?; + + Ok(row) + } + + /// Update the body of an existing message. Sets `edited_at` and `updated_at`. + /// + /// Returns the updated row, or an error if the message is not found or deleted. + pub async fn update_body(&self, message_id: Uuid, new_body: &str) -> ImksResult { + let now = Utc::now(); + + let row = sqlx::query_as::<_, Message>( + r#" + UPDATE message + SET body = $1, edited_at = $2, updated_at = $2 + WHERE id = $3 AND deleted_at IS NULL + RETURNING * + "#, + ) + .bind(new_body) + .bind(now) + .bind(message_id) + .fetch_optional(self.pool()) + .await? + .ok_or_else(|| crate::ImksError::NotFound(format!("message {message_id}")))?; + + Ok(row) + } + + /// Soft-delete a message by setting `deleted_at`. + /// + /// Returns `Ok(())` even if the message was already deleted. + pub async fn soft_delete(&self, message_id: Uuid) -> ImksResult<()> { + let now = Utc::now(); + + sqlx::query( + r#" + UPDATE message + SET deleted_at = $1, updated_at = $1 + WHERE id = $2 AND deleted_at IS NULL + "#, + ) + .bind(now) + .bind(message_id) + .execute(self.pool()) + .await?; + + Ok(()) + } +} diff --git a/repo/message_draft.rs b/repo/message_draft.rs new file mode 100644 index 0000000..d8399ba --- /dev/null +++ b/repo/message_draft.rs @@ -0,0 +1,113 @@ +//! Draft CRUD operations on `MessageRepo`. +//! +//! One draft per (channel, user, thread). Upserted on every keystroke debounce, +//! deleted on send. Thread_id=NULL uses a dedicated conflict target. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_draft::MessageDraft; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Upsert a draft for the given (channel, user, thread) key. + /// Uses NULL-safe conflict handling via COALESCE. + pub async fn upsert_draft( + &self, + channel_id: Uuid, + user_id: Uuid, + thread_id: Option, + body: &str, + reply_to_message_id: Option, + metadata: Option, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + let query = if thread_id.is_some() { + r#" + INSERT INTO message_draft ( + id, channel_id, user_id, thread_id, reply_to_message_id, + body, metadata, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8) + ON CONFLICT (channel_id, user_id, thread_id) DO UPDATE SET + body = EXCLUDED.body, + reply_to_message_id = EXCLUDED.reply_to_message_id, + metadata = EXCLUDED.metadata, + updated_at = EXCLUDED.updated_at + RETURNING * + "# + } else { + r#" + INSERT INTO message_draft ( + id, channel_id, user_id, thread_id, reply_to_message_id, + body, metadata, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8) + ON CONFLICT (channel_id, user_id) WHERE thread_id IS NULL DO UPDATE SET + body = EXCLUDED.body, + reply_to_message_id = EXCLUDED.reply_to_message_id, + metadata = EXCLUDED.metadata, + updated_at = EXCLUDED.updated_at + RETURNING * + "# + }; + + sqlx::query_as::<_, MessageDraft>(query) + .bind(id) + .bind(channel_id) + .bind(user_id) + .bind(thread_id) + .bind(reply_to_message_id) + .bind(body) + .bind(metadata) + .bind(now) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get a user's draft for a channel (optionally scoped to a thread). + pub async fn get_draft( + &self, + channel_id: Uuid, + user_id: Uuid, + thread_id: Option, + ) -> ImksResult> { + sqlx::query_as::<_, MessageDraft>( + r#" + SELECT * FROM message_draft + WHERE channel_id = $1 AND user_id = $2 AND thread_id IS NOT DISTINCT FROM $3 + "#, + ) + .bind(channel_id) + .bind(user_id) + .bind(thread_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } + + /// Delete a draft after the message is sent. + pub async fn delete_draft( + &self, + channel_id: Uuid, + user_id: Uuid, + thread_id: Option, + ) -> ImksResult { + let result = sqlx::query( + r#" + DELETE FROM message_draft + WHERE channel_id = $1 AND user_id = $2 AND thread_id IS NOT DISTINCT FROM $3 + "#, + ) + .bind(channel_id) + .bind(user_id) + .bind(thread_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } +} diff --git a/repo/message_edit.rs b/repo/message_edit.rs new file mode 100644 index 0000000..5d1959b --- /dev/null +++ b/repo/message_edit.rs @@ -0,0 +1,77 @@ +//! Edit history CRUD operations on `MessageRepo`. +//! +//! Immutable append-only log. Every message body edit creates one row. + +use chrono::Utc; +use sqlx::Row; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_edit::{EditSummary, MessageEdit}; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Record an edit to a message's body. + pub async fn record_edit( + &self, + message_id: Uuid, + edited_by: Uuid, + old_body: &str, + new_body: &str, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + sqlx::query_as::<_, MessageEdit>( + r#" + INSERT INTO message_edit (id, message_id, edited_by, old_body, new_body, edited_at) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING * + "#, + ) + .bind(id) + .bind(message_id) + .bind(edited_by) + .bind(old_body) + .bind(new_body) + .bind(now) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get the full edit history for a message, oldest first. + pub async fn get_edit_history(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageEdit>( + "SELECT * FROM message_edit WHERE message_id = $1 ORDER BY edited_at ASC", + ) + .bind(message_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } + + /// Get a summary of edits for a message (count + last editor). + pub async fn get_edit_summary(&self, message_id: Uuid) -> ImksResult { + let row = sqlx::query( + r#" + SELECT + COUNT(*)::BIGINT AS edit_count, + MAX(edited_at) AS last_edited_at, + (ARRAY_AGG(edited_by ORDER BY edited_at DESC))[1] AS last_edited_by + FROM message_edit + WHERE message_id = $1 + "#, + ) + .bind(message_id) + .fetch_one(self.pool()) + .await?; + + Ok(EditSummary { + edit_count: row.get("edit_count"), + last_edited_at: row.get("last_edited_at"), + last_edited_by: row.get("last_edited_by"), + }) + } +} diff --git a/repo/message_embed.rs b/repo/message_embed.rs new file mode 100644 index 0000000..84258ca --- /dev/null +++ b/repo/message_embed.rs @@ -0,0 +1,106 @@ +//! Embed CRUD operations on `MessageRepo`. + +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_embed::{EmbedDetail, MessageEmbed, MessageEmbedField}; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Create an embed with its fields. Returns the embed (fields fetched separately). + #[allow(clippy::too_many_arguments)] + pub async fn create_embed( + &self, + message_id: Uuid, + embed_type: &str, + title: Option<&str>, + description: Option<&str>, + url: Option<&str>, + color: Option, + image_url: Option<&str>, + author_name: Option<&str>, + author_url: Option<&str>, + footer_text: Option<&str>, + provider_name: Option<&str>, + fields: &[(String, String, bool)], + ) -> ImksResult { + let embed_id = Uuid::now_v7(); + + let embed = sqlx::query_as::<_, MessageEmbed>( + r#" + INSERT INTO message_embed ( + id, message_id, embed_type, title, description, url, color, + image_url, author_name, author_url, footer_text, provider_name + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + RETURNING * + "#, + ) + .bind(embed_id) + .bind(message_id) + .bind(embed_type) + .bind(title) + .bind(description) + .bind(url) + .bind(color) + .bind(image_url) + .bind(author_name) + .bind(author_url) + .bind(footer_text) + .bind(provider_name) + .fetch_one(self.pool()) + .await?; + + for (i, (name, value, inline)) in fields.iter().enumerate() { + sqlx::query( + r#" + INSERT INTO message_embed_field (id, embed_id, name, value, inline, position) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + ) + .bind(Uuid::now_v7()) + .bind(embed_id) + .bind(name) + .bind(value) + .bind(inline) + .bind(i as i32) + .execute(self.pool()) + .await?; + } + + Ok(embed) + } + + /// Get all embeds for a message, including their fields. + pub async fn get_embeds(&self, message_id: Uuid) -> ImksResult> { + let embeds: Vec = + sqlx::query_as("SELECT * FROM message_embed WHERE message_id = $1 ORDER BY created_at") + .bind(message_id) + .fetch_all(self.pool()) + .await?; + + let mut result = Vec::with_capacity(embeds.len()); + for embed in embeds { + let fields: Vec = sqlx::query_as( + "SELECT * FROM message_embed_field WHERE embed_id = $1 ORDER BY position", + ) + .bind(embed.id) + .fetch_all(self.pool()) + .await?; + + result.push(EmbedDetail { embed, fields }); + } + + Ok(result) + } + + /// Delete an embed (fields cascade via FK). + pub async fn delete_embed(&self, embed_id: Uuid) -> ImksResult { + let result = sqlx::query("DELETE FROM message_embed WHERE id = $1") + .bind(embed_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } +} diff --git a/repo/message_forward.rs b/repo/message_forward.rs new file mode 100644 index 0000000..01fca33 --- /dev/null +++ b/repo/message_forward.rs @@ -0,0 +1,44 @@ +//! Message forward CRUD operations on `MessageRepo`. + +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_forward::MessageForward; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Record a forwarded message's provenance. + pub async fn record_forward( + &self, + message_id: Uuid, + source_message_id: Uuid, + source_channel_id: Uuid, + forwarded_by: Uuid, + ) -> ImksResult { + sqlx::query_as::<_, MessageForward>( + r#" + INSERT INTO message_forward (id, message_id, source_message_id, source_channel_id, forwarded_by) + VALUES ($1, $2, $3, $4, $5) + RETURNING * + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(source_message_id) + .bind(source_channel_id) + .bind(forwarded_by) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get forwarding info for a message. + pub async fn get_forward_info(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageForward>("SELECT * FROM message_forward WHERE message_id = $1") + .bind(message_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } +} diff --git a/repo/message_mention.rs b/repo/message_mention.rs new file mode 100644 index 0000000..ce69993 --- /dev/null +++ b/repo/message_mention.rs @@ -0,0 +1,111 @@ +//! Mention CRUD operations on `MessageRepo`. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_mention::MessageMention; + +use super::message_repo::MessageRepo; +use super::pagination::{CursorPage, clamp_limit}; + +impl MessageRepo { + /// Bulk record mentions for a message. Called right after message creation. + pub async fn record_mentions( + &self, + message_id: Uuid, + channel_id: Uuid, + mentioned_by: Uuid, + mentioned_user_ids: &[Uuid], + ) -> ImksResult<()> { + if mentioned_user_ids.is_empty() { + return Ok(()); + } + + let now = Utc::now(); + for &user_id in mentioned_user_ids { + sqlx::query( + r#" + INSERT INTO message_mention (id, message_id, channel_id, mentioned_user_id, mentioned_by, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(channel_id) + .bind(user_id) + .bind(mentioned_by) + .bind(now) + .execute(self.pool()) + .await?; + } + + Ok(()) + } + + /// List mentions for a specific user, newest first. + pub async fn list_mentions_for_user( + &self, + user_id: Uuid, + before: Option, + limit: Option, + ) -> ImksResult> { + let effective_limit = clamp_limit(limit); + let fetch_limit = effective_limit + 1; + + let rows = match before { + Some(cursor) => { + sqlx::query_as::<_, MessageMention>( + r#" + SELECT * FROM message_mention + WHERE mentioned_user_id = $1 AND id < $2 + ORDER BY id DESC + LIMIT $3 + "#, + ) + .bind(user_id) + .bind(cursor) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + None => { + sqlx::query_as::<_, MessageMention>( + r#" + SELECT * FROM message_mention + WHERE mentioned_user_id = $1 + ORDER BY id DESC + LIMIT $2 + "#, + ) + .bind(user_id) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + }; + + Ok(CursorPage::from_raw(rows, effective_limit, |m| m.id)) + } + + /// Mark a mention as read. + pub async fn mark_mention_read(&self, mention_id: Uuid) -> ImksResult<()> { + sqlx::query("UPDATE message_mention SET read_at = $1 WHERE id = $2") + .bind(Utc::now()) + .bind(mention_id) + .execute(self.pool()) + .await?; + + Ok(()) + } + + /// Delete all mentions for a message (e.g. on message delete). + pub async fn delete_mentions(&self, message_id: Uuid) -> ImksResult<()> { + sqlx::query("DELETE FROM message_mention WHERE message_id = $1") + .bind(message_id) + .execute(self.pool()) + .await?; + + Ok(()) + } +} diff --git a/repo/message_notification.rs b/repo/message_notification.rs new file mode 100644 index 0000000..ef12307 --- /dev/null +++ b/repo/message_notification.rs @@ -0,0 +1,140 @@ +//! Notification CRUD operations on `MessageRepo`. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_notification::MessageNotification; + +use super::message_repo::MessageRepo; +use super::pagination::{CursorPage, clamp_limit}; + +impl MessageRepo { + /// Create a notification for a user triggered by a message. + pub async fn create_notification( + &self, + message_id: Uuid, + channel_id: Uuid, + user_id: Uuid, + reason: &str, + delivery_channel: Option<&str>, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + sqlx::query_as::<_, MessageNotification>( + r#" + INSERT INTO message_notification ( + id, message_id, channel_id, user_id, reason, status, delivery_channel, created_at + ) VALUES ($1, $2, $3, $4, $5, 'pending', $6, $7) + RETURNING * + "#, + ) + .bind(id) + .bind(message_id) + .bind(channel_id) + .bind(user_id) + .bind(reason) + .bind(delivery_channel) + .bind(now) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Mark a notification as read. + pub async fn mark_notification_read(&self, notification_id: Uuid) -> ImksResult<()> { + let now = Utc::now(); + sqlx::query( + r#" + UPDATE message_notification + SET status = 'read', read_at = $1 + WHERE id = $2 + "#, + ) + .bind(now) + .bind(notification_id) + .execute(self.pool()) + .await?; + + Ok(()) + } + + /// Mark all of a user's notifications as read. + pub async fn mark_all_notifications_read(&self, user_id: Uuid) -> ImksResult { + let now = Utc::now(); + let result = sqlx::query( + r#" + UPDATE message_notification + SET status = 'read', read_at = $1 + WHERE user_id = $2 AND status != 'read' + "#, + ) + .bind(now) + .bind(user_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected()) + } + + /// List notifications for a user, newest first. + pub async fn list_notifications( + &self, + user_id: Uuid, + before: Option, + limit: Option, + ) -> ImksResult> { + let effective_limit = clamp_limit(limit); + let fetch_limit = effective_limit + 1; + + let rows = match before { + Some(cursor) => { + sqlx::query_as::<_, MessageNotification>( + r#" + SELECT * FROM message_notification + WHERE user_id = $1 AND id < $2 + ORDER BY id DESC + LIMIT $3 + "#, + ) + .bind(user_id) + .bind(cursor) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + None => { + sqlx::query_as::<_, MessageNotification>( + r#" + SELECT * FROM message_notification + WHERE user_id = $1 + ORDER BY id DESC + LIMIT $2 + "#, + ) + .bind(user_id) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + }; + + Ok(CursorPage::from_raw(rows, effective_limit, |n| n.id)) + } + + /// Get unread notification count for a user. + pub async fn get_unread_notification_count(&self, user_id: Uuid) -> ImksResult { + let count: i64 = sqlx::query_scalar( + r#" + SELECT COUNT(*)::BIGINT FROM message_notification + WHERE user_id = $1 AND status = 'pending' + "#, + ) + .bind(user_id) + .fetch_one(self.pool()) + .await?; + + Ok(count) + } +} diff --git a/repo/message_pin.rs b/repo/message_pin.rs new file mode 100644 index 0000000..e3d3b46 --- /dev/null +++ b/repo/message_pin.rs @@ -0,0 +1,110 @@ +//! Pin CRUD operations on `MessageRepo`. +//! +//! One row per pinned message. Channels can have multiple pinned messages. +//! `position` auto-calculated as `MAX(position) + 1` within the channel. + +use chrono::Utc; +use sqlx::Row; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_pin::{MessagePin, PinDetail}; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Pin a message in a channel. Computes the next position automatically. + pub async fn pin_message( + &self, + channel_id: Uuid, + message_id: Uuid, + pinned_by: Uuid, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + let mut tx = self.pool().begin().await?; + + sqlx::query("SELECT pg_advisory_xact_lock(hashtextextended($1, 0))") + .bind(channel_id.to_string()) + .execute(&mut *tx) + .await?; + + let max_pos: Option = sqlx::query_scalar( + "SELECT COALESCE(MAX(position), -1) FROM message_pin WHERE channel_id = $1", + ) + .bind(channel_id) + .fetch_one(&mut *tx) + .await?; + + let position = max_pos.unwrap_or(-1) + 1; + + let pin = sqlx::query_as::<_, MessagePin>( + r#" + INSERT INTO message_pin (id, channel_id, message_id, pinned_by, position, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (channel_id, message_id) DO NOTHING + RETURNING * + "#, + ) + .bind(id) + .bind(channel_id) + .bind(message_id) + .bind(pinned_by) + .bind(position) + .bind(now) + .fetch_optional(&mut *tx) + .await? + .ok_or_else(|| crate::ImksError::InvalidInput("Message already pinned".into()))?; + + tx.commit().await?; + Ok(pin) + } + + /// Unpin a message from a channel. + pub async fn unpin_message(&self, channel_id: Uuid, message_id: Uuid) -> ImksResult { + let result = + sqlx::query("DELETE FROM message_pin WHERE channel_id = $1 AND message_id = $2") + .bind(channel_id) + .bind(message_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + /// List all pinned messages in a channel, newest first, joined with message content. + pub async fn list_pins(&self, channel_id: Uuid) -> ImksResult> { + let rows = sqlx::query( + r#" + SELECT p.*, m.body AS message_body, m.author_id AS message_author_id, + m.created_at AS message_created_at + FROM message_pin p + JOIN message m ON m.id = p.message_id + WHERE p.channel_id = $1 AND m.deleted_at IS NULL + ORDER BY p.position ASC + "#, + ) + .bind(channel_id) + .fetch_all(self.pool()) + .await?; + + let result = rows + .into_iter() + .map(|row| PinDetail { + pin: MessagePin { + id: row.get("id"), + channel_id: row.get("channel_id"), + message_id: row.get("message_id"), + pinned_by: row.get("pinned_by"), + position: row.get("position"), + created_at: row.get("created_at"), + }, + message_body: row.get("message_body"), + message_author_id: row.get("message_author_id"), + message_created_at: row.get("message_created_at"), + }) + .collect(); + + Ok(result) + } +} diff --git a/repo/message_poll.rs b/repo/message_poll.rs new file mode 100644 index 0000000..bd07a33 --- /dev/null +++ b/repo/message_poll.rs @@ -0,0 +1,396 @@ +//! Poll CRUD operations on `MessageRepo`. +//! +//! Handles poll creation (with options), voting (with denormalized counts), +//! and result retrieval. + +use chrono::Utc; +use sqlx::Row; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_poll::{MessagePoll, MessagePollOption, MessagePollVote, PollResult}; + +use super::message_repo::MessageRepo; + +/// Canonical poll target resolved from the database. +#[derive(Debug, Clone, Copy)] +pub struct PollTarget { + pub poll_id: Uuid, + pub message_id: Uuid, + pub channel_id: Uuid, +} + +impl MessageRepo { + /// Resolve and validate a poll option's canonical message/channel target. + pub async fn get_poll_target(&self, poll_id: Uuid, option_id: Uuid) -> ImksResult { + let row = sqlx::query( + r#" + SELECT p.id AS poll_id, p.message_id, m.channel_id, o.id AS option_id + FROM message_poll p + JOIN message m ON m.id = p.message_id + JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2 + WHERE p.id = $1 AND m.deleted_at IS NULL + "#, + ) + .bind(poll_id) + .bind(option_id) + .fetch_optional(self.pool()) + .await? + .ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?; + + Ok(PollTarget { + poll_id: row.get("poll_id"), + message_id: row.get("message_id"), + channel_id: row.get("channel_id"), + }) + } + + /// Create a poll with its options. Returns the poll (options fetched separately). + pub async fn create_poll( + &self, + message_id: Uuid, + question: &str, + allow_multiselect: bool, + max_selections: Option, + expires_at: Option>, + options: &[(String, Option)], + ) -> ImksResult { + let poll_id = Uuid::now_v7(); + let now = Utc::now(); + + let poll = sqlx::query_as::<_, MessagePoll>( + r#" + INSERT INTO message_poll ( + id, message_id, question, allow_multiselect, + max_selections, expires_at, total_votes, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, 0, $7, $7) + RETURNING * + "#, + ) + .bind(poll_id) + .bind(message_id) + .bind(question) + .bind(allow_multiselect) + .bind(max_selections) + .bind(expires_at) + .bind(now) + .fetch_one(self.pool()) + .await?; + + for (i, (text, emoji)) in options.iter().enumerate() { + let opt_id = Uuid::now_v7(); + sqlx::query( + r#" + INSERT INTO message_poll_option (id, poll_id, text, emoji, vote_count, position) + VALUES ($1, $2, $3, $4, 0, $5) + "#, + ) + .bind(opt_id) + .bind(poll_id) + .bind(text) + .bind(emoji.as_deref()) + .bind(i as i32) + .execute(self.pool()) + .await?; + } + + Ok(poll) + } + + /// Cast a validated vote and return the canonical message/channel target. + pub async fn cast_vote_checked( + &self, + poll_id: Uuid, + option_id: Uuid, + user_id: Uuid, + ) -> ImksResult { + let mut tx = self.pool().begin().await?; + let now = Utc::now(); + + let row = sqlx::query( + r#" + SELECT p.id AS poll_id, p.message_id, p.allow_multiselect, p.max_selections, + p.expires_at, m.channel_id, o.id AS option_id + FROM message_poll p + JOIN message m ON m.id = p.message_id + JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2 + WHERE p.id = $1 AND m.deleted_at IS NULL + FOR UPDATE OF p + "#, + ) + .bind(poll_id) + .bind(option_id) + .fetch_optional(&mut *tx) + .await? + .ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?; + + let expires_at: Option> = row.get("expires_at"); + if expires_at.is_some_and(|exp| now >= exp) { + return Err(crate::ImksError::InvalidInput("Poll has expired".into())); + } + + let allow_multiselect: bool = row.get("allow_multiselect"); + let max_selections: Option = row.get("max_selections"); + let current_votes: Vec = sqlx::query_scalar( + "SELECT option_id FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2", + ) + .bind(poll_id) + .bind(user_id) + .fetch_all(&mut *tx) + .await?; + + if current_votes.contains(&option_id) { + return Err(crate::ImksError::InvalidInput( + "Already voted for this option".into(), + )); + } + if !allow_multiselect && !current_votes.is_empty() { + return Err(crate::ImksError::InvalidInput( + "Poll allows only one selection".into(), + )); + } + if let Some(max) = max_selections + && current_votes.len() >= max.max(1) as usize + { + return Err(crate::ImksError::InvalidInput( + "Poll selection limit exceeded".into(), + )); + } + + let vote_id = Uuid::now_v7(); + sqlx::query( + r#" + INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, created_at) + VALUES ($1, $2, $3, $4, $5) + "#, + ) + .bind(vote_id) + .bind(poll_id) + .bind(option_id) + .bind(user_id) + .bind(now) + .execute(&mut *tx) + .await?; + + sqlx::query("UPDATE message_poll_option SET vote_count = vote_count + 1 WHERE id = $1") + .bind(option_id) + .execute(&mut *tx) + .await?; + + sqlx::query( + "UPDATE message_poll SET total_votes = total_votes + 1, updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(poll_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(PollTarget { + poll_id, + message_id: row.get("message_id"), + channel_id: row.get("channel_id"), + }) + } + + /// Remove a validated vote and return the canonical message/channel target. + pub async fn remove_vote_checked( + &self, + poll_id: Uuid, + option_id: Uuid, + user_id: Uuid, + ) -> ImksResult> { + let mut tx = self.pool().begin().await?; + let now = Utc::now(); + + let row = sqlx::query( + r#" + SELECT p.id AS poll_id, p.message_id, m.channel_id, o.id AS option_id + FROM message_poll p + JOIN message m ON m.id = p.message_id + JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2 + WHERE p.id = $1 AND m.deleted_at IS NULL + FOR UPDATE OF p + "#, + ) + .bind(poll_id) + .bind(option_id) + .fetch_optional(&mut *tx) + .await? + .ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?; + + let result = sqlx::query( + r#" + DELETE FROM message_poll_vote + WHERE poll_id = $1 AND option_id = $2 AND user_id = $3 + "#, + ) + .bind(poll_id) + .bind(option_id) + .bind(user_id) + .execute(&mut *tx) + .await?; + + if result.rows_affected() == 0 { + tx.commit().await?; + return Ok(None); + } + + sqlx::query( + "UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1", + ) + .bind(option_id) + .execute(&mut *tx) + .await?; + + sqlx::query( + "UPDATE message_poll SET total_votes = GREATEST(total_votes - 1, 0), updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(poll_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(Some(PollTarget { + poll_id, + message_id: row.get("message_id"), + channel_id: row.get("channel_id"), + })) + } + + /// Cast a vote. Increments denormalized counts atomically. + pub async fn vote( + &self, + poll_id: Uuid, + option_id: Uuid, + user_id: Uuid, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + let vote = sqlx::query_as::<_, MessagePollVote>( + r#" + INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, created_at) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (poll_id, user_id, option_id) DO NOTHING + RETURNING * + "#, + ) + .bind(id) + .bind(poll_id) + .bind(option_id) + .bind(user_id) + .bind(now) + .fetch_optional(self.pool()) + .await? + .ok_or_else(|| crate::ImksError::InvalidInput("Already voted for this option".into()))?; + + sqlx::query("UPDATE message_poll_option SET vote_count = vote_count + 1 WHERE id = $1") + .bind(option_id) + .execute(self.pool()) + .await?; + + sqlx::query( + "UPDATE message_poll SET total_votes = total_votes + 1, updated_at = $1 WHERE id = $2", + ) + .bind(now) + .bind(poll_id) + .execute(self.pool()) + .await?; + + Ok(vote) + } + + /// Remove a vote. Decrements denormalized counts. + pub async fn remove_vote( + &self, + poll_id: Uuid, + option_id: Uuid, + user_id: Uuid, + ) -> ImksResult { + let result = sqlx::query( + r#" + DELETE FROM message_poll_vote + WHERE poll_id = $1 AND option_id = $2 AND user_id = $3 + "#, + ) + .bind(poll_id) + .bind(option_id) + .bind(user_id) + .execute(self.pool()) + .await?; + + if result.rows_affected() == 0 { + return Ok(false); + } + + sqlx::query( + "UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1", + ) + .bind(option_id) + .execute(self.pool()) + .await?; + + sqlx::query( + "UPDATE message_poll SET total_votes = GREATEST(total_votes - 1, 0), updated_at = $1 WHERE id = $2", + ) + .bind(Utc::now()) + .bind(poll_id) + .execute(self.pool()) + .await?; + + Ok(true) + } + + /// Get full poll results including options, vote counts, and the given user's votes. + pub async fn get_poll_result( + &self, + message_id: Uuid, + user_id: Uuid, + ) -> ImksResult> { + let poll = + sqlx::query_as::<_, MessagePoll>("SELECT * FROM message_poll WHERE message_id = $1") + .bind(message_id) + .fetch_optional(self.pool()) + .await?; + + let Some(poll) = poll else { + return Ok(None); + }; + + let options: Vec = sqlx::query_as( + "SELECT * FROM message_poll_option WHERE poll_id = $1 ORDER BY position", + ) + .bind(poll.id) + .fetch_all(self.pool()) + .await?; + + let my_votes: Vec = sqlx::query_scalar( + "SELECT option_id FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2", + ) + .bind(poll.id) + .bind(user_id) + .fetch_all(self.pool()) + .await?; + + Ok(Some(PollResult::from_poll(poll, options, my_votes))) + } + + /// Close a poll by setting its expiration to now. + pub async fn close_poll(&self, message_id: Uuid) -> ImksResult<()> { + let now = Utc::now(); + sqlx::query( + r#" + UPDATE message_poll + SET expires_at = $1, updated_at = $1 + WHERE message_id = $2 AND (expires_at IS NULL OR expires_at > $1) + "#, + ) + .bind(now) + .bind(message_id) + .execute(self.pool()) + .await?; + Ok(()) + } +} diff --git a/repo/message_query.rs b/repo/message_query.rs new file mode 100644 index 0000000..1619ca1 --- /dev/null +++ b/repo/message_query.rs @@ -0,0 +1,169 @@ +//! Message read operations — get single, list by channel, list by thread. +//! +//! All list queries use UUID v7 cursor-based pagination (no OFFSET). + +use sqlx::Row; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message::Message; + +use super::message_repo::MessageRepo; +use super::pagination::{CursorPage, clamp_limit}; + +impl MessageRepo { + /// Fetch a single message by ID. + /// + /// Returns `None` if the message doesn't exist or has been soft-deleted. + pub async fn get(&self, message_id: Uuid) -> ImksResult> { + let row = sqlx::query_as::<_, Message>( + "SELECT * FROM message WHERE id = $1 AND deleted_at IS NULL", + ) + .bind(message_id) + .fetch_optional(self.pool()) + .await?; + + Ok(row) + } + + /// List messages in a channel with cursor-based pagination. + /// + /// Returns messages in reverse chronological order (newest first). + /// Pass `before` as the last message ID from the previous page to + /// fetch the next page. + pub async fn list_by_channel( + &self, + channel_id: Uuid, + before: Option, + limit: Option, + ) -> ImksResult> { + let effective_limit = clamp_limit(limit); + // Fetch one extra row to determine `has_more`. + let fetch_limit = effective_limit + 1; + + let rows = match before { + Some(cursor) => { + sqlx::query_as::<_, Message>( + r#" + SELECT * FROM message + WHERE channel_id = $1 + AND deleted_at IS NULL + AND id < $2 + ORDER BY id DESC + LIMIT $3 + "#, + ) + .bind(channel_id) + .bind(cursor) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + None => { + sqlx::query_as::<_, Message>( + r#" + SELECT * FROM message + WHERE channel_id = $1 + AND deleted_at IS NULL + ORDER BY id DESC + LIMIT $2 + "#, + ) + .bind(channel_id) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + }; + + Ok(CursorPage::from_raw(rows, effective_limit, |m| m.id)) + } + + /// List messages in a thread with cursor-based pagination. + pub async fn list_by_thread( + &self, + thread_id: Uuid, + before: Option, + limit: Option, + ) -> ImksResult> { + let effective_limit = clamp_limit(limit); + let fetch_limit = effective_limit + 1; + + let rows = match before { + Some(cursor) => { + sqlx::query_as::<_, Message>( + r#" + SELECT * FROM message + WHERE thread_id = $1 + AND deleted_at IS NULL + AND id < $2 + ORDER BY id DESC + LIMIT $3 + "#, + ) + .bind(thread_id) + .bind(cursor) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + None => { + sqlx::query_as::<_, Message>( + r#" + SELECT * FROM message + WHERE thread_id = $1 + AND deleted_at IS NULL + ORDER BY id DESC + LIMIT $2 + "#, + ) + .bind(thread_id) + .bind(fetch_limit) + .fetch_all(self.pool()) + .await? + } + }; + + Ok(CursorPage::from_raw(rows, effective_limit, |m| m.id)) + } + + /// Get message reaction counts grouped by emoji content. + /// + /// Returns `(content, count)` pairs for the given message. + pub async fn get_reaction_counts(&self, message_id: Uuid) -> ImksResult> { + let rows = sqlx::query( + r#" + SELECT content, COUNT(*)::BIGINT AS cnt + FROM message_reaction + WHERE message_id = $1 + GROUP BY content + "#, + ) + .bind(message_id) + .fetch_all(self.pool()) + .await?; + + Ok(rows + .into_iter() + .map(|r| (r.get("content"), r.get("cnt"))) + .collect()) + } + + /// Count attachments and embeds for a message. + pub async fn get_content_counts(&self, message_id: Uuid) -> ImksResult<(i64, i64)> { + let att_row = sqlx::query( + "SELECT COUNT(*)::BIGINT AS cnt FROM message_attachment WHERE message_id = $1", + ) + .bind(message_id) + .fetch_one(self.pool()) + .await?; + + let emb_row = + sqlx::query("SELECT COUNT(*)::BIGINT AS cnt FROM message_embed WHERE message_id = $1") + .bind(message_id) + .fetch_one(self.pool()) + .await?; + + Ok((att_row.get("cnt"), emb_row.get("cnt"))) + } +} diff --git a/repo/message_reaction.rs b/repo/message_reaction.rs new file mode 100644 index 0000000..b238497 --- /dev/null +++ b/repo/message_reaction.rs @@ -0,0 +1,94 @@ +//! Reaction CRUD operations on `MessageRepo`. +//! +//! Each (message, user, content) tuple is unique via ON CONFLICT. +//! Toggle semantics: same request adds/removes the reaction. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_reaction::MessageReaction; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Add or toggle a reaction. Returns the reaction if added, None if already exists. + pub async fn add_reaction( + &self, + message_id: Uuid, + channel_id: Uuid, + user_id: Uuid, + content: &str, + ) -> ImksResult> { + let id = Uuid::now_v7(); + let now = Utc::now(); + + let row = sqlx::query_as::<_, MessageReaction>( + r#" + INSERT INTO message_reaction (id, message_id, channel_id, user_id, content, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (message_id, user_id, content) DO NOTHING + RETURNING * + "#, + ) + .bind(id) + .bind(message_id) + .bind(channel_id) + .bind(user_id) + .bind(content) + .bind(now) + .fetch_optional(self.pool()) + .await?; + + Ok(row) + } + + /// Remove a user's reaction from a message. + pub async fn remove_reaction( + &self, + message_id: Uuid, + user_id: Uuid, + content: &str, + ) -> ImksResult { + let result = sqlx::query( + r#" + DELETE FROM message_reaction + WHERE message_id = $1 AND user_id = $2 AND content = $3 + "#, + ) + .bind(message_id) + .bind(user_id) + .bind(content) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + /// Get all reactions on a message. + pub async fn get_reactions(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageReaction>( + "SELECT * FROM message_reaction WHERE message_id = $1 ORDER BY created_at", + ) + .bind(message_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } + + /// Get a specific user's reactions on a message. + pub async fn get_user_reactions( + &self, + message_id: Uuid, + user_id: Uuid, + ) -> ImksResult> { + sqlx::query_as::<_, MessageReaction>( + "SELECT * FROM message_reaction WHERE message_id = $1 AND user_id = $2 ORDER BY created_at", + ) + .bind(message_id) + .bind(user_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } +} diff --git a/repo/message_read_state.rs b/repo/message_read_state.rs new file mode 100644 index 0000000..ce4ba87 --- /dev/null +++ b/repo/message_read_state.rs @@ -0,0 +1,110 @@ +//! Read state CRUD operations on `MessageRepo`. +//! +//! One row per (channel, user). Upserted on each read; ON CONFLICT DO UPDATE +//! advances the cursor and recalculates unread counts. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_read_state::MessageReadState; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Mark a channel as read up to a given message for a user. + /// Recalculates unread_count and unread_mentions from the database. + pub async fn mark_read( + &self, + channel_id: Uuid, + user_id: Uuid, + last_read_message_id: Uuid, + ) -> ImksResult { + let now = Utc::now(); + + let unread_count: i64 = sqlx::query_scalar( + r#" + SELECT COUNT(*)::BIGINT + FROM message + WHERE channel_id = $1 + AND deleted_at IS NULL + AND id > $2 + AND author_id != $3 + "#, + ) + .bind(channel_id) + .bind(last_read_message_id) + .bind(user_id) + .fetch_one(self.pool()) + .await?; + + let unread_mentions: i64 = sqlx::query_scalar( + r#" + SELECT COUNT(*)::BIGINT + FROM message_mention mm + WHERE mm.channel_id = $1 + AND mm.mentioned_user_id = $2 + AND mm.message_id > $3 + AND mm.read_at IS NULL + "#, + ) + .bind(channel_id) + .bind(user_id) + .bind(last_read_message_id) + .fetch_one(self.pool()) + .await?; + + let id = Uuid::now_v7(); + + sqlx::query_as::<_, MessageReadState>( + r#" + INSERT INTO message_read_state ( + id, channel_id, user_id, last_read_message_id, last_read_at, + unread_count, unread_mentions, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8) + ON CONFLICT (channel_id, user_id) DO UPDATE SET + last_read_message_id = EXCLUDED.last_read_message_id, + last_read_at = EXCLUDED.last_read_at, + unread_count = EXCLUDED.unread_count, + unread_mentions = EXCLUDED.unread_mentions, + updated_at = EXCLUDED.updated_at + RETURNING * + "#, + ) + .bind(id) + .bind(channel_id) + .bind(user_id) + .bind(last_read_message_id) + .bind(now) + .bind(unread_count) + .bind(unread_mentions) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get a user's read state for a channel. + pub async fn get_read_state( + &self, + channel_id: Uuid, + user_id: Uuid, + ) -> ImksResult> { + sqlx::query_as::<_, MessageReadState>( + "SELECT * FROM message_read_state WHERE channel_id = $1 AND user_id = $2", + ) + .bind(channel_id) + .bind(user_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } + + /// Get read state summaries for all channels a user participates in. + pub async fn get_user_read_states(&self, user_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageReadState>("SELECT * FROM message_read_state WHERE user_id = $1") + .bind(user_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } +} diff --git a/repo/message_repo.rs b/repo/message_repo.rs new file mode 100644 index 0000000..f2f22ae --- /dev/null +++ b/repo/message_repo.rs @@ -0,0 +1,27 @@ +//! Message repository — struct definition and pool accessor. +//! +//! CRUD operations are split across `message_create.rs` (writes) +//! and `message_query.rs` (reads) as separate `impl` blocks. + +use sqlx::PgPool; + +/// Repository for message CRUD operations. +/// +/// All queries use parameterized statements via sqlx. +/// IDs are UUID v7 (time-ordered) for efficient cursor pagination. +#[derive(Clone)] +pub struct MessageRepo { + pool: PgPool, +} + +impl MessageRepo { + /// Create a new repository backed by the given connection pool. + pub fn new(pool: PgPool) -> Self { + Self { pool } + } + + /// Access the inner `PgPool` for advanced queries. + pub fn pool(&self) -> &PgPool { + &self.pool + } +} diff --git a/repo/message_scheduled.rs b/repo/message_scheduled.rs new file mode 100644 index 0000000..2f5d0ea --- /dev/null +++ b/repo/message_scheduled.rs @@ -0,0 +1,159 @@ +//! Scheduled message CRUD operations on `MessageRepo`. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_scheduled::MessageScheduled; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Schedule a message to be sent later. + #[allow(clippy::too_many_arguments)] + pub async fn schedule_message( + &self, + channel_id: Uuid, + author_id: Uuid, + thread_id: Option, + reply_to_message_id: Option, + body: &str, + metadata: Option, + scheduled_at: chrono::DateTime, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + sqlx::query_as::<_, MessageScheduled>( + r#" + INSERT INTO message_scheduled ( + id, channel_id, author_id, thread_id, reply_to_message_id, + body, metadata, scheduled_at, status, created_at, updated_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'pending', $9, $9) + RETURNING * + "#, + ) + .bind(id) + .bind(channel_id) + .bind(author_id) + .bind(thread_id) + .bind(reply_to_message_id) + .bind(body) + .bind(metadata) + .bind(scheduled_at) + .bind(now) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Cancel a scheduled message (only if still pending). + pub async fn cancel_scheduled(&self, scheduled_id: Uuid) -> ImksResult { + let result = sqlx::query( + r#" + UPDATE message_scheduled + SET status = 'cancelled', updated_at = $1 + WHERE id = $2 AND status = 'pending' + "#, + ) + .bind(Utc::now()) + .bind(scheduled_id) + .execute(self.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + /// List a user's scheduled messages. + pub async fn list_scheduled( + &self, + channel_id: Uuid, + author_id: Uuid, + ) -> ImksResult> { + sqlx::query_as::<_, MessageScheduled>( + r#" + SELECT * FROM message_scheduled + WHERE channel_id = $1 AND author_id = $2 AND status = 'pending' + ORDER BY scheduled_at ASC + "#, + ) + .bind(channel_id) + .bind(author_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } + + /// Atomically claim due scheduled messages for background dispatch. + pub async fn claim_due_scheduled(&self) -> ImksResult> { + let mut tx = self.pool().begin().await?; + let now = Utc::now(); + + let rows = sqlx::query_as::<_, MessageScheduled>( + r#" + UPDATE message_scheduled + SET status = 'processing', updated_at = $1 + WHERE id IN ( + SELECT id + FROM message_scheduled + WHERE status = 'pending' AND scheduled_at <= $1 + ORDER BY scheduled_at ASC + LIMIT 100 + FOR UPDATE SKIP LOCKED + ) + RETURNING * + "#, + ) + .bind(now) + .fetch_all(&mut *tx) + .await?; + + tx.commit().await?; + Ok(rows) + } + + /// Get all pending scheduled messages whose time has come (for background dispatch). + pub async fn get_due_scheduled(&self) -> ImksResult> { + self.claim_due_scheduled().await + } + + /// Mark a scheduled message as sent. + pub async fn mark_scheduled_sent( + &self, + scheduled_id: Uuid, + sent_message_id: Uuid, + ) -> ImksResult<()> { + sqlx::query( + r#" + UPDATE message_scheduled + SET status = 'sent', sent_message_id = $1, updated_at = $2 + WHERE id = $3 AND status = 'processing' + "#, + ) + .bind(sent_message_id) + .bind(Utc::now()) + .bind(scheduled_id) + .execute(self.pool()) + .await?; + + Ok(()) + } + + /// Mark a scheduled message as failed. + pub async fn mark_scheduled_failed(&self, scheduled_id: Uuid, error: &str) -> ImksResult<()> { + sqlx::query( + r#" + UPDATE message_scheduled + SET status = 'failed', error = $1, updated_at = $2 + WHERE id = $3 AND status = 'processing' + "#, + ) + .bind(error) + .bind(Utc::now()) + .bind(scheduled_id) + .execute(self.pool()) + .await?; + + Ok(()) + } +} diff --git a/repo/message_sticker.rs b/repo/message_sticker.rs new file mode 100644 index 0000000..0b05cf0 --- /dev/null +++ b/repo/message_sticker.rs @@ -0,0 +1,53 @@ +//! Sticker CRUD operations on `MessageRepo`. + +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_sticker::MessageSticker; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Record a sticker used in a message. + #[allow(clippy::too_many_arguments)] + pub async fn record_sticker( + &self, + message_id: Uuid, + sticker_id: Uuid, + name: &str, + image_url: &str, + format_type: &str, + pack_name: Option<&str>, + tags: Option<&str>, + ) -> ImksResult { + sqlx::query_as::<_, MessageSticker>( + r#" + INSERT INTO message_sticker (id, message_id, sticker_id, name, image_url, format_type, pack_name, tags) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING * + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(sticker_id) + .bind(name) + .bind(image_url) + .bind(format_type) + .bind(pack_name) + .bind(tags) + .fetch_one(self.pool()) + .await + .map_err(Into::into) + } + + /// Get all stickers on a message. + pub async fn get_stickers(&self, message_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageSticker>( + "SELECT * FROM message_sticker WHERE message_id = $1 ORDER BY created_at", + ) + .bind(message_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } +} diff --git a/repo/message_thread.rs b/repo/message_thread.rs new file mode 100644 index 0000000..d93259b --- /dev/null +++ b/repo/message_thread.rs @@ -0,0 +1,218 @@ +//! Thread CRUD operations on `MessageRepo`. + +use chrono::Utc; +use uuid::Uuid; + +use crate::ImksResult; +use crate::models::message_thread::MessageThread; +use crate::models::message_thread_participant::MessageThreadParticipant; + +use super::message_repo::MessageRepo; + +impl MessageRepo { + /// Create a new thread anchored on a root message. + pub async fn create_thread( + &self, + root_message_id: Uuid, + channel_id: Uuid, + created_by: Uuid, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + sqlx::query_as::<_, MessageThread>( + r#" + INSERT INTO message_thread ( + id, channel_id, root_message_id, created_by, + replies_count, participants_count, resolved, + created_at, updated_at + ) VALUES ($1, $2, $3, $4, 0, 0, FALSE, $5, $5) + ON CONFLICT (root_message_id) DO NOTHING + RETURNING * + "#, + ) + .bind(id) + .bind(channel_id) + .bind(root_message_id) + .bind(created_by) + .bind(now) + .fetch_optional(self.pool()) + .await? + .ok_or_else(|| crate::ImksError::InvalidInput("Thread already exists".into())) + } + + /// Get a thread by its ID. + pub async fn get_thread(&self, thread_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageThread>("SELECT * FROM message_thread WHERE id = $1") + .bind(thread_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } + + /// Get a thread by its root message ID. + pub async fn get_thread_by_root( + &self, + root_message_id: Uuid, + ) -> ImksResult> { + sqlx::query_as::<_, MessageThread>( + "SELECT * FROM message_thread WHERE root_message_id = $1", + ) + .bind(root_message_id) + .fetch_optional(self.pool()) + .await + .map_err(Into::into) + } + + /// List threads in a channel. + pub async fn list_threads(&self, channel_id: Uuid) -> ImksResult> { + sqlx::query_as::<_, MessageThread>( + r#" + SELECT * FROM message_thread + WHERE channel_id = $1 + ORDER BY last_reply_at DESC NULLS LAST + "#, + ) + .bind(channel_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } + + /// Increment thread reply counter and update last reply info. + pub async fn bump_thread(&self, thread_id: Uuid, message_id: Uuid) -> ImksResult<()> { + let now = Utc::now(); + sqlx::query( + r#" + UPDATE message_thread + SET replies_count = replies_count + 1, + last_reply_message_id = $1, + last_reply_at = $2, + updated_at = $2 + WHERE id = $3 + "#, + ) + .bind(message_id) + .bind(now) + .bind(thread_id) + .execute(self.pool()) + .await?; + + Ok(()) + } + + /// Resolve or unresolve a thread. + pub async fn resolve_thread( + &self, + thread_id: Uuid, + resolved_by: Uuid, + resolved: bool, + ) -> ImksResult<()> { + if resolved { + sqlx::query( + r#" + UPDATE message_thread + SET resolved = TRUE, resolved_by = $1, resolved_at = $2, updated_at = $2 + WHERE id = $3 + "#, + ) + .bind(resolved_by) + .bind(Utc::now()) + .bind(thread_id) + .execute(self.pool()) + .await?; + } else { + sqlx::query( + r#" + UPDATE message_thread + SET resolved = FALSE, resolved_by = NULL, resolved_at = NULL, updated_at = $1 + WHERE id = $2 + "#, + ) + .bind(Utc::now()) + .bind(thread_id) + .execute(self.pool()) + .await?; + } + + Ok(()) + } +} + +impl MessageRepo { + /// Add a participant to a thread (or update their join reason). + pub async fn add_thread_participant( + &self, + thread_id: Uuid, + user_id: Uuid, + joined_reason: &str, + ) -> ImksResult { + let id = Uuid::now_v7(); + let now = Utc::now(); + + let participant = sqlx::query_as::<_, MessageThreadParticipant>( + r#" + INSERT INTO message_thread_participant (id, thread_id, user_id, joined_reason, joined_at) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (thread_id, user_id) DO UPDATE SET joined_reason = EXCLUDED.joined_reason + RETURNING * + "#, + ) + .bind(id) + .bind(thread_id) + .bind(user_id) + .bind(joined_reason) + .bind(now) + .fetch_one(self.pool()) + .await?; + + sqlx::query( + "UPDATE message_thread SET participants_count = (SELECT COUNT(*) FROM message_thread_participant WHERE thread_id = $1) WHERE id = $1", + ) + .bind(thread_id) + .execute(self.pool()) + .await?; + + Ok(participant) + } + + /// Remove a participant from a thread. + pub async fn remove_thread_participant( + &self, + thread_id: Uuid, + user_id: Uuid, + ) -> ImksResult { + let result = sqlx::query( + "DELETE FROM message_thread_participant WHERE thread_id = $1 AND user_id = $2", + ) + .bind(thread_id) + .bind(user_id) + .execute(self.pool()) + .await?; + + if result.rows_affected() > 0 { + sqlx::query( + "UPDATE message_thread SET participants_count = GREATEST(participants_count - 1, 0) WHERE id = $1", + ) + .bind(thread_id) + .execute(self.pool()) + .await?; + } + + Ok(result.rows_affected() > 0) + } + + /// List all participants in a thread. + pub async fn list_thread_participants( + &self, + thread_id: Uuid, + ) -> ImksResult> { + sqlx::query_as::<_, MessageThreadParticipant>( + "SELECT * FROM message_thread_participant WHERE thread_id = $1 ORDER BY joined_at", + ) + .bind(thread_id) + .fetch_all(self.pool()) + .await + .map_err(Into::into) + } +} diff --git a/repo/mod.rs b/repo/mod.rs new file mode 100644 index 0000000..fd6cf10 --- /dev/null +++ b/repo/mod.rs @@ -0,0 +1,25 @@ +pub mod message_article; +pub mod message_attachment; +pub mod message_bookmark; +pub mod message_component; +pub mod message_create; +pub mod message_draft; +pub mod message_edit; +pub mod message_embed; +pub mod message_forward; +pub mod message_mention; +pub mod message_notification; +pub mod message_pin; +pub mod message_poll; +pub mod message_query; +pub mod message_reaction; +pub mod message_read_state; +pub mod message_repo; +pub mod message_scheduled; +pub mod message_sticker; +pub mod message_thread; +pub mod pagination; + +pub use message_create::CreateMessageInput; +pub use message_repo::MessageRepo; +pub use pagination::CursorPage; diff --git a/repo/pagination.rs b/repo/pagination.rs new file mode 100644 index 0000000..f441d52 --- /dev/null +++ b/repo/pagination.rs @@ -0,0 +1,114 @@ +//! Cursor-based pagination helpers for repository queries. +//! +//! UUID v7 IDs are time-ordered, so `WHERE id < $cursor ORDER BY id DESC` +//! naturally yields reverse-chronological pages without OFFSET. + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Default number of items per page. +pub const DEFAULT_PAGE_SIZE: i64 = 50; +/// Hard upper bound on page size to prevent abuse. +pub const MAX_PAGE_SIZE: i64 = 100; + +/// Generic cursor-based page response. +/// +/// Returned by list operations in the repo layer. The `next_cursor` +/// is the last item's UUID — pass it as `before` in the next request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CursorPage { + /// Items in this page (ordered by `id DESC`). + pub items: Vec, + /// Opaque cursor for the next page. `None` when no more results exist. + pub next_cursor: Option, + /// Whether there are more results beyond this page. + pub has_more: bool, +} + +impl CursorPage { + /// Build a page from a raw result set that may contain one extra row. + /// + /// If `raw_items.len() > limit`, the extra row is dropped and + /// `has_more` is set to `true`. + pub fn from_raw(mut raw_items: Vec, limit: i64, get_id: impl Fn(&T) -> Uuid) -> Self { + let has_more = raw_items.len() > limit as usize; + if has_more { + raw_items.truncate(limit as usize); + } + let next_cursor = if has_more { + raw_items.last().map(get_id) + } else { + None + }; + Self { + items: raw_items, + next_cursor, + has_more, + } + } + + /// Empty page (no results). + pub fn empty() -> Self { + Self { + items: Vec::new(), + next_cursor: None, + has_more: false, + } + } +} + +/// Clamp a caller-requested limit to `[1, MAX_PAGE_SIZE]`, defaulting to `DEFAULT_PAGE_SIZE`. +pub fn clamp_limit(limit: Option) -> i64 { + match limit { + Some(n) if n < 1 => DEFAULT_PAGE_SIZE, + Some(n) if n > MAX_PAGE_SIZE => MAX_PAGE_SIZE, + Some(n) => n, + None => DEFAULT_PAGE_SIZE, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clamp_limit_none() { + assert_eq!(clamp_limit(None), DEFAULT_PAGE_SIZE); + } + + #[test] + fn test_clamp_limit_zero() { + assert_eq!(clamp_limit(Some(0)), DEFAULT_PAGE_SIZE); + } + + #[test] + fn test_clamp_limit_negative() { + assert_eq!(clamp_limit(Some(-5)), DEFAULT_PAGE_SIZE); + } + + #[test] + fn test_clamp_limit_over_max() { + assert_eq!(clamp_limit(Some(200)), MAX_PAGE_SIZE); + } + + #[test] + fn test_clamp_limit_valid() { + assert_eq!(clamp_limit(Some(25)), 25); + } + + #[test] + fn test_cursor_page_empty() { + let page: CursorPage = CursorPage::empty(); + assert!(page.items.is_empty()); + assert!(!page.has_more); + assert!(page.next_cursor.is_none()); + } + + #[test] + fn test_cursor_page_from_raw_no_overflow() { + let items = vec!["a".to_string(), "b".to_string()]; + let page = CursorPage::from_raw(items, 5, |_| Uuid::nil()); + assert_eq!(page.items.len(), 2); + assert!(!page.has_more); + } +} diff --git a/rpc/clients.rs b/rpc/clients.rs new file mode 100644 index 0000000..b13ba59 --- /dev/null +++ b/rpc/clients.rs @@ -0,0 +1,108 @@ +//! Aggregate gRPC client holder for all appks core services. +//! +//! A single TCP `Channel` is shared across all four service clients +//! to avoid redundant connections. + +use std::fs; +use std::time::Duration; + +use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity}; + +use crate::pb::core::token_service_client::TokenServiceClient; +use crate::pb::im::{ + channel_service_client::ChannelServiceClient, member_service_client::MemberServiceClient, + permission_service_client::PermissionServiceClient, +}; +use crate::{ImksError, ImksResult}; + +use super::config::RpcConfig; + +/// Holds gRPC clients for all appks core services consumed by imks. +/// +/// Cheaply cloneable — each inner client wraps a shared `Arc`. +#[derive(Clone)] +pub struct AppksClients { + /// JWT token lifecycle: issue, refresh, revoke, verify, signing keys. + pub token: TokenServiceClient, + /// Channel and category CRUD + statistics. + pub channel: ChannelServiceClient, + /// Channel member invite / kick / join / leave. + pub member: MemberServiceClient, + /// Permission checks and overwrite rules. + pub permission: PermissionServiceClient, +} + +impl AppksClients { + /// Connect to all appks services using a shared gRPC channel. + pub async fn connect(config: &RpcConfig) -> ImksResult { + let mut endpoint = Endpoint::from_shared(config.appks_addr.clone()) + .map_err(|e| ImksError::Internal(format!("Invalid gRPC endpoint: {e}")))? + .connect_timeout(Duration::from_secs(config.connect_timeout_secs)); + + if config.appks_addr.starts_with("https://") + || config.tls_ca_cert_path.is_some() + || config.tls_client_cert_path.is_some() + || config.tls_client_key_path.is_some() + { + endpoint = endpoint.tls_config(build_tls_config(config)?)?; + } + + let channel = endpoint + .connect() + .await + .map_err(|e| ImksError::Internal(format!("gRPC connect failed: {e}")))?; + + tracing::info!(addr = %config.appks_addr, "Connected to appks gRPC services"); + + Ok(Self { + token: TokenServiceClient::new(channel.clone()), + channel: ChannelServiceClient::new(channel.clone()), + member: MemberServiceClient::new(channel.clone()), + permission: PermissionServiceClient::new(channel), + }) + } + + /// Build from pre-connected clients (useful for tests with mock servers). + pub fn new( + token: TokenServiceClient, + channel: ChannelServiceClient, + member: MemberServiceClient, + permission: PermissionServiceClient, + ) -> Self { + Self { + token, + channel, + member, + permission, + } + } +} + +fn build_tls_config(config: &RpcConfig) -> ImksResult { + let mut tls = ClientTlsConfig::new(); + + if let Some(domain) = &config.tls_domain_name { + tls = tls.domain_name(domain); + } + + if let Some(path) = &config.tls_ca_cert_path { + let pem = fs::read(path)?; + tls = tls.ca_certificate(Certificate::from_pem(pem)); + } + + match (&config.tls_client_cert_path, &config.tls_client_key_path) { + (Some(cert_path), Some(key_path)) => { + let cert = fs::read(cert_path)?; + let key = fs::read(key_path)?; + tls = tls.identity(Identity::from_pem(cert, key)); + } + (None, None) => {} + _ => { + return Err(ImksError::InvalidInput( + "Both APPKS_GRPC_TLS_CLIENT_CERT and APPKS_GRPC_TLS_CLIENT_KEY are required for mTLS".into(), + )); + } + } + + Ok(tls) +} diff --git a/rpc/config.rs b/rpc/config.rs new file mode 100644 index 0000000..abba5ec --- /dev/null +++ b/rpc/config.rs @@ -0,0 +1,65 @@ +//! gRPC client configuration for connecting to appks core services. +//! +//! Reads the appks address and timeout from environment variables. + +use std::env; + +/// Configuration for appks gRPC connections. +#[derive(Debug, Clone)] +pub struct RpcConfig { + /// appks gRPC endpoint, e.g. `http://localhost:50051`. + pub appks_addr: String, + /// Connection establishment timeout (seconds). + pub connect_timeout_secs: u64, + /// Optional CA certificate PEM path for appks mTLS. + pub tls_ca_cert_path: Option, + /// Optional client certificate PEM path for appks mTLS. + pub tls_client_cert_path: Option, + /// Optional client private key PEM path for appks mTLS. + pub tls_client_key_path: Option, + /// TLS domain name used for certificate verification. + pub tls_domain_name: Option, +} + +impl RpcConfig { + /// Build config from environment variables with defaults. + pub fn from_env() -> Self { + Self { + appks_addr: env::var("APPKS_GRPC_ADDR") + .unwrap_or_else(|_| "http://localhost:50051".to_string()), + connect_timeout_secs: env::var("APPKS_GRPC_TIMEOUT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10), + tls_ca_cert_path: env::var("APPKS_GRPC_TLS_CA_CERT").ok(), + tls_client_cert_path: env::var("APPKS_GRPC_TLS_CLIENT_CERT").ok(), + tls_client_key_path: env::var("APPKS_GRPC_TLS_CLIENT_KEY").ok(), + tls_domain_name: env::var("APPKS_GRPC_TLS_DOMAIN").ok(), + } + } +} + +impl Default for RpcConfig { + fn default() -> Self { + Self::from_env() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let cfg = RpcConfig { + appks_addr: "http://localhost:50051".to_string(), + connect_timeout_secs: 10, + tls_ca_cert_path: None, + tls_client_cert_path: None, + tls_client_key_path: None, + tls_domain_name: None, + }; + assert_eq!(cfg.connect_timeout_secs, 10); + assert!(cfg.appks_addr.starts_with("http")); + } +} diff --git a/rpc/mod.rs b/rpc/mod.rs new file mode 100644 index 0000000..8c8e8a4 --- /dev/null +++ b/rpc/mod.rs @@ -0,0 +1,5 @@ +pub mod clients; +pub mod config; + +pub use clients::AppksClients; +pub use config::RpcConfig; diff --git a/socket/adapter/local.rs b/socket/adapter/local.rs index daab46a..9d2d072 100644 --- a/socket/adapter/local.rs +++ b/socket/adapter/local.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use dashmap::DashMap; use uuid::Uuid; -use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, SocketInfo}; +use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, LocalSendFn, SocketInfo}; use crate::socket::packet::Packet; pub struct LocalAdapter { @@ -16,7 +16,7 @@ pub struct LocalAdapter { pub socket_sids: Arc>, /// socket_sid → namespace path socket_namespace: Arc>, - send_fn: Arc Result<(), String> + Send + Sync>, + send_fn: LocalSendFn, } impl LocalAdapter { @@ -68,7 +68,11 @@ impl LocalAdapter { #[async_trait] impl Adapter for LocalAdapter { - async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { + async fn broadcast( + &self, + packet: &Packet, + opts: &BroadcastOptions, + ) -> Result<(), AdapterError> { let namespace = &packet.namespace; let sids = self.collect_matching_sids(opts, namespace); for sid in &sids { @@ -87,9 +91,16 @@ impl Adapter for LocalAdapter { Ok(()) } - async fn register(&self, socket_sid: &str, engine_sid: &str, ns: &str) -> Result<(), AdapterError> { - self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string()); - self.socket_namespace.insert(socket_sid.to_string(), ns.to_string()); + async fn register( + &self, + socket_sid: &str, + engine_sid: &str, + ns: &str, + ) -> Result<(), AdapterError> { + self.socket_sids + .insert(socket_sid.to_string(), engine_sid.to_string()); + self.socket_namespace + .insert(socket_sid.to_string(), ns.to_string()); Ok(()) } @@ -99,8 +110,16 @@ impl Adapter for LocalAdapter { async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> { let key = Self::room_key(ns, room); - self.rooms.entry(key).or_insert_with(HashSet::new).value_mut().insert(sid.to_string()); - self.socket_rooms.entry(sid.to_string()).or_insert_with(HashSet::new).value_mut().insert(room.to_string()); + self.rooms + .entry(key) + .or_default() + .value_mut() + .insert(sid.to_string()); + self.socket_rooms + .entry(sid.to_string()) + .or_default() + .value_mut() + .insert(room.to_string()); Ok(()) } @@ -137,10 +156,14 @@ impl Adapter for LocalAdapter { } } self.socket_sids.remove(sid); + self.socket_namespace.remove(sid); Ok(()) } - async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result, AdapterError> { + async fn fetch_sockets( + &self, + opts: &BroadcastOptions, + ) -> Result, AdapterError> { // fetch_sockets needs namespace context; use an empty namespace to match all // (this method is typically called for inspection, not delivery) let sids: Vec = if opts.rooms.is_empty() { @@ -164,11 +187,13 @@ impl Adapter for LocalAdapter { continue; } if self.socket_sids.contains_key(sid) { - let namespace = self.socket_namespace + let namespace = self + .socket_namespace .get(sid) .map(|r| r.value().clone()) .unwrap_or_default(); - let rooms = self.socket_rooms + let rooms = self + .socket_rooms .get(sid) .map(|r| r.value().clone()) .unwrap_or_default(); @@ -183,7 +208,8 @@ impl Adapter for LocalAdapter { } async fn socket_rooms(&self, sid: &str) -> Result, AdapterError> { - Ok(self.socket_rooms + Ok(self + .socket_rooms .get(sid) .map(|r| r.value().clone()) .unwrap_or_default()) @@ -196,4 +222,4 @@ impl Adapter for LocalAdapter { async fn close(&self) -> Result<(), AdapterError> { Ok(()) } -} \ No newline at end of file +} diff --git a/socket/adapter/mod.rs b/socket/adapter/mod.rs index b46eeed..61bb447 100644 --- a/socket/adapter/mod.rs +++ b/socket/adapter/mod.rs @@ -1,14 +1,20 @@ pub mod local; -pub mod redis; pub mod nats; +pub mod redis; use std::collections::HashSet; +use std::sync::Arc; use async_trait::async_trait; use thiserror::Error; use crate::socket::packet::Packet; +/// Alias for cross-node broadcast callback functions. +pub type LocalBroadcastFn = Arc; +/// Alias for local send-to-socket callback functions. +pub type LocalSendFn = Arc Result<(), String> + Send + Sync>; + #[derive(Error, Debug)] pub enum AdapterError { #[error("Redis error: {0}")] @@ -72,11 +78,13 @@ pub enum BusMessage { #[async_trait] pub trait Adapter: Send + Sync + 'static { - async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError>; + async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) + -> Result<(), AdapterError>; async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>; async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>; async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>; - async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result, AdapterError>; + async fn fetch_sockets(&self, opts: &BroadcastOptions) + -> Result, AdapterError>; async fn socket_rooms(&self, sid: &str) -> Result, AdapterError>; fn server_id(&self) -> &str; async fn close(&self) -> Result<(), AdapterError>; @@ -84,7 +92,12 @@ pub trait Adapter: Send + Sync + 'static { /// Register a socket SID → engine SID mapping in the adapter. /// Must be called when a socket first connects, before any room operations. /// The `ns` parameter is the namespace path this socket belongs to. - async fn register(&self, _socket_sid: &str, _engine_sid: &str, _ns: &str) -> Result<(), AdapterError> { + async fn register( + &self, + _socket_sid: &str, + _engine_sid: &str, + _ns: &str, + ) -> Result<(), AdapterError> { Ok(()) } @@ -95,5 +108,5 @@ pub trait Adapter: Send + Sync + 'static { } pub use local::LocalAdapter; +pub use nats::NatsAdapter; pub use redis::RedisAdapter; -pub use nats::NatsAdapter; \ No newline at end of file diff --git a/socket/adapter/nats.rs b/socket/adapter/nats.rs index 8cd86c1..a71faba 100644 --- a/socket/adapter/nats.rs +++ b/socket/adapter/nats.rs @@ -5,7 +5,9 @@ use async_trait::async_trait; use dashmap::DashMap; use tokio::sync::mpsc; -use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo}; +use crate::socket::adapter::{ + Adapter, AdapterError, BroadcastOptions, BusMessage, LocalBroadcastFn, SocketInfo, +}; use crate::socket::message_bus::MessageBus; use crate::socket::packet::Packet; use crate::socket::parser; @@ -15,11 +17,16 @@ use crate::socket::socket::Socket; /// Only performs local dispatch — no remote state writes needed. async fn handle_bus_message( msg: BusMessage, - on_local_broadcast: &Arc, + on_local_broadcast: &LocalBroadcastFn, server_id: &str, ) { match msg { - BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => { + BusMessage::Broadcast { + namespace: _, + packet, + opts, + server_id: sender_id, + } => { if sender_id == server_id { return; } @@ -29,13 +36,18 @@ async fn handle_bus_message( } // NATS adapter manages room state locally; cross-server join/leave/disconnect // are informational only and don't require duplicate state writes. - BusMessage::SocketJoin { server_id: sender_id, .. } - | BusMessage::SocketLeave { server_id: sender_id, .. } - | BusMessage::SocketDisconnect { server_id: sender_id, .. } => { - if sender_id == server_id { - return; - } + BusMessage::SocketJoin { + server_id: sender_id, + .. } + | BusMessage::SocketLeave { + server_id: sender_id, + .. + } + | BusMessage::SocketDisconnect { + server_id: sender_id, + .. + } => if sender_id == server_id {}, } } @@ -51,7 +63,7 @@ pub struct NatsAdapter { sockets: DashMap>, server_id: String, namespace: String, - on_local_broadcast: Arc, + on_local_broadcast: LocalBroadcastFn, } impl NatsAdapter { @@ -59,7 +71,7 @@ impl NatsAdapter { message_bus: Arc, server_id: String, namespace: String, - on_local_broadcast: Arc, + on_local_broadcast: LocalBroadcastFn, ) -> Self { Self { message_bus, @@ -133,7 +145,11 @@ impl NatsAdapter { #[async_trait] impl Adapter for NatsAdapter { - async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { + async fn broadcast( + &self, + packet: &Packet, + opts: &BroadcastOptions, + ) -> Result<(), AdapterError> { if opts.flags.local_only { (self.on_local_broadcast)(packet, opts); return Ok(()); @@ -146,8 +162,8 @@ impl Adapter for NatsAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:broadcast", self.namespace), &payload) @@ -158,20 +174,30 @@ impl Adapter for NatsAdapter { Ok(()) } - async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> { - self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string()); + async fn register( + &self, + socket_sid: &str, + engine_sid: &str, + _ns: &str, + ) -> Result<(), AdapterError> { + self.socket_sids + .insert(socket_sid.to_string(), engine_sid.to_string()); Ok(()) } async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> { self.socket_rooms .entry(sid.to_string()) - .and_modify(|set| { set.insert(room.to_string()); }) + .and_modify(|set| { + set.insert(room.to_string()); + }) .or_insert_with(|| HashSet::from([room.to_string()])); self.rooms .entry(room.to_string()) - .and_modify(|set| { set.insert(sid.to_string()); }) + .and_modify(|set| { + set.insert(sid.to_string()); + }) .or_insert_with(|| HashSet::from([sid.to_string()])); let msg = BusMessage::SocketJoin { @@ -181,8 +207,8 @@ impl Adapter for NatsAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:join", self.namespace), &payload) @@ -196,14 +222,24 @@ impl Adapter for NatsAdapter { if let Some(mut entry) = self.socket_rooms.get_mut(sid) { entry.value_mut().remove(room); } - if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) { + if self + .socket_rooms + .get(sid) + .map(|e| e.value().is_empty()) + .unwrap_or(true) + { self.socket_rooms.remove(sid); } if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } - if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { + if self + .rooms + .get(room) + .map(|e| e.value().is_empty()) + .unwrap_or(true) + { self.rooms.remove(room); } @@ -214,8 +250,8 @@ impl Adapter for NatsAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus .publish(&format!("socket.io:{}:leave", self.namespace), &payload) @@ -231,7 +267,12 @@ impl Adapter for NatsAdapter { if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } - if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { + if self + .rooms + .get(room) + .map(|e| e.value().is_empty()) + .unwrap_or(true) + { self.rooms.remove(room); } } @@ -246,18 +287,24 @@ impl Adapter for NatsAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus - .publish(&format!("socket.io:{}:disconnect", self.namespace), &payload) + .publish( + &format!("socket.io:{}:disconnect", self.namespace), + &payload, + ) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } - async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result, AdapterError> { + async fn fetch_sockets( + &self, + opts: &BroadcastOptions, + ) -> Result, AdapterError> { let mut result = Vec::new(); let target_sids: HashSet = if opts.rooms.is_empty() { @@ -276,7 +323,11 @@ impl Adapter for NatsAdapter { if opts.except.contains(&sid) { continue; } - let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default(); + let rooms = self + .socket_rooms + .get(&sid) + .map(|e| e.value().clone()) + .unwrap_or_default(); result.push(SocketInfo { sid: sid.clone(), namespace: self.namespace.clone(), @@ -288,7 +339,11 @@ impl Adapter for NatsAdapter { } async fn socket_rooms(&self, sid: &str) -> Result, AdapterError> { - Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default()) + Ok(self + .socket_rooms + .get(sid) + .map(|e| e.value().clone()) + .unwrap_or_default()) } fn server_id(&self) -> &str { @@ -296,7 +351,10 @@ impl Adapter for NatsAdapter { } async fn close(&self) -> Result<(), AdapterError> { - self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?; + self.message_bus + .close() + .await + .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } } diff --git a/socket/adapter/redis.rs b/socket/adapter/redis.rs index 98f4591..0015334 100644 --- a/socket/adapter/redis.rs +++ b/socket/adapter/redis.rs @@ -7,7 +7,9 @@ use fred::clients::Client; use fred::interfaces::{KeysInterface, SetsInterface}; use tokio::sync::mpsc; -use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo}; +use crate::socket::adapter::{ + Adapter, AdapterError, BroadcastOptions, BusMessage, LocalBroadcastFn, SocketInfo, +}; use crate::socket::message_bus::MessageBus; use crate::socket::packet::Packet; use crate::socket::parser; @@ -28,11 +30,16 @@ fn socket_rooms_key(ns: &str, sid: &str) -> String { /// Only performs local state updates — the remote server already wrote to Redis. async fn handle_bus_message( msg: BusMessage, - on_local_broadcast: &Arc, + on_local_broadcast: &LocalBroadcastFn, server_id: &str, ) { match msg { - BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => { + BusMessage::Broadcast { + namespace: _, + packet, + opts, + server_id: sender_id, + } => { if sender_id == server_id { return; } @@ -40,13 +47,20 @@ async fn handle_bus_message( on_local_broadcast(&decoded_packet, &opts); } } - BusMessage::SocketJoin { server_id: sender_id, .. } - | BusMessage::SocketLeave { server_id: sender_id, .. } - | BusMessage::SocketDisconnect { server_id: sender_id, .. } => { + BusMessage::SocketJoin { + server_id: sender_id, + .. + } + | BusMessage::SocketLeave { + server_id: sender_id, + .. + } + | BusMessage::SocketDisconnect { + server_id: sender_id, + .. + } => { // Skip messages from this server; remote server already updated Redis - if sender_id == server_id { - return; - } + if sender_id == server_id {} // No duplicate Redis writes — the sender already persisted the state change } } @@ -58,10 +72,12 @@ pub struct RedisAdapter { room_subscribers: DashMap>>, socket_rooms: DashMap>, rooms: DashMap>, + /// socket_sid → engine_sid mapping for local inspection. + socket_sids: DashMap, sockets: DashMap>, server_id: String, namespace: String, - on_local_broadcast: Arc, + on_local_broadcast: LocalBroadcastFn, } impl RedisAdapter { @@ -70,7 +86,7 @@ impl RedisAdapter { redis_client: Client, server_id: String, namespace: String, - on_local_broadcast: Arc, + on_local_broadcast: LocalBroadcastFn, ) -> Self { Self { message_bus, @@ -81,6 +97,7 @@ impl RedisAdapter { room_subscribers: DashMap::new(), socket_rooms: DashMap::new(), rooms: DashMap::new(), + socket_sids: DashMap::new(), sockets: DashMap::new(), } } @@ -144,7 +161,11 @@ impl RedisAdapter { #[async_trait] impl Adapter for RedisAdapter { - async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { + async fn broadcast( + &self, + packet: &Packet, + opts: &BroadcastOptions, + ) -> Result<(), AdapterError> { if opts.flags.local_only { (self.on_local_broadcast)(packet, opts); return Ok(()); @@ -157,11 +178,11 @@ impl Adapter for RedisAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus - .publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload) + .publish(&format!("socket.io:{}:broadcast", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; @@ -185,12 +206,16 @@ impl Adapter for RedisAdapter { self.socket_rooms .entry(sid.to_string()) - .and_modify(|set| { set.insert(room.to_string()); }) + .and_modify(|set| { + set.insert(room.to_string()); + }) .or_insert_with(|| HashSet::from([room.to_string()])); self.rooms .entry(room.to_string()) - .and_modify(|set| { set.insert(sid.to_string()); }) + .and_modify(|set| { + set.insert(sid.to_string()); + }) .or_insert_with(|| HashSet::from([sid.to_string()])); let msg = BusMessage::SocketJoin { @@ -200,11 +225,11 @@ impl Adapter for RedisAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus - .publish(&format!("socket.io:{}:join", ns), &payload) + .publish(&format!("socket.io:{}:join", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; @@ -228,14 +253,24 @@ impl Adapter for RedisAdapter { if let Some(mut entry) = self.socket_rooms.get_mut(sid) { entry.value_mut().remove(room); } - if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) { + if self + .socket_rooms + .get(sid) + .map(|e| e.value().is_empty()) + .unwrap_or(true) + { self.socket_rooms.remove(sid); } if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } - if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { + if self + .rooms + .get(room) + .map(|e| e.value().is_empty()) + .unwrap_or(true) + { self.rooms.remove(room); } @@ -246,11 +281,11 @@ impl Adapter for RedisAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus - .publish(&format!("socket.io:{}:leave", ns), &payload) + .publish(&format!("socket.io:{}:leave", self.namespace), &payload) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; @@ -263,7 +298,12 @@ impl Adapter for RedisAdapter { if let Some(mut entry) = self.rooms.get_mut(room) { entry.value_mut().remove(sid); } - if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { + if self + .rooms + .get(room) + .map(|e| e.value().is_empty()) + .unwrap_or(true) + { self.rooms.remove(room); } @@ -280,6 +320,7 @@ impl Adapter for RedisAdapter { .await .map_err(|e| AdapterError::Redis(e.to_string()))?; + self.socket_sids.remove(sid); self.sockets.remove(sid); let msg = BusMessage::SocketDisconnect { @@ -288,22 +329,43 @@ impl Adapter for RedisAdapter { server_id: self.server_id.clone(), }; - let payload = serde_json::to_vec(&msg) - .map_err(|e| AdapterError::Serialization(e.to_string()))?; + let payload = + serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?; self.message_bus - .publish(&format!("socket.io:{}:disconnect", ns), &payload) + .publish( + &format!("socket.io:{}:disconnect", self.namespace), + &payload, + ) .await .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } - async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result, AdapterError> { + async fn register( + &self, + socket_sid: &str, + engine_sid: &str, + _ns: &str, + ) -> Result<(), AdapterError> { + self.socket_sids + .insert(socket_sid.to_string(), engine_sid.to_string()); + Ok(()) + } + + async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> { + self.del_all(socket_sid, ns).await + } + + async fn fetch_sockets( + &self, + opts: &BroadcastOptions, + ) -> Result, AdapterError> { let mut result = Vec::new(); let target_sids: HashSet = if opts.rooms.is_empty() { - self.sockets.iter().map(|e| e.key().clone()).collect() + self.socket_sids.iter().map(|e| e.key().clone()).collect() } else { let mut sids = HashSet::new(); for room in &opts.rooms { @@ -318,7 +380,11 @@ impl Adapter for RedisAdapter { if opts.except.contains(&sid) { continue; } - let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default(); + let rooms = self + .socket_rooms + .get(&sid) + .map(|e| e.value().clone()) + .unwrap_or_default(); result.push(SocketInfo { sid: sid.clone(), namespace: self.namespace.clone(), @@ -330,7 +396,11 @@ impl Adapter for RedisAdapter { } async fn socket_rooms(&self, sid: &str) -> Result, AdapterError> { - Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default()) + Ok(self + .socket_rooms + .get(sid) + .map(|e| e.value().clone()) + .unwrap_or_default()) } fn server_id(&self) -> &str { @@ -338,7 +408,10 @@ impl Adapter for RedisAdapter { } async fn close(&self) -> Result<(), AdapterError> { - self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?; + self.message_bus + .close() + .await + .map_err(|e| AdapterError::MessageBus(e.to_string()))?; Ok(()) } } diff --git a/socket/message_bus/mod.rs b/socket/message_bus/mod.rs index b9c79f2..cff9235 100644 --- a/socket/message_bus/mod.rs +++ b/socket/message_bus/mod.rs @@ -1,5 +1,5 @@ -pub mod redis; pub mod nats; +pub mod redis; use async_trait::async_trait; use thiserror::Error; @@ -27,5 +27,5 @@ pub trait MessageBus: Send + Sync + 'static { async fn close(&self) -> Result<(), MessageBusError>; } +pub use nats::NatsMessageBus; pub use redis::RedisMessageBus; -pub use nats::NatsMessageBus; \ No newline at end of file diff --git a/socket/message_bus/nats.rs b/socket/message_bus/nats.rs index 08a0ee9..92f77d8 100644 --- a/socket/message_bus/nats.rs +++ b/socket/message_bus/nats.rs @@ -34,7 +34,8 @@ impl MessageBus for NatsMessageBus { async fn subscribe(&self, channel: &str) -> Result>, MessageBusError> { let (tx, rx) = mpsc::channel::>(256); - let mut subscriber = self.client + let mut subscriber = self + .client .subscribe(channel.to_string()) .await .map_err(|e| MessageBusError::Nats(e.to_string()))?; @@ -85,4 +86,4 @@ impl MessageBus for NatsMessageBus { self.shutdowns.clear(); Ok(()) } -} \ No newline at end of file +} diff --git a/socket/message_bus/redis.rs b/socket/message_bus/redis.rs index c7e8c5e..9fb14e8 100644 --- a/socket/message_bus/redis.rs +++ b/socket/message_bus/redis.rs @@ -13,8 +13,8 @@ pub struct RedisMessageBus { impl RedisMessageBus { pub async fn new(redis_url: &str) -> Result { - let config = Config::from_url(redis_url) - .map_err(|e| MessageBusError::Redis(e.to_string()))?; + let config = + Config::from_url(redis_url).map_err(|e| MessageBusError::Redis(e.to_string()))?; let client = Client::new(config.clone(), None, None, None); let subscriber = SubscriberClient::new(config, None, None, None); @@ -64,9 +64,8 @@ impl MessageBus for RedisMessageBus { tokio::spawn(async move { while let Ok(message) = message_rx.recv().await { - if &message.channel == &channel_owned { - let data: Vec = FromValue::from_value(message.value) - .unwrap_or_default(); + if message.channel == channel_owned { + let data: Vec = FromValue::from_value(message.value).unwrap_or_default(); if tx.send(data).await.is_err() { break; } @@ -96,4 +95,4 @@ impl MessageBus for RedisMessageBus { .map_err(|e| MessageBusError::Redis(e.to_string()))?; Ok(()) } -} \ No newline at end of file +} diff --git a/socket/mod.rs b/socket/mod.rs index 8a395dc..ab2af03 100644 --- a/socket/mod.rs +++ b/socket/mod.rs @@ -5,12 +5,18 @@ pub mod packet; pub mod parser; pub mod server; pub mod session_store; +#[allow(clippy::module_inception)] pub mod socket; -pub use adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, RedisAdapter, NatsAdapter, SocketInfo}; -pub use message_bus::{MessageBus, MessageBusError, RedisMessageBus, NatsMessageBus}; -pub use namespace::{is_valid_namespace, Namespace, NamespaceManager}; +pub use adapter::{ + Adapter, AdapterError, BroadcastFlags, BroadcastOptions, BusMessage, LocalAdapter, NatsAdapter, + RedisAdapter, SocketInfo, +}; +pub use message_bus::{MessageBus, MessageBusError, NatsMessageBus, RedisMessageBus}; +pub use namespace::{Namespace, NamespaceManager, is_valid_namespace}; pub use packet::{Packet, PacketType}; pub use server::{SocketServer, SocketServerBuilder}; -pub use session_store::{InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait}; -pub use socket::Socket; \ No newline at end of file +pub use session_store::{ + InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait, +}; +pub use socket::Socket; diff --git a/socket/namespace.rs b/socket/namespace.rs index 5dc437d..71e2435 100644 --- a/socket/namespace.rs +++ b/socket/namespace.rs @@ -4,12 +4,13 @@ use std::sync::Arc; use dashmap::DashMap; use tokio::sync::RwLock; -use crate::socket::adapter::{Adapter, BroadcastOptions, BroadcastFlags}; +use crate::socket::adapter::{Adapter, BroadcastFlags, BroadcastOptions}; use crate::socket::packet::Packet; use crate::socket::socket::Socket; -pub type EventHandler = Arc; -type ConnectHandler = Arc) -> Result<(), String> + Send + Sync>; +pub type EventHandler = Arc, &serde_json::Value) + Send + Sync>; +type ConnectHandler = + Arc) -> Result<(), String> + Send + Sync>; pub struct Namespace { pub path: String, @@ -19,6 +20,8 @@ pub struct Namespace { engine_to_socket: DashMap, handlers: RwLock>>, connect_handler: RwLock>, + rooms: DashMap>, + socket_rooms: DashMap>, pub(crate) adapter: RwLock>>, } @@ -30,6 +33,8 @@ impl Namespace { engine_to_socket: DashMap::new(), handlers: RwLock::new(HashMap::new()), connect_handler: RwLock::new(None), + rooms: DashMap::new(), + socket_rooms: DashMap::new(), adapter: RwLock::new(None), } } @@ -40,11 +45,15 @@ impl Namespace { } /// Add a socket to this namespace. Returns Err if the connect handler rejects. - pub async fn add_socket(&self, socket: Arc) -> Result<(), String> { + pub async fn add_socket( + &self, + socket: Arc, + auth_data: Option<&serde_json::Value>, + ) -> Result<(), String> { // Run connect handler before adding to storage let handler = self.connect_handler.read().await; if let Some(ref h) = *handler { - h(&socket, None)?; + h(&socket, auth_data)?; } drop(handler); @@ -53,10 +62,10 @@ impl Namespace { // Register with adapter (socket_sid → engine_sid mapping) let adapter = self.adapter.read().await; - if let Some(ref adapter) = *adapter { - if let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await { - tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e); - } + if let Some(ref adapter) = *adapter + && let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await + { + tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e); } // Store socket by socket_sid, plus reverse index @@ -69,12 +78,13 @@ impl Namespace { pub async fn remove_socket_by_sid(&self, socket_sid: &str) { if let Some((_, socket)) = self.sockets.remove(socket_sid) { self.engine_to_socket.remove(&socket.engine_sid); + self.remove_socket_from_local_rooms(socket_sid); let adapter = self.adapter.read().await; - if let Some(ref adapter) = *adapter { - if let Err(e) = adapter.del_all(socket_sid, &self.path).await { - tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e); - } + if let Some(ref adapter) = *adapter + && let Err(e) = adapter.del_all(socket_sid, &self.path).await + { + tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e); } } } @@ -130,7 +140,12 @@ impl Namespace { } } - pub async fn emit_to_room(&self, room: &str, event: impl Into, data: serde_json::Value) { + pub async fn emit_to_room( + &self, + room: &str, + event: impl Into, + data: serde_json::Value, + ) { let event_name = event.into(); let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None); @@ -145,20 +160,64 @@ impl Namespace { tracing::warn!("Adapter broadcast to room error: {}", e); } } else { - self.emit_local(&packet); + self.emit_local_to_room(&packet, room, &HashSet::new()); } } pub fn emit_local(&self, packet: &Packet) { for entry in self.sockets.iter() { - let socket = entry.value(); - if socket.send_packet(packet).is_err() { - tracing::warn!("Failed to send event to socket {}", socket.sid); + self.send_local_packet(entry.value(), packet); + } + } + + pub fn emit_local_filtered(&self, packet: &Packet, opts: &BroadcastOptions) { + if opts.rooms.is_empty() { + for entry in self.sockets.iter() { + if !opts.except.contains(entry.key()) { + self.send_local_packet(entry.value(), packet); + } + } + return; + } + + let mut target_sids = HashSet::new(); + for room in &opts.rooms { + if let Some(room_sids) = self.rooms.get(room) { + target_sids.extend(room_sids.value().iter().cloned()); + } + } + + for sid in target_sids { + if opts.except.contains(&sid) { + continue; + } + if let Some(socket) = self.get_socket(&sid) { + self.send_local_packet(&socket, packet); } } } - pub async fn emit_to(&self, socket_sid: &str, event: impl Into, data: serde_json::Value) { + fn emit_local_to_room(&self, packet: &Packet, room: &str, except: &HashSet) { + let opts = BroadcastOptions { + rooms: HashSet::from([room.to_string()]), + except: except.clone(), + flags: BroadcastFlags::default(), + }; + self.emit_local_filtered(packet, &opts); + } + + fn send_local_packet(&self, socket: &Socket, packet: &Packet) { + if socket.send_packet(packet).is_err() { + tracing::warn!("Failed to send event to socket {}", socket.sid); + } + } + + pub async fn emit_to( + &self, + socket_sid: &str, + event: impl Into, + data: serde_json::Value, + ) { if let Some(socket) = self.get_socket(socket_sid) { let event_name = event.into(); let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None); @@ -168,11 +227,79 @@ impl Namespace { } } - pub async fn handle_event(&self, socket: &Socket, event: &str, data: &serde_json::Value) { + pub async fn handle_event(&self, socket: Arc, event: &str, data: &serde_json::Value) { let handlers = self.handlers.read().await; if let Some(event_handlers) = handlers.get(event) { for handler in event_handlers { - handler(socket, data); + handler(Arc::clone(&socket), data); + } + } + } + + pub async fn join_room(&self, socket_sid: &str, room: &str) -> crate::ImksResult<()> { + if !self.sockets.contains_key(socket_sid) { + return Err(crate::ImksError::SocketNotFound(socket_sid.to_string())); + } + + self.rooms + .entry(room.to_string()) + .or_default() + .value_mut() + .insert(socket_sid.to_string()); + self.socket_rooms + .entry(socket_sid.to_string()) + .or_default() + .value_mut() + .insert(room.to_string()); + + let adapter = self.adapter.read().await; + if let Some(ref adapter) = *adapter + && let Err(e) = adapter.add(socket_sid, room, &self.path).await + { + self.remove_local_room(socket_sid, room); + return Err(e.into()); + } + Ok(()) + } + + pub async fn leave_room(&self, socket_sid: &str, room: &str) -> crate::ImksResult<()> { + let adapter = self.adapter.read().await; + if let Some(ref adapter) = *adapter { + adapter.del(socket_sid, room, &self.path).await?; + } + + self.remove_local_room(socket_sid, room); + Ok(()) + } + + fn remove_local_room(&self, socket_sid: &str, room: &str) { + if let Some(mut sids) = self.rooms.get_mut(room) { + sids.value_mut().remove(socket_sid); + if sids.value().is_empty() { + drop(sids); + self.rooms.remove(room); + } + } + + if let Some(mut rooms) = self.socket_rooms.get_mut(socket_sid) { + rooms.value_mut().remove(room); + if rooms.value().is_empty() { + drop(rooms); + self.socket_rooms.remove(socket_sid); + } + } + } + + fn remove_socket_from_local_rooms(&self, socket_sid: &str) { + if let Some((_, rooms)) = self.socket_rooms.remove(socket_sid) { + for room in rooms { + if let Some(mut sids) = self.rooms.get_mut(&room) { + sids.value_mut().remove(socket_sid); + if sids.value().is_empty() { + drop(sids); + self.rooms.remove(&room); + } + } } } } diff --git a/socket/parser.rs b/socket/parser.rs index 1ea5e4d..212b0f1 100644 --- a/socket/parser.rs +++ b/socket/parser.rs @@ -24,19 +24,18 @@ pub fn encode(packet: &Packet) -> String { if let Some(ref data) = packet.data { if packet.has_binary() { - let data_with_placeholders = replace_binary_with_placeholders(data, packet.attachment_count()); - let encoded_data = serde_json::to_string(&data_with_placeholders) - .unwrap_or_else(|e| { - tracing::error!("Failed to serialize socket packet data: {}", e); - "null".to_string() - }); + let data_with_placeholders = + replace_binary_with_placeholders(data, packet.attachment_count()); + let encoded_data = serde_json::to_string(&data_with_placeholders).unwrap_or_else(|e| { + tracing::error!("Failed to serialize socket packet data: {}", e); + "null".to_string() + }); result.push_str(&encoded_data); } else { - let encoded_data = serde_json::to_string(data) - .unwrap_or_else(|e| { - tracing::error!("Failed to serialize socket packet data: {}", e); - "null".to_string() - }); + let encoded_data = serde_json::to_string(data).unwrap_or_else(|e| { + tracing::error!("Failed to serialize socket packet data: {}", e); + "null".to_string() + }); result.push_str(&encoded_data); } } @@ -67,7 +66,8 @@ pub fn decode(input: &str) -> Result { let type_char = chars.next().ok_or(PacketError::Empty)?; let packet_type = PacketType::try_from(type_char)?; - let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) { + let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) + { let mut count_str = String::new(); while let Some(&c) = chars.peek() { if c == '-' { @@ -126,7 +126,11 @@ pub fn decode(input: &str) -> Result { id, attachments: Vec::new(), // Store attachment_count for binary packets; actual attachments come via decode_with_attachments - expected_attachments: if attachment_count > 0 { Some(attachment_count) } else { None }, + expected_attachments: if attachment_count > 0 { + Some(attachment_count) + } else { + None + }, }) } @@ -144,10 +148,10 @@ pub fn decode_with_attachments( packet.attachments = attachments; packet.expected_attachments = None; - if packet.has_binary() { - if let Some(ref data) = packet.data { - packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments)); - } + if packet.has_binary() + && let Some(ref data) = packet.data + { + packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments)); } Ok(packet) @@ -204,12 +208,12 @@ fn replace_binary_with_placeholders(value: &Value, total_attachments: usize) -> } } -fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut usize) -> Value { +fn replace_binary_with_placeholders_inner(value: &Value, _placeholder_idx: &mut usize) -> Value { match value { Value::Array(arr) => { let new_arr: Vec = arr .iter() - .map(|v| replace_binary_with_placeholders_inner(v, placeholder_idx)) + .map(|v| replace_binary_with_placeholders_inner(v, _placeholder_idx)) .collect(); Value::Array(new_arr) } @@ -218,7 +222,7 @@ fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut u for (k, v) in map { new_map.insert( k.clone(), - replace_binary_with_placeholders_inner(v, placeholder_idx), + replace_binary_with_placeholders_inner(v, _placeholder_idx), ); } Value::Object(new_map) @@ -236,15 +240,13 @@ fn replace_placeholders_with_binary(value: &Value, attachments: &[Vec]) -> V // Check if this is a placeholder object: { "_placeholder": true, "num": N } if let (Some(Value::Bool(true)), Some(Value::Number(num))) = (map.get("_placeholder"), map.get("num")) + && let Some(idx) = num.as_u64() + && let Some(attachment) = attachments.get(idx as usize) { - if let Some(idx) = num.as_u64() { - if let Some(attachment) = attachments.get(idx as usize) { - return Value::String(base64::Engine::encode( - &base64::engine::general_purpose::STANDARD, - attachment, - )); - } - } + return Value::String(base64::Engine::encode( + &base64::engine::general_purpose::STANDARD, + attachment, + )); } let mut new_map = serde_json::Map::new(); @@ -389,4 +391,4 @@ mod tests { assert_eq!(packet.expected_attachments, Some(1)); assert_eq!(packet.namespace, "/"); } -} \ No newline at end of file +} diff --git a/socket/server.rs b/socket/server.rs index 84de50a..fcb53f7 100644 --- a/socket/server.rs +++ b/socket/server.rs @@ -103,8 +103,14 @@ impl SocketServerBuilder { let adapter = adapter_clone.clone(); tokio::spawn(async move { handle_engine_message( - sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter, - ).await; + sid, + engine_packet, + &namespaces, + &socket_txs, + &engine_store, + &adapter, + ) + .await; }); }, )); @@ -136,10 +142,18 @@ async fn handle_engine_message( adapter: &Arc, ) { if let EnginePacketData::Text(ref text) = engine_packet.data { - if let Ok(socket_packet) = parser::decode(text) { - match socket_packet.packet_type { + match parser::decode(text) { + Ok(socket_packet) => match socket_packet.packet_type { PacketType::Connect => { - handle_connect(&engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter).await; + handle_connect( + &engine_sid, + &socket_packet, + namespaces, + socket_txs, + engine_store, + adapter, + ) + .await; } PacketType::Disconnect => { handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs); @@ -151,6 +165,9 @@ async fn handle_engine_message( handle_ack(&engine_sid, &socket_packet); } _ => {} + }, + Err(e) => { + tracing::warn!(engine_sid = %engine_sid, error = %e, "Invalid Socket.IO packet"); } } } @@ -166,22 +183,21 @@ async fn handle_connect( ) { // Validate namespace path to prevent DoS via arbitrary namespace creation if !crate::socket::namespace::is_valid_namespace(&packet.namespace) { - tracing::warn!("Rejected connect with invalid namespace: {}", packet.namespace); + tracing::warn!( + "Rejected connect with invalid namespace: {}", + packet.namespace + ); return; } let namespace = namespaces.get_or_create_namespace(&packet.namespace); - // Ensure newly created namespaces get the shared adapter + // Ensure newly created namespaces get the shared adapter before registration. { let ns_adapter = namespace.adapter.read().await; if ns_adapter.is_none() { drop(ns_adapter); - let adapter_ref = adapter.clone(); - let ns_clone = namespace.clone(); - tokio::spawn(async move { - ns_clone.set_adapter(adapter_ref).await; - }); + namespace.set_adapter(adapter.clone()).await; } } @@ -198,7 +214,10 @@ async fn handle_connect( // Run connect handler and add to namespace. // If the handler rejects, clean up and do NOT send a Connect response. - if let Err(msg) = namespace.add_socket(socket.clone()).await { + if let Err(msg) = namespace + .add_socket(socket.clone(), packet.data.as_ref()) + .await + { tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg); socket_txs.remove(&socket_sid); return; @@ -227,7 +246,9 @@ async fn handle_connect( } // Forwarding task ended — ensure socket is cleaned up from namespace socket_txs_clone.remove(&socket_sid_clone); - namespace_clone.remove_socket_by_sid(&socket_sid_clone).await; + namespace_clone + .remove_socket_by_sid(&socket_sid_clone) + .await; }); // Send Connect response (only after handler passed) @@ -260,34 +281,27 @@ fn handle_disconnect( } } -fn handle_event( - engine_sid: &str, - packet: &Packet, - namespaces: &Arc, -) { - if let Some(namespace) = namespaces.get_namespace(&packet.namespace) { - if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) { - if let Some(ref data) = packet.data { - if let Some(arr) = data.as_array() { - if let Some(event) = arr.first().and_then(|v| v.as_str()) { - let event_data = if arr.len() > 1 { - serde_json::Value::Array(arr[1..].to_vec()) - } else { - serde_json::Value::Null - }; +fn handle_event(engine_sid: &str, packet: &Packet, namespaces: &Arc) { + if let Some(namespace) = namespaces.get_namespace(&packet.namespace) + && let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) + && let Some(ref data) = packet.data + && let Some(arr) = data.as_array() + && let Some(event) = arr.first().and_then(|v| v.as_str()) + { + let event_data = if arr.len() > 1 { + serde_json::Value::Array(arr[1..].to_vec()) + } else { + serde_json::Value::Null + }; - let namespace_clone = namespace.clone(); - let event = event.to_string(); - let socket_clone = socket.clone(); - tokio::spawn(async move { - namespace_clone - .handle_event(&socket_clone, &event, &event_data) - .await; - }); - } - } - } - } + let namespace_clone = namespace.clone(); + let event = event.to_string(); + let socket_clone = socket.clone(); + tokio::spawn(async move { + namespace_clone + .handle_event(socket_clone, &event, &event_data) + .await; + }); } } diff --git a/socket/session_store/memory.rs b/socket/session_store/memory.rs index 11a2a81..8e71ef7 100644 --- a/socket/session_store/memory.rs +++ b/socket/session_store/memory.rs @@ -33,7 +33,12 @@ fn now_millis() -> u64 { #[async_trait] impl SessionStoreTrait for InMemorySessionStore { - async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> { + async fn create( + &self, + sid: &str, + transport: &str, + server_id: &str, + ) -> Result<(), SessionError> { let info = SessionInfo { sid: sid.to_string(), transport: transport.to_string(), @@ -85,4 +90,4 @@ impl SessionStoreTrait for InMemorySessionStore { async fn exists(&self, sid: &str) -> Result { Ok(self.sessions.contains_key(sid)) } -} \ No newline at end of file +} diff --git a/socket/session_store/mod.rs b/socket/session_store/mod.rs index 3c4934d..68cb3cc 100644 --- a/socket/session_store/mod.rs +++ b/socket/session_store/mod.rs @@ -28,7 +28,8 @@ pub struct SessionInfo { #[async_trait] pub trait SessionStoreTrait: Send + Sync + 'static { - async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError>; + async fn create(&self, sid: &str, transport: &str, server_id: &str) + -> Result<(), SessionError>; async fn get(&self, sid: &str) -> Result, SessionError>; async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>; async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>; @@ -38,4 +39,4 @@ pub trait SessionStoreTrait: Send + Sync + 'static { } pub use memory::InMemorySessionStore; -pub use redis::RedisSessionStore; \ No newline at end of file +pub use redis::RedisSessionStore; diff --git a/socket/session_store/redis.rs b/socket/session_store/redis.rs index 352d7f9..84f5c9b 100644 --- a/socket/session_store/redis.rs +++ b/socket/session_store/redis.rs @@ -36,7 +36,12 @@ impl RedisSessionStore { #[async_trait] impl SessionStoreTrait for RedisSessionStore { - async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> { + async fn create( + &self, + sid: &str, + transport: &str, + server_id: &str, + ) -> Result<(), SessionError> { let key = self.key(sid); let now = now_millis(); @@ -67,7 +72,8 @@ impl SessionStoreTrait for RedisSessionStore { // Use hgetall directly — if the key doesn't exist Redis returns an empty map. // This avoids the TOCTOU race between EXISTS and HGETALL. - let values: std::collections::HashMap = self.client + let values: std::collections::HashMap = self + .client .hgetall::, _>(&key) .await .map_err(|e| SessionError::Redis(e.to_string()))?; @@ -81,8 +87,14 @@ impl SessionStoreTrait for RedisSessionStore { transport: values.get("transport").cloned().unwrap_or_default(), state: values.get("state").cloned().unwrap_or_default(), server_id: values.get("server_id").cloned().unwrap_or_default(), - created_at: values.get("created_at").and_then(|v| v.parse::().ok()).unwrap_or(0), - last_ping: values.get("last_ping").and_then(|v| v.parse::().ok()).unwrap_or(0), + created_at: values + .get("created_at") + .and_then(|v| v.parse::().ok()) + .unwrap_or(0), + last_ping: values + .get("last_ping") + .and_then(|v| v.parse::().ok()) + .unwrap_or(0), }; Ok(Some(info)) @@ -154,11 +166,12 @@ impl SessionStoreTrait for RedisSessionStore { async fn exists(&self, sid: &str) -> Result { let key = self.key(sid); - let exists: bool = self.client + let exists: bool = self + .client .exists::(&key) .await .map_err(|e| SessionError::Redis(e.to_string()))?; Ok(exists) } -} \ No newline at end of file +} diff --git a/socket/socket.rs b/socket/socket.rs index 26bfa3b..8ed47cd 100644 --- a/socket/socket.rs +++ b/socket/socket.rs @@ -1,6 +1,8 @@ +use std::sync::OnceLock; use std::sync::atomic::{AtomicU64, Ordering}; use tokio::sync::mpsc; +use uuid::Uuid; use crate::socket::packet::Packet; @@ -8,10 +10,13 @@ pub struct Socket { pub sid: String, pub namespace: String, pub engine_sid: String, + /// Authenticated user ID, set once during `on_connect`. + user_id: OnceLock, ack_id: AtomicU64, tx: mpsc::Sender, } +#[allow(clippy::result_large_err)] impl Socket { pub fn new( sid: String, @@ -24,10 +29,22 @@ impl Socket { namespace, engine_sid, ack_id: AtomicU64::new(0), + user_id: OnceLock::new(), tx, } } + /// Set the authenticated user ID after JWT verification. + /// Safe to call once; subsequent calls are ignored. + pub fn set_user_id(&self, id: Uuid) { + let _ = self.user_id.set(id); + } + + /// Get the authenticated user ID, if set. + pub fn user_id(&self) -> Option { + self.user_id.get().copied() + } + pub fn next_ack_id(&self) -> u64 { self.ack_id.fetch_add(1, Ordering::SeqCst) } @@ -36,7 +53,11 @@ impl Socket { self.tx.try_send(packet.clone()) } - pub fn emit(&self, event: impl Into, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError> { + pub fn emit( + &self, + event: impl Into, + data: serde_json::Value, + ) -> Result<(), mpsc::error::TrySendError> { let packet = Packet::event( &self.namespace, serde_json::json!([event.into(), data]), @@ -65,7 +86,11 @@ impl Socket { self.send_packet(&packet) } - pub fn send_ack(&self, id: u64, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError> { + pub fn send_ack( + &self, + id: u64, + data: serde_json::Value, + ) -> Result<(), mpsc::error::TrySendError> { let packet = Packet::ack(&self.namespace, data, id); self.send_packet(&packet) } diff --git a/svc/article.rs b/svc/article.rs new file mode 100644 index 0000000..61efe6a --- /dev/null +++ b/svc/article.rs @@ -0,0 +1,221 @@ +//! Forum article event handlers on `MessageService`. +//! +//! Articles are long-form posts in forum channels. Creating an article +//! creates both a `message` (with `message_type = "article"`) and a +//! `message_article` row linked to it. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::models::message::MessageType; +use crate::models::message_article::ArticleSort; +use crate::repo::CreateMessageInput; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `article:create` — create a forum article (message + article metadata). + pub async fn create_article( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let title: String = Self::parse_field(arr, "title")?; + let body: String = Self::parse_field(arr, "body")?; + let summary: Option = Self::parse_optional(arr, "summary")?; + let cover_url: Option = Self::parse_optional(arr, "cover_url")?; + let tags: Option = Self::parse_optional(arr, "tags")?; + + self.validate_channel_write(&channel_id.to_string(), &user_id.to_string()) + .await?; + + // Create the message first (with article type) + let input = CreateMessageInput { + channel_id, + author_id: user_id, + thread_id: None, + reply_to_message_id: None, + message_type: MessageType::Article.as_str().into(), + body, + metadata: None, + system: false, + }; + + let message = self.repo.create(&input).await?; + + // Create the article record + let article = self + .repo + .create_article( + message.id, + &title, + summary.as_deref(), + cover_url.as_deref(), + None, + None, + None, + tags.as_ref(), + ) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "article:created", + serde_json::json!({ + "message": message, + "article": article, + }), + ) + .await; + } + + tracing::info!(article_id = %article.id, %channel_id, %user_id, "Article created"); + Ok(()) + } + + /// Handle `article:update` — update article title, summary, cover, tags. + pub async fn update_article( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let message_id: Uuid = Self::parse_field(arr, "message_id")?; + let title: Option = Self::parse_optional(arr, "title")?; + let summary: Option = Self::parse_optional(arr, "summary")?; + let cover_url: Option = Self::parse_optional(arr, "cover_url")?; + let cover_color: Option = Self::parse_optional(arr, "cover_color")?; + let tags: Option = Self::parse_optional(arr, "tags")?; + + let existing = self + .repo + .get(message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?; + self.ensure_author_or_mod( + existing.author_id, + &existing.channel_id.to_string(), + user_id, + ) + .await?; + + // Update article body if provided + if let Ok(new_body) = Self::parse_field::(arr, "body") + && !new_body.is_empty() + { + let old_body = existing.body.clone(); + self.repo.update_body(message_id, &new_body).await?; + self.repo + .record_edit(message_id, user_id, &old_body, &new_body) + .await?; + } + + if let Some(updated) = self + .repo + .update_article( + message_id, + title.as_deref(), + summary.as_deref(), + cover_url.as_deref(), + cover_color.as_deref(), + tags.as_ref(), + ) + .await? + && let Some(ns) = self.namespaces.get_namespace(&socket.namespace) + { + ns.emit_to_room( + &existing.channel_id.to_string(), + "article:updated", + serde_json::to_value(&updated).unwrap_or_default(), + ) + .await; + } + + Ok(()) + } + + /// Handle `article:list` — list articles in a forum channel. + pub async fn list_articles( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + self.ensure_readable(&channel_id.to_string(), &user_id.to_string()) + .await?; + self.ensure_member(&channel_id.to_string(), &user_id.to_string()) + .await?; + let before: Option<(i64, Uuid)> = None; + let limit: Option = Self::parse_optional(arr, "limit")?; + + let page = self + .repo + .list_articles(channel_id, ArticleSort::LatestActivity, before, limit) + .await?; + let _ = socket.emit( + "article:loaded", + serde_json::to_value(&page).unwrap_or_default(), + ); + Ok(()) + } + + /// Handle `article:delete` — soft-delete an article. + pub async fn delete_article( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let message_id: Uuid = Self::parse_field(arr, "message_id")?; + let _channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + + let existing = self + .repo + .get(message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?; + self.ensure_author_or_mod( + existing.author_id, + &existing.channel_id.to_string(), + user_id, + ) + .await?; + + self.repo.soft_delete(message_id).await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room(&existing.channel_id.to_string(), "article:deleted", + serde_json::json!({"id": message_id.to_string(), "channel_id": existing.channel_id.to_string()}), + ).await; + } + + Ok(()) + } +} diff --git a/svc/bookmark.rs b/svc/bookmark.rs new file mode 100644 index 0000000..e31ab7a --- /dev/null +++ b/svc/bookmark.rs @@ -0,0 +1,75 @@ +//! Bookmark event handlers on `MessageService`. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `bookmark:add` — toggle (add/update) a bookmark. + pub async fn add_bookmark( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let message_id: Uuid = Self::parse_field(arr, "message_id")?; + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let note: Option = Self::parse_optional(arr, "note")?; + + self.repo + .add_bookmark(message_id, channel_id, user_id, note.as_deref()) + .await?; + Ok(()) + } + + /// Handle `bookmark:remove` — remove a bookmark. + pub async fn remove_bookmark( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let message_id: Uuid = Self::parse_field(arr, "message_id")?; + + self.repo.remove_bookmark(message_id, user_id).await?; + Ok(()) + } + + /// Handle `bookmark:list` — list a user's bookmarks. + pub async fn list_bookmarks( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let before: Option = Self::parse_optional(arr, "before")?; + let limit: Option = Self::parse_optional(arr, "limit")?; + + let page = self.repo.list_bookmarks(user_id, before, limit).await?; + let _ = socket.emit( + "bookmark:loaded", + serde_json::to_value(&page).unwrap_or_default(), + ); + Ok(()) + } +} diff --git a/svc/component.rs b/svc/component.rs new file mode 100644 index 0000000..a582ae6 --- /dev/null +++ b/svc/component.rs @@ -0,0 +1,98 @@ +//! Interactive component event handlers on `MessageService`. +//! +//! Handles button clicks and select menu interactions on message components. +//! When a user clicks a button, the server updates the component state and +//! broadcasts the interaction to the channel. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `component:interact` — a user clicked a button or selected from a menu. + pub async fn interact_component( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let component_id: Uuid = Self::parse_field(arr, "component_id")?; + let custom_id: String = Self::parse_field(arr, "custom_id")?; + let message_id: Uuid = Self::parse_field(arr, "message_id")?; + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + + // Get current components to verify the interaction is valid + let components = self.repo.get_components(message_id).await?; + let component = components.iter().find(|c| c.id == component_id); + + if component.is_none() { + return Err(ImksError::NotFound(format!("component {component_id}"))); + } + + // Broadcast the interaction event to all clients in the channel. + // The actual action (e.g., approve/deny) is handled by the bot/webhook + // that listens for this event. The server just relays and disables the + // component to prevent double-clicks. + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "component:interaction", + serde_json::json!({ + "component_id": component_id.to_string(), + "custom_id": custom_id, + "message_id": message_id.to_string(), + "user_id": user_id.to_string(), + "channel_id": channel_id.to_string(), + }), + ) + .await; + } + + tracing::info!(%component_id, %user_id, %custom_id, "Component interaction"); + + Ok(()) + } + + /// Handle `component:update` — update a component's state (e.g., disable after interaction). + #[allow(dead_code)] + pub async fn update_component( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let component_id: Uuid = Self::parse_field(arr, "component_id")?; + let label: Option = Self::parse_optional(arr, "label")?; + let disabled: bool = Self::parse_optional(arr, "disabled")?.unwrap_or(true); + + if let Some(updated) = self + .repo + .update_component(component_id, label.as_deref(), disabled) + .await? + && let Some(ns) = self.namespaces.get_namespace(&socket.namespace) + { + ns.emit_to_room( + &updated.message_id.to_string(), + "component:updated", + serde_json::to_value(&updated).unwrap_or_default(), + ) + .await; + } + + Ok(()) + } +} diff --git a/svc/deploy.rs b/svc/deploy.rs new file mode 100644 index 0000000..47e5c71 --- /dev/null +++ b/svc/deploy.rs @@ -0,0 +1,62 @@ +//! Server deployment configuration. +//! +//! Reads from environment variables to select adapter (local/redis/nats) +//! and WebTransport settings. + +use std::env; + +/// Adapter + message bus configuration for multi-node scale-out. +#[derive(Debug, Clone)] +pub struct DeployConfig { + /// "local" | "redis" | "nats" + pub adapter_mode: String, + /// Redis connection URL (used when adapter_mode = "redis"). + pub redis_url: String, + /// NATS connection URL (used when adapter_mode = "nats"). + pub nats_url: String, + /// Unique server ID for this node. + pub server_id: String, + /// Enable WebTransport server. + pub webtransport_enabled: bool, + /// WebTransport listen port. + pub webtransport_port: u16, + /// TLS certificate path (required for WebTransport). + pub cert_path: String, + /// TLS key path (required for WebTransport). + pub key_path: String, +} + +impl DeployConfig { + pub fn from_env() -> Self { + let server_id = env::var("IMKS_SERVER_ID").unwrap_or_else(|_| hostname()); + + Self { + adapter_mode: env::var("IMKS_ADAPTER").unwrap_or_else(|_| "local".into()), + redis_url: env::var("IMKS_REDIS_URL") + .unwrap_or_else(|_| "redis://localhost:6379".into()), + nats_url: env::var("IMKS_NATS_URL").unwrap_or_else(|_| "nats://localhost:4222".into()), + server_id, + webtransport_enabled: env::var("IMKS_WT_ENABLED") + .map(|v| v == "true" || v == "1") + .unwrap_or(false), + webtransport_port: env::var("IMKS_WT_PORT") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(3001), + cert_path: env::var("IMKS_WT_CERT_PATH").unwrap_or_default(), + key_path: env::var("IMKS_WT_KEY_PATH").unwrap_or_default(), + } + } +} + +impl Default for DeployConfig { + fn default() -> Self { + Self::from_env() + } +} + +fn hostname() -> String { + env::var("HOSTNAME") + .or_else(|_| env::var("HOST")) + .unwrap_or_else(|_| "imks-node-1".into()) +} diff --git a/svc/draft.rs b/svc/draft.rs new file mode 100644 index 0000000..bc2c052 --- /dev/null +++ b/svc/draft.rs @@ -0,0 +1,105 @@ +//! Draft event handlers on `MessageService`. +//! +//! Drafts are per-user private data — no permission checks beyond auth needed. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `draft:save` — upsert a draft. + pub async fn save_draft( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let body: String = Self::parse_field(arr, "body")?; + let thread_id: Option = Self::parse_optional(arr, "thread_id")?; + let reply_to_message_id: Option = Self::parse_optional(arr, "reply_to_message_id")?; + let metadata: Option = Self::parse_optional(arr, "metadata")?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.validate_channel_write(&channel_id_str, &user_id_str) + .await?; + + self.repo + .upsert_draft( + channel_id, + user_id, + thread_id, + &body, + reply_to_message_id, + metadata, + ) + .await?; + Ok(()) + } + + /// Handle `draft:get` — retrieve a draft and send it back to the requesting socket. + pub async fn get_draft( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let thread_id: Option = Self::parse_optional(arr, "thread_id")?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + let draft = self.repo.get_draft(channel_id, user_id, thread_id).await?; + if let Some(d) = draft { + let _ = socket.emit("draft:loaded", serde_json::to_value(&d).unwrap_or_default()); + } else { + let _ = socket.emit("draft:loaded", serde_json::json!(null)); + } + Ok(()) + } + + /// Handle `draft:delete` — delete a draft after sending. + pub async fn delete_draft( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let thread_id: Option = Self::parse_optional(arr, "thread_id")?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + self.repo + .delete_draft(channel_id, user_id, thread_id) + .await?; + Ok(()) + } +} diff --git a/svc/message.rs b/svc/message.rs new file mode 100644 index 0000000..5a9bb93 --- /dev/null +++ b/svc/message.rs @@ -0,0 +1,970 @@ +//! Message service — the business logic layer connecting auth, permissions, +//! persistence, and real-time broadcast. +//! +//! Validates every operation through the gRPC permission chain before +//! touching the database or broadcasting to the room. + +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use chrono::Utc; +use dashmap::DashMap; +use tracing; +use uuid::Uuid; + +use crate::auth::Authenticator; +use crate::models::message::Message; +use crate::pb::im::{ + CheckPermissionRequest, EnsureReadableRequest, ImPermission, IsMemberRequest, + ResolveChannelRequest, +}; +use crate::repo::message_repo::MessageRepo; +use crate::rpc::AppksClients; +use crate::socket::namespace::NamespaceManager; +use crate::socket::socket::Socket; +use crate::{ImksError, ImksResult}; + +/// Central business-logic service for message operations. +/// +/// Every mutating operation performs the following checks in order: +/// 1. JWT authentication (via `Authenticator`) +/// 2. Nonce deduplication (prevent duplicate sends) +/// 3. Rate limiting (per-user, per-channel sliding window) +/// 4. Message body size validation (max 100 KB) +/// 5. Channel readability (`PermissionService.EnsureReadable`) +/// 6. Channel status (`ResolveChannel` → read_only / archived) +/// 7. Channel membership (`MemberService.IsMember`) +/// 8. Operation-specific permission (`PermissionService.CheckPermission`) +/// 9. Ownership validation (for edit/delete of others' messages) +/// +/// Only after all gates pass does the operation reach the database +/// and the broadcast adapter. +#[derive(Clone)] +pub struct MessageService { + pub(crate) repo: MessageRepo, + pub(crate) auth: Arc, + pub(crate) clients: AppksClients, + pub(crate) namespaces: Arc, + /// Rate limiter: stores timestamps of recent sends per (user, channel). + rate_limits: Arc>>, + /// Nonce dedup cache: nonce → first-seen timestamp. Uses TTL eviction. + nonces: Arc>, + /// Max message body length in bytes. + max_body_size: usize, + /// Per-user, per-channel messages allowed in the rate window. + rate_limit_count: usize, + /// Rate-limiting window duration. + rate_window: Duration, + /// Nonce TTL before reuse is allowed again. + nonce_ttl: Duration, +} + +impl MessageService { + /// Create a new message service. + /// + /// Internally initializes the JWT `Authenticator` from the token client + /// so that `authenticate_socket()` can verify JWT tokens during `on_connect`. + pub async fn new( + repo: MessageRepo, + clients: AppksClients, + namespaces: Arc, + ) -> ImksResult { + let auth = Arc::new(Authenticator::new(clients.token.clone()).await?); + Ok(Self { + repo, + auth, + clients, + namespaces, + rate_limits: Arc::new(DashMap::new()), + nonces: Arc::new(DashMap::new()), + max_body_size: 100_000, + rate_limit_count: 10, + rate_window: Duration::from_secs(10), + nonce_ttl: Duration::from_secs(300), + }) + } + + pub fn namespaces(&self) -> &NamespaceManager { + &self.namespaces + } + + /// Verify a JWT token and attach the authenticated `user_id` to the socket. + /// Call this from `on_connect` before registering any event handlers. + pub fn authenticate_socket( + &self, + socket: &Socket, + auth_data: Option<&serde_json::Value>, + ) -> ImksResult<()> { + let token = auth_data + .and_then(|v| v.get("token")) + .and_then(|v| v.as_str()) + .ok_or_else(|| ImksError::Auth("Missing auth token".into()))?; + + let claims = self.auth.verify_local(token)?; + if !claims.has_scope("im:read") && !claims.has_scope("im:write") { + return Err(ImksError::Auth("Token lacks im scope".into())); + } + + let user_id = Uuid::parse_str(&claims.sub) + .map_err(|_| ImksError::Auth(format!("Invalid user ID in token: {}", claims.sub)))?; + + socket.set_user_id(user_id); + tracing::info!(user_id = %user_id, socket_sid = %socket.sid, "Socket authenticated"); + Ok(()) + } + + /// Handle `channel:join` event by adding the socket to the channel room. + pub async fn join_channel( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult<()> { + let user_id = self.user_id(&socket)?; + let payload = Self::first_payload(data)?; + let channel_id: Uuid = Self::parse_field(payload, "channel_id")?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + let namespace = self + .namespaces + .get_namespace(&socket.namespace) + .ok_or_else(|| { + ImksError::Namespace(format!("namespace {} not found", socket.namespace)) + })?; + namespace.join_room(&socket.sid, &channel_id_str).await?; + tracing::info!(socket_sid = %socket.sid, %channel_id, %user_id, "Socket joined channel room"); + Ok(()) + } + + /// Handle `channel:leave` event by removing the socket from the channel room. + pub async fn leave_channel( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult<()> { + let _user_id = self.user_id(&socket)?; + let payload = Self::first_payload(data)?; + let channel_id: Uuid = Self::parse_field(payload, "channel_id")?; + let channel_id_str = channel_id.to_string(); + + if let Some(namespace) = self.namespaces.get_namespace(&socket.namespace) { + namespace.leave_room(&socket.sid, &channel_id_str).await?; + } + tracing::info!(socket_sid = %socket.sid, %channel_id, "Socket left channel room"); + Ok(()) + } + + /// Handle `message:send` event from a connected socket. + /// + /// Expected `data` shape (Socket.IO event array tail): + /// ```json + /// [{ + /// "channel_id": "...", + /// "body": "hello world", + /// "thread_id": null, + /// "reply_to_message_id": null, + /// "nonce": "optional-client-idempotency-key" + /// }] + /// ``` + pub async fn send_message( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult { + let user_id = socket + .user_id() + .ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))?; + + let payload = self.parse_send_payload(data)?; + let channel_id = payload.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.validate_send_payload(&payload.body, &payload.nonce, user_id, channel_id)?; + self.validate_channel_write(&channel_id_str, &user_id_str) + .await?; + + let nonce_key = match payload.nonce.as_deref() { + Some(nonce) => Some(self.reserve_nonce(nonce, user_id, channel_id)?), + None => None, + }; + + let result = self + .create_and_dispatch(payload, user_id, &socket.namespace) + .await; + if result.is_err() { + self.release_nonce(nonce_key); + } + + result + } + + /// Handle `message:edit` event. + /// + /// The author can edit their own message. Users with `MANAGE_MESSAGES` + /// permission can edit any message in the channel. + pub async fn edit_message( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult { + let user_id = socket + .user_id() + .ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))?; + + let payload = Self::first_payload(data)?; + let message_id: Uuid = Self::parse_field(payload, "message_id")?; + let new_body: String = Self::parse_field(payload, "body")?; + + let existing = self + .repo + .get(message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?; + + self.ensure_author_or_mod( + existing.author_id, + &existing.channel_id.to_string(), + user_id, + ) + .await?; + + let old_body = existing.body.clone(); + let updated = self.repo.update_body(message_id, &new_body).await?; + self.repo + .record_edit(message_id, user_id, &old_body, &new_body) + .await?; + + let namespace = self.namespaces.get_namespace(&socket.namespace); + if let Some(ns) = namespace { + ns.emit_to_room( + &existing.channel_id.to_string(), + "message:updated", + serde_json::to_value(&updated).unwrap_or_default(), + ) + .await; + } + + tracing::info!(message_id = %message_id, user_id = %user_id, "Message edited"); + Ok(updated) + } + + /// Handle `message:delete` event (soft-delete). + /// + /// Same ownership / moderator rules as edit. + pub async fn delete_message( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult<()> { + let user_id = socket + .user_id() + .ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))?; + + let payload = Self::first_payload(data)?; + let message_id: Uuid = Self::parse_field(payload, "message_id")?; + + let existing = self + .repo + .get(message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?; + + self.ensure_author_or_mod( + existing.author_id, + &existing.channel_id.to_string(), + user_id, + ) + .await?; + + self.repo.soft_delete(message_id).await?; + + let namespace = self.namespaces.get_namespace(&socket.namespace); + if let Some(ns) = namespace { + ns.emit_to_room( + &existing.channel_id.to_string(), + "message:deleted", + serde_json::json!({ "id": message_id.to_string(), "channel_id": existing.channel_id.to_string() }), + ) + .await; + } + + tracing::info!(message_id = %message_id, user_id = %user_id, "Message deleted"); + Ok(()) + } + + // Permission validation helpers + + /// Full write-access gate: resolve channel + readability + membership + SEND_MESSAGE. + pub(crate) async fn validate_channel_write( + &self, + channel_id: &str, + user_id: &str, + ) -> ImksResult<()> { + // Check read-only / archived first (fast gRPC, returns early) + let channel_info = self.resolve_channel(channel_id).await?; + if channel_info.read_only { + return Err(ImksError::Auth(format!( + "Channel {channel_id} is read-only" + ))); + } + if channel_info.archived { + return Err(ImksError::Auth(format!("Channel {channel_id} is archived"))); + } + + self.ensure_readable(channel_id, user_id).await?; + self.ensure_member(channel_id, user_id).await?; + self.ensure_permission(channel_id, user_id, ImPermission::SendMessage) + .await?; + Ok(()) + } + + /// Verify the user can read this channel at all. + pub(crate) async fn ensure_readable(&self, channel_id: &str, user_id: &str) -> ImksResult<()> { + let mut client = self.clients.permission.clone(); + let resp = client + .ensure_readable(EnsureReadableRequest { + channel_id: channel_id.to_string(), + user_id: user_id.to_string(), + }) + .await?; + + let inner = resp.into_inner(); + if !inner.allowed { + return Err(ImksError::Auth(format!( + "User {user_id} cannot read channel {channel_id}" + ))); + } + Ok(()) + } + + pub(crate) fn user_id(&self, socket: &Socket) -> crate::ImksResult { + socket + .user_id() + .ok_or_else(|| ImksError::Auth("Socket not authenticated".into())) + } + + pub(crate) async fn ensure_member( + &self, + channel_id: &str, + user_id: &str, + ) -> crate::ImksResult<()> { + let mut client = self.clients.member.clone(); + let resp = client + .is_member(IsMemberRequest { + channel_id: channel_id.to_string(), + user_id: user_id.to_string(), + }) + .await?; + + let inner = resp.into_inner(); + if !inner.is_member { + return Err(ImksError::Auth(format!( + "User {user_id} is not a member of channel {channel_id}" + ))); + } + Ok(()) + } + + /// Verify a specific permission. + async fn ensure_permission( + &self, + channel_id: &str, + user_id: &str, + permission: ImPermission, + ) -> ImksResult<()> { + let allowed = self + .check_permission(channel_id, user_id, permission) + .await?; + + if !allowed { + return Err(ImksError::Auth(format!( + "User {user_id} lacks permission {permission:?} in channel {channel_id}" + ))); + } + Ok(()) + } + + /// Verify the user is the message author or has ManageMessages permission. + pub(crate) async fn ensure_author_or_mod( + &self, + message_author_id: Uuid, + channel_id: &str, + user_id: Uuid, + ) -> ImksResult<()> { + if message_author_id == user_id { + return Ok(()); + } + let allowed = self + .check_permission( + channel_id, + &user_id.to_string(), + ImPermission::ManageMessages, + ) + .await?; + if !allowed { + return Err(ImksError::Auth( + "Only the author or a moderator can modify this message".into(), + )); + } + Ok(()) + } + + /// Low-level permission check returning a boolean. + pub(crate) async fn check_permission( + &self, + channel_id: &str, + user_id: &str, + permission: ImPermission, + ) -> ImksResult { + let mut client = self.clients.permission.clone(); + let resp = client + .check_permission(CheckPermissionRequest { + channel_id: channel_id.to_string(), + user_id: user_id.to_string(), + permission: permission as i32, + }) + .await?; + + Ok(resp.into_inner().allowed) + } + + /// Resolve channel metadata via gRPC to check read_only / archived status. + async fn resolve_channel( + &self, + channel_id: &str, + ) -> ImksResult { + let mut client = self.clients.permission.clone(); + let resp = client + .resolve_channel(ResolveChannelRequest { + channel_id: channel_id.to_string(), + }) + .await?; + Ok(resp.into_inner()) + } + + // Rate limiting & dedup + + /// Reserve a nonce atomically. Nonces expire after `nonce_ttl`. + fn reserve_nonce(&self, nonce: &str, user_id: Uuid, channel_id: Uuid) -> ImksResult { + let now = Instant::now(); + let key = format!("{user_id}:{channel_id}:{nonce}"); + + // Cleanup expired nonces periodically (probabilistic, ~1/64 chance per check) + if rand::random::().is_multiple_of(64) { + self.nonces + .retain(|_, t| now.duration_since(*t) < self.nonce_ttl); + } + + match self.nonces.entry(key.clone()) { + dashmap::mapref::entry::Entry::Occupied(mut entry) => { + if now.duration_since(*entry.get()) < self.nonce_ttl { + return Err(ImksError::InvalidInput( + "Duplicate message: this nonce was already used".into(), + )); + } + entry.insert(now); + } + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(now); + } + } + Ok(key) + } + + fn release_nonce(&self, nonce_key: Option) { + if let Some(key) = nonce_key { + self.nonces.remove(&key); + } + } + + fn validate_body_size(&self, body: &str) -> ImksResult<()> { + if body.len() > self.max_body_size { + return Err(ImksError::InvalidInput(format!( + "Message body exceeds max size of {} bytes (got {})", + self.max_body_size, + body.len() + ))); + } + Ok(()) + } + + /// Check per-user, per-channel rate limit using a sliding window. + fn check_rate_limit(&self, user_id: Uuid, channel_id: Uuid) -> ImksResult<()> { + let key = (user_id, channel_id); + let now = Instant::now(); + + let mut entry = self.rate_limits.entry(key).or_default(); + + // Evict timestamps outside the window + entry.retain(|t| now.duration_since(*t) < self.rate_window); + + if entry.len() >= self.rate_limit_count { + return Err(ImksError::InvalidInput(format!( + "Rate limit exceeded: max {} messages per {}s", + self.rate_limit_count, + self.rate_window.as_secs() + ))); + } + + entry.push(now); + Ok(()) + } + + /// Combined safety checks for message sending: body size and rate limit. + fn validate_send_payload( + &self, + body: &str, + _nonce: &Option, + user_id: Uuid, + channel_id: Uuid, + ) -> ImksResult<()> { + self.validate_body_size(body)?; + self.check_rate_limit(user_id, channel_id)?; + Ok(()) + } + + // Payload parsers + + /// Parse attachment inputs from a JSON payload. + fn parse_attachments(value: &serde_json::Value) -> Vec { + value + .get("attachments") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|a| { + Some(AttachmentInput { + filename: a.get("filename")?.as_str()?.to_string(), + url: a.get("url")?.as_str()?.to_string(), + size: a.get("size")?.as_i64()?, + content_type: a + .get("content_type") + .and_then(|v| v.as_str()) + .map(String::from), + }) + }) + .collect() + }) + .unwrap_or_default() + } + + /// Parse embed inputs from a JSON payload. + fn parse_embeds(value: &serde_json::Value) -> Vec { + value + .get("embeds") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|e| { + let fields: Vec<(String, String, bool)> = e + .get("fields") + .and_then(|v| v.as_array()) + .map(|fa| { + fa.iter() + .filter_map(|f| { + Some(( + f.get("name")?.as_str()?.to_string(), + f.get("value")?.as_str()?.to_string(), + f.get("inline") + .and_then(|v| v.as_bool()) + .unwrap_or(false), + )) + }) + .collect() + }) + .unwrap_or_default(); + + Some(EmbedInput { + embed_type: e.get("embed_type")?.as_str()?.to_string(), + title: e.get("title").and_then(|v| v.as_str()).map(String::from), + description: e + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), + url: e.get("url").and_then(|v| v.as_str()).map(String::from), + image_url: e + .get("image_url") + .and_then(|v| v.as_str()) + .map(String::from), + fields, + }) + }) + .collect() + }) + .unwrap_or_default() + } + + /// Parse a sticker input from a JSON payload (single object or null). + fn parse_sticker(value: &serde_json::Value) -> Option { + value.get("sticker").and_then(|s| { + Some(StickerInput { + sticker_id: MessageService::parse_field(s, "sticker_id").ok()?, + name: MessageService::parse_field(s, "name").ok()?, + image_url: MessageService::parse_field(s, "image_url").ok()?, + format_type: s + .get("format_type") + .and_then(|v| v.as_str()) + .unwrap_or("png") + .to_string(), + }) + }) + } + + /// Parse a forward input from a JSON payload (single object or null). + fn parse_forward(value: &serde_json::Value) -> Option { + value.get("forward").and_then(|f| { + Some(ForwardInput { + source_message_id: MessageService::parse_field(f, "source_message_id").ok()?, + source_channel_id: MessageService::parse_field(f, "source_channel_id").ok()?, + }) + }) + } + + /// Parse the `message:send` event payload including optional rich content. + fn parse_send_payload(&self, data: &serde_json::Value) -> ImksResult { + let arr = data + .as_array() + .ok_or_else(|| ImksError::InvalidInput("Event data must be a JSON array".into()))?; + + if arr.is_empty() { + return Err(ImksError::InvalidInput("Empty event data".into())); + } + + let payload = &arr[0]; + + let channel_id: Uuid = Self::parse_field(payload, "channel_id")?; + let body: String = Self::parse_field(payload, "body")?; + let thread_id = Self::parse_optional(payload, "thread_id")?; + let reply_to_message_id = Self::parse_optional(payload, "reply_to_message_id")?; + let nonce: Option = Self::parse_optional(payload, "nonce")?; + let mentioned_user_ids: Vec = + Self::parse_optional(payload, "mentioned_user_ids")?.unwrap_or_default(); + + let attachments = Self::parse_attachments(payload); + let embeds = Self::parse_embeds(payload); + let sticker = Self::parse_sticker(payload); + let forward = Self::parse_forward(payload); + + Ok(SendPayload { + channel_id, + body, + thread_id, + reply_to_message_id, + nonce, + mentioned_user_ids, + attachments, + embeds, + sticker, + forward, + }) + } + + pub(crate) fn first_payload(data: &serde_json::Value) -> ImksResult<&serde_json::Value> { + data.as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into())) + } + + pub(crate) fn parse_field( + value: &serde_json::Value, + field: &str, + ) -> crate::ImksResult { + let field_value = value + .get(field) + .ok_or_else(|| ImksError::InvalidInput(format!("Missing required field: {field}")))?; + + serde_json::from_value(field_value.clone()) + .map_err(|e| ImksError::InvalidInput(format!("Invalid field {field}: {e}"))) + } + + pub(crate) fn parse_optional( + value: &serde_json::Value, + field: &str, + ) -> ImksResult> { + match value.get(field) { + Some(v) if v.is_null() => Ok(None), + Some(v) => serde_json::from_value(v.clone()) + .map(Some) + .map_err(|e| ImksError::InvalidInput(format!("Invalid field {field}: {e}"))), + None => Ok(None), + } + } + + /// Create the message and all rich content in one database transaction, + /// then broadcast to the channel room after commit. + async fn create_and_dispatch( + &self, + payload: SendPayload, + user_id: Uuid, + namespace_path: &str, + ) -> ImksResult { + let mut tx = self.repo.pool().begin().await?; + let message_id = Uuid::now_v7(); + let now = Utc::now(); + + let message = sqlx::query_as::<_, Message>( + r#" + INSERT INTO message ( + id, channel_id, author_id, thread_id, reply_to_message_id, + message_type, body, metadata, pinned, system, + edited_at, deleted_at, created_at, updated_at + ) VALUES ( + $1, $2, $3, $4, $5, + 'text', $6, NULL, FALSE, FALSE, + NULL, NULL, $7, $7 + ) + RETURNING * + "#, + ) + .bind(message_id) + .bind(payload.channel_id) + .bind(user_id) + .bind(payload.thread_id) + .bind(payload.reply_to_message_id) + .bind(&payload.body) + .bind(now) + .fetch_one(&mut *tx) + .await?; + + if let Some(thread_id) = payload.thread_id { + sqlx::query( + r#" + UPDATE message_thread + SET replies_count = replies_count + 1, + last_reply_message_id = $1, + last_reply_at = $2, + updated_at = $2 + WHERE id = $3 + "#, + ) + .bind(message_id) + .bind(now) + .bind(thread_id) + .execute(&mut *tx) + .await?; + + sqlx::query( + r#" + INSERT INTO message_thread_participant (id, thread_id, user_id, joined_reason, joined_at) + VALUES ($1, $2, $3, 'reply', $4) + ON CONFLICT (thread_id, user_id) DO UPDATE SET joined_reason = EXCLUDED.joined_reason + "#, + ) + .bind(Uuid::now_v7()) + .bind(thread_id) + .bind(user_id) + .bind(now) + .execute(&mut *tx) + .await?; + + sqlx::query( + "UPDATE message_thread SET participants_count = (SELECT COUNT(*) FROM message_thread_participant WHERE thread_id = $1) WHERE id = $1", + ) + .bind(thread_id) + .execute(&mut *tx) + .await?; + } + + for att in &payload.attachments { + sqlx::query( + r#" + INSERT INTO message_attachment ( + id, message_id, filename, content_type, size, url, + storage_key, width, height, spoiler + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, FALSE) + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(&att.filename) + .bind(att.content_type.as_deref()) + .bind(att.size) + .bind(&att.url) + .execute(&mut *tx) + .await?; + } + + for emb in &payload.embeds { + let embed_id = Uuid::now_v7(); + sqlx::query( + r#" + INSERT INTO message_embed ( + id, message_id, embed_type, title, description, url, color, + image_url, author_name, author_url, footer_text, provider_name + ) VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, NULL, NULL, NULL, NULL) + "#, + ) + .bind(embed_id) + .bind(message_id) + .bind(&emb.embed_type) + .bind(emb.title.as_deref()) + .bind(emb.description.as_deref()) + .bind(emb.url.as_deref()) + .bind(emb.image_url.as_deref()) + .execute(&mut *tx) + .await?; + + for (position, (name, value, inline)) in emb.fields.iter().enumerate() { + sqlx::query( + r#" + INSERT INTO message_embed_field (id, embed_id, name, value, inline, position) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + ) + .bind(Uuid::now_v7()) + .bind(embed_id) + .bind(name) + .bind(value) + .bind(*inline) + .bind(position as i32) + .execute(&mut *tx) + .await?; + } + } + + if let Some(sticker) = &payload.sticker { + sqlx::query( + r#" + INSERT INTO message_sticker (id, message_id, sticker_id, name, image_url, format_type, pack_name, tags) + VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL) + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(sticker.sticker_id) + .bind(&sticker.name) + .bind(&sticker.image_url) + .bind(&sticker.format_type) + .execute(&mut *tx) + .await?; + } + + if let Some(forward) = &payload.forward { + sqlx::query( + r#" + INSERT INTO message_forward (id, message_id, source_message_id, source_channel_id, forwarded_by) + VALUES ($1, $2, $3, $4, $5) + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(forward.source_message_id) + .bind(forward.source_channel_id) + .bind(user_id) + .execute(&mut *tx) + .await?; + } + + for mentioned_id in &payload.mentioned_user_ids { + sqlx::query( + r#" + INSERT INTO message_mention (id, message_id, channel_id, mentioned_user_id, mentioned_by, created_at) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(payload.channel_id) + .bind(*mentioned_id) + .bind(user_id) + .bind(now) + .execute(&mut *tx) + .await?; + + sqlx::query( + r#" + INSERT INTO message_notification ( + id, message_id, channel_id, user_id, reason, status, delivery_channel, created_at + ) VALUES ($1, $2, $3, $4, 'mention', 'pending', NULL, $5) + "#, + ) + .bind(Uuid::now_v7()) + .bind(message_id) + .bind(payload.channel_id) + .bind(*mentioned_id) + .bind(now) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + self.broadcast_new_message(&message, payload.channel_id, user_id, namespace_path) + .await; + + Ok(message) + } + + /// Broadcast a newly created message to the channel room and log. + async fn broadcast_new_message( + &self, + message: &Message, + channel_id: Uuid, + user_id: Uuid, + namespace_path: &str, + ) { + let namespace = self.namespaces.get_namespace(namespace_path); + if let Some(ns) = namespace { + ns.emit_to_room( + &channel_id.to_string(), + "message:new", + serde_json::to_value(message).unwrap_or_default(), + ) + .await; + } + + tracing::info!( + message_id = %message.id, + channel_id = %channel_id, + user_id = %user_id, + "Message sent" + ); + } +} + +// Rich content input types for parsing + +pub(crate) struct SendPayload { + channel_id: Uuid, + body: String, + thread_id: Option, + reply_to_message_id: Option, + nonce: Option, + mentioned_user_ids: Vec, + attachments: Vec, + embeds: Vec, + sticker: Option, + forward: Option, +} + +pub(crate) struct AttachmentInput { + filename: String, + url: String, + size: i64, + content_type: Option, +} + +pub(crate) struct EmbedInput { + embed_type: String, + title: Option, + description: Option, + url: Option, + image_url: Option, + fields: Vec<(String, String, bool)>, +} + +pub(crate) struct StickerInput { + sticker_id: Uuid, + name: String, + image_url: String, + format_type: String, +} + +pub(crate) struct ForwardInput { + source_message_id: Uuid, + source_channel_id: Uuid, +} diff --git a/svc/mod.rs b/svc/mod.rs new file mode 100644 index 0000000..3eff3e6 --- /dev/null +++ b/svc/mod.rs @@ -0,0 +1,19 @@ +pub mod article; +pub mod bookmark; +pub mod component; +pub mod deploy; +pub mod draft; +pub mod message; +pub mod pin; +pub mod poll; +pub mod reaction; +pub mod read_state; +pub mod scheduled; +pub mod thread; +pub mod typing; + +#[cfg(test)] +mod tests; + +pub use deploy::DeployConfig; +pub use message::MessageService; diff --git a/svc/pin.rs b/svc/pin.rs new file mode 100644 index 0000000..c11621b --- /dev/null +++ b/svc/pin.rs @@ -0,0 +1,97 @@ +//! Pin event handlers on `MessageService`. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::ImksResult; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `pin:add` — pin a message, then broadcast to the channel room. + pub async fn pin_message( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult<()> { + let user_id = self.user_id(&socket)?; + let (channel_id, message_id) = self.parse_pin_payload(data)?; + + self.ensure_member(&channel_id.to_string(), &user_id.to_string()) + .await?; + + self.repo + .pin_message(channel_id, message_id, user_id) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + let ns = ns.clone(); + let cid = channel_id.to_string(); + let mid = message_id.to_string(); + tokio::spawn(async move { + ns.emit_to_room( + &cid, + "pin:added", + serde_json::json!({ + "channel_id": cid, + "message_id": mid, + "pinned_by": user_id.to_string(), + }), + ) + .await; + }); + } + + tracing::info!(%channel_id, %message_id, %user_id, "Message pinned"); + Ok(()) + } + + /// Handle `pin:remove` — unpin a message. + pub async fn unpin_message( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> ImksResult<()> { + let user_id = self.user_id(&socket)?; + let (channel_id, message_id) = self.parse_pin_payload(data)?; + + self.ensure_member(&channel_id.to_string(), &user_id.to_string()) + .await?; + + self.repo.unpin_message(channel_id, message_id).await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + let ns = ns.clone(); + let cid = channel_id.to_string(); + let mid = message_id.to_string(); + tokio::spawn(async move { + ns.emit_to_room( + &cid, + "pin:removed", + serde_json::json!({ + "channel_id": cid, + "message_id": mid, + }), + ) + .await; + }); + } + + tracing::info!(%channel_id, %message_id, %user_id, "Message unpinned"); + Ok(()) + } + + fn parse_pin_payload(&self, data: &serde_json::Value) -> ImksResult<(Uuid, Uuid)> { + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + Ok(( + Self::parse_field(arr, "channel_id")?, + Self::parse_field(arr, "message_id")?, + )) + } +} diff --git a/svc/poll.rs b/svc/poll.rs new file mode 100644 index 0000000..e31ee96 --- /dev/null +++ b/svc/poll.rs @@ -0,0 +1,97 @@ +//! Poll event handlers on `MessageService`. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `poll:vote` — cast a vote on a poll option. + pub async fn poll_vote( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let poll_id: Uuid = Self::parse_field(arr, "poll_id")?; + let option_id: Uuid = Self::parse_field(arr, "option_id")?; + let target = self.repo.get_poll_target(poll_id, option_id).await?; + let channel_id = target.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + let target = self + .repo + .cast_vote_checked(poll_id, option_id, user_id) + .await?; + if let Some(result) = self + .repo + .get_poll_result(target.message_id, user_id) + .await? + && let Some(ns) = self.namespaces.get_namespace(&socket.namespace) + { + ns.emit_to_room( + &target.channel_id.to_string(), + "poll:updated", + serde_json::to_value(&result).unwrap_or_default(), + ) + .await; + } + + Ok(()) + } + + /// Handle `poll:vote:remove` — retract a vote. + pub async fn poll_remove_vote( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let poll_id: Uuid = Self::parse_field(arr, "poll_id")?; + let option_id: Uuid = Self::parse_field(arr, "option_id")?; + let target = self.repo.get_poll_target(poll_id, option_id).await?; + let channel_id_str = target.channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + if let Some(target) = self + .repo + .remove_vote_checked(poll_id, option_id, user_id) + .await? + && let Some(result) = self + .repo + .get_poll_result(target.message_id, user_id) + .await? + && let Some(ns) = self.namespaces.get_namespace(&socket.namespace) + { + ns.emit_to_room( + &target.channel_id.to_string(), + "poll:updated", + serde_json::to_value(&result).unwrap_or_default(), + ) + .await; + } + + Ok(()) + } +} diff --git a/svc/reaction.rs b/svc/reaction.rs new file mode 100644 index 0000000..9dabe52 --- /dev/null +++ b/svc/reaction.rs @@ -0,0 +1,127 @@ +//! Reaction event handlers on `MessageService`. +//! +//! Toggle semantics: sending the same reaction again removes it. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `reaction:add` — toggle (add or remove) a reaction, then broadcast. + pub async fn toggle_reaction( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let (message_id, content) = self.parse_reaction_payload(data)?; + let message = self + .repo + .get(message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?; + let channel_id = message.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + let action = if self + .repo + .add_reaction(message_id, channel_id, user_id, &content) + .await? + .is_some() + { + tracing::info!(%message_id, %user_id, %content, "Reaction added"); + "add" + } else { + self.repo + .remove_reaction(message_id, user_id, &content) + .await?; + tracing::info!(%message_id, %user_id, %content, "Reaction removed"); + "remove" + }; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "reaction:updated", + serde_json::json!({ + "message_id": message_id.to_string(), + "channel_id": channel_id.to_string(), + "user_id": user_id.to_string(), + "content": content, + "action": action, + }), + ) + .await; + } + Ok(()) + } + + /// Handle `reaction:remove` — explicitly remove a reaction. + pub async fn remove_reaction( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let (message_id, content) = self.parse_reaction_payload(data)?; + let message = self + .repo + .get(message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?; + let channel_id = message.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + self.repo + .remove_reaction(message_id, user_id, &content) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "reaction:updated", + serde_json::json!({ + "message_id": message_id.to_string(), + "channel_id": channel_id.to_string(), + "user_id": user_id.to_string(), + "content": content, + "action": "remove", + }), + ) + .await; + } + Ok(()) + } + + fn parse_reaction_payload( + &self, + data: &serde_json::Value, + ) -> crate::ImksResult<(Uuid, String)> { + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let content: String = Self::parse_field(arr, "content")?; + if content.trim().is_empty() { + return Err(ImksError::InvalidInput( + "Reaction content cannot be empty".into(), + )); + } + + Ok((Self::parse_field(arr, "message_id")?, content)) + } +} diff --git a/svc/read_state.rs b/svc/read_state.rs new file mode 100644 index 0000000..b7dbfcc --- /dev/null +++ b/svc/read_state.rs @@ -0,0 +1,126 @@ +//! Read state and notification event handlers on `MessageService`. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + // Read state handlers + + /// Handle `read_state:mark` — mark a channel as read up to a message. + pub async fn mark_read( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let message_id: Uuid = Self::parse_field(arr, "message_id")?; + + let state = self.repo.mark_read(channel_id, user_id, message_id).await?; + let _ = socket.emit( + "read_state:updated", + serde_json::to_value(&state).unwrap_or_default(), + ); + Ok(()) + } + + /// Handle `read_state:get` — get read state for a channel. + pub async fn get_read_state( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + + let state = self.repo.get_read_state(channel_id, user_id).await?; + let _ = socket.emit( + "read_state:loaded", + serde_json::to_value(&state).unwrap_or_default(), + ); + Ok(()) + } + + // Notification handlers + + /// Handle `notification:list` — list a user's notifications. + pub async fn list_notifications( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let before: Option = Self::parse_optional(arr, "before")?; + let limit: Option = Self::parse_optional(arr, "limit")?; + + let page = self.repo.list_notifications(user_id, before, limit).await?; + let _ = socket.emit( + "notification:loaded", + serde_json::to_value(&page).unwrap_or_default(), + ); + Ok(()) + } + + /// Handle `notification:mark_read` — mark one notification as read. + pub async fn mark_notification_read( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let notification_id: Uuid = Self::parse_field(arr, "notification_id")?; + + self.repo.mark_notification_read(notification_id).await?; + + let unread = self.repo.get_unread_notification_count(user_id).await?; + let _ = socket.emit( + "notification:unread_count", + serde_json::json!({ "count": unread }), + ); + Ok(()) + } + + /// Handle `notification:mark_all_read` — mark all as read. + pub async fn mark_all_notifications_read( + &self, + socket: Arc, + _data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + + let affected = self.repo.mark_all_notifications_read(user_id).await?; + tracing::info!(%user_id, %affected, "All notifications marked read"); + + let _ = socket.emit( + "notification:unread_count", + serde_json::json!({ "count": 0 }), + ); + Ok(()) + } +} diff --git a/svc/scheduled.rs b/svc/scheduled.rs new file mode 100644 index 0000000..4ed07a7 --- /dev/null +++ b/svc/scheduled.rs @@ -0,0 +1,80 @@ +//! Scheduled message dispatcher on `MessageService`. +//! +//! A background task that periodically scans for due scheduled messages +//! and sends them through the normal message path. + +use std::time::Duration; + +use crate::repo::CreateMessageInput; + +use super::message::MessageService; + +impl MessageService { + /// Start the background scheduled-message dispatcher. + /// Scans every 30 seconds for pending messages whose `scheduled_at` has passed. + pub fn start_scheduled_dispatcher(self: std::sync::Arc) { + tokio::spawn(async move { + tracing::info!("Scheduled message dispatcher started (interval: 30s)"); + loop { + tokio::time::sleep(Duration::from_secs(30)).await; + + match self.process_due_scheduled().await { + Ok(count) => { + if count > 0 { + tracing::info!(count, "Dispatched scheduled messages"); + } + } + Err(e) => { + tracing::error!(error = %e, "Scheduled message dispatch failed"); + } + } + } + }); + } + + /// Fetch and dispatch all due scheduled messages. + async fn process_due_scheduled(&self) -> crate::ImksResult { + let due = self.repo.get_due_scheduled().await?; + let mut dispatched = 0; + + for scheduled in due { + let input = CreateMessageInput { + channel_id: scheduled.channel_id, + author_id: scheduled.author_id, + thread_id: scheduled.thread_id, + reply_to_message_id: scheduled.reply_to_message_id, + message_type: "text".into(), + body: scheduled.body.clone(), + metadata: scheduled.metadata.clone(), + system: false, + }; + + match self.repo.create(&input).await { + Ok(message) => { + self.repo + .mark_scheduled_sent(scheduled.id, message.id) + .await?; + + // Broadcast to channel + if let Some(ns) = self.namespaces.get_namespace("/") { + ns.emit_to_room( + &scheduled.channel_id.to_string(), + "message:new", + serde_json::to_value(&message).unwrap_or_default(), + ) + .await; + } + dispatched += 1; + } + Err(e) => { + tracing::error!(scheduled_id = %scheduled.id, error = %e, "Failed to send scheduled message"); + self.repo + .mark_scheduled_failed(scheduled.id, &e.to_string()) + .await?; + } + } + } + + Ok(dispatched) + } +} diff --git a/svc/tests.rs b/svc/tests.rs new file mode 100644 index 0000000..501e916 --- /dev/null +++ b/svc/tests.rs @@ -0,0 +1,166 @@ +//! Service layer unit tests for `MessageService`. +//! +//! Tests parsing, nonce dedup, rate limiting — all in-memory without +//! requiring a real database or gRPC connection. + +#[cfg(test)] +#[allow(clippy::module_inception)] +mod tests { + use std::sync::Arc; + use std::time::{Duration, Instant}; + + use uuid::Uuid; + + use crate::svc::message::MessageService; + + #[test] + fn test_parse_field_valid() { + let json = serde_json::json!({"message_id": "01909abc-def0-7000-8000-000000000001"}); + let id: Uuid = MessageService::parse_field(&json, "message_id").unwrap(); + assert_eq!(id.to_string(), "01909abc-def0-7000-8000-000000000001"); + } + + #[test] + fn test_parse_field_missing() { + let json = serde_json::json!({}); + let result: crate::ImksResult = MessageService::parse_field(&json, "missing"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Missing required field") + ); + } + + #[test] + fn test_parse_optional_present() { + let json = serde_json::json!({"name": "alice"}); + let val: Option = MessageService::parse_optional(&json, "name").unwrap(); + assert_eq!(val, Some("alice".into())); + } + + #[test] + fn test_parse_optional_null() { + let json = serde_json::json!({"name": null}); + let val: Option = MessageService::parse_optional(&json, "name").unwrap(); + assert_eq!(val, None); + } + + #[test] + fn test_parse_optional_missing() { + let json = serde_json::json!({}); + let val: Option = MessageService::parse_optional(&json, "name").unwrap(); + assert_eq!(val, None); + } + + #[test] + fn test_parse_send_payload_basic_shape() { + let json = serde_json::json!([{ + "channel_id": "01909abc-def0-7000-8000-000000000001", + "body": "hello world" + }]); + let arr = json.as_array().unwrap(); + let payload = &arr[0]; + assert_eq!(payload["body"], "hello world"); + } + + #[test] + fn test_parse_send_payload_with_rich_content() { + let json = serde_json::json!([{ + "channel_id": "01909abc-def0-7000-8000-000000000001", + "body": "hey @alice", + "thread_id": "01909def-abc0-7000-8000-000000000002", + "nonce": "nonce-001", + "mentioned_user_ids": ["01909abc-def0-7000-8000-000000000003"], + "attachments": [{"filename": "img.png", "url": "https://cdn/img.png", "size": 1024, "content_type": "image/png"}], + "embeds": [{"embed_type": "link", "title": "Example", "url": "https://example.com", "fields": [{"name": "k", "value": "v", "inline": true}]}], + "sticker": {"sticker_id": "01909abc-def0-7000-8000-000000000004", "name": "Hype!", "image_url": "https://cdn/sticker.png"}, + "forward": {"source_message_id": "01909abc-def0-7000-8000-000000000005", "source_channel_id": "01909abc-def0-7000-8000-000000000006"} + }]); + let arr = json.as_array().unwrap(); + let payload = &arr[0]; + assert_eq!(payload["body"], "hey @alice"); + assert!(payload["attachments"].is_array()); + assert!(payload["embeds"].is_array()); + assert!(payload["sticker"].is_object()); + assert!(payload["forward"].is_object()); + } + + #[test] + fn test_nonce_dedup_first_accepted() { + use dashmap::DashMap; + let nonces = Arc::new(DashMap::::new()); + assert!(!nonces.contains_key("nonce-1")); + nonces.insert("nonce-1".to_string(), Instant::now()); + assert!(nonces.contains_key("nonce-1")); + } + + #[test] + fn test_nonce_dedup_rejects_duplicate() { + use dashmap::DashMap; + let nonces = Arc::new(DashMap::::new()); + nonces.insert("nonce-1".to_string(), Instant::now()); + // After insert, it should exist + assert!(nonces.contains_key("nonce-1")); + } + + #[test] + fn test_rate_limit_within_window() { + use dashmap::DashMap; + let limits = Arc::new(DashMap::<(Uuid, Uuid), Vec>::new()); + let user = Uuid::now_v7(); + let channel = Uuid::now_v7(); + // Should be empty initially + assert!(limits.get(&(user, channel)).is_none()); + } + + #[test] + fn test_rate_limit_approaches_threshold() { + use dashmap::DashMap; + let limits = Arc::new(DashMap::<(Uuid, Uuid), Vec>::new()); + let now = Instant::now(); + let user = Uuid::now_v7(); + let channel = Uuid::now_v7(); + + let mut entry = limits.entry((user, channel)).or_default(); + for _ in 0..9 { + entry.push(now); + } + assert_eq!(entry.len(), 9); + } + + #[test] + fn test_rate_limit_exceeded() { + use dashmap::DashMap; + let limits = Arc::new(DashMap::<(Uuid, Uuid), Vec>::new()); + let now = Instant::now(); + let user = Uuid::now_v7(); + let channel = Uuid::now_v7(); + + let mut entry = limits.entry((user, channel)).or_default(); + for _ in 0..10 { + entry.push(now); + } + assert!(entry.len() >= 10); + } + + #[test] + fn test_rate_limit_window_expiry_eviction() { + use dashmap::DashMap; + let limits = Arc::new(DashMap::<(Uuid, Uuid), Vec>::new()); + let window = Duration::from_secs(10); + let user = Uuid::now_v7(); + let channel = Uuid::now_v7(); + + let old = Instant::now() - Duration::from_secs(15); + let now = Instant::now(); + + let mut entry = limits.entry((user, channel)).or_default(); + entry.push(old); + entry.push(now); + entry.retain(|t| now.duration_since(*t) < window); + + assert_eq!(entry.len(), 1); + } +} diff --git a/svc/thread.rs b/svc/thread.rs new file mode 100644 index 0000000..551883a --- /dev/null +++ b/svc/thread.rs @@ -0,0 +1,235 @@ +//! Thread event handlers on `MessageService`. +//! +//! Threads are anchored by a root message. Participants are added when they +//! reply, get mentioned, or explicitly join. Thread events broadcast to the +//! channel room so all clients see thread activity updates. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::models::message_thread_participant::JoinReason; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `thread:create` — create a new thread anchored on a root message. + pub async fn create_thread( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let root_message_id: Uuid = Self::parse_field(arr, "root_message_id")?; + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + + let root_message = self + .repo + .get(root_message_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("message {root_message_id}")))?; + if root_message.channel_id != channel_id { + return Err(ImksError::InvalidInput( + "Root message does not belong to channel".into(), + )); + } + + self.validate_channel_write(&channel_id.to_string(), &user_id.to_string()) + .await?; + + let thread = self + .repo + .create_thread(root_message_id, channel_id, user_id) + .await?; + + // Creator is automatically a participant + self.repo + .add_thread_participant(thread.id, user_id, JoinReason::Reply.as_str()) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "thread:created", + serde_json::to_value(&thread).unwrap_or_default(), + ) + .await; + } + + tracing::info!(thread_id = %thread.id, %channel_id, %user_id, "Thread created"); + Ok(()) + } + + /// Handle `thread:resolve` — toggle the resolved state of a thread. + pub async fn resolve_thread( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let thread_id: Uuid = Self::parse_field(arr, "thread_id")?; + let resolved: bool = Self::parse_field(arr, "resolved")?; + let thread = self + .repo + .get_thread(thread_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("thread {thread_id}")))?; + let channel_id = thread.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + self.ensure_author_or_mod(thread.created_by, &channel_id_str, user_id) + .await?; + + self.repo + .resolve_thread(thread_id, user_id, resolved) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "thread:updated", + serde_json::json!({ + "thread_id": thread_id.to_string(), + "resolved": resolved, + "resolved_by": user_id.to_string(), + }), + ) + .await; + } + + tracing::info!(%thread_id, %resolved, %user_id, "Thread resolve toggled"); + Ok(()) + } + + /// Handle `thread:join` — explicitly join a thread. + pub async fn join_thread( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let thread_id: Uuid = Self::parse_field(arr, "thread_id")?; + let thread = self + .repo + .get_thread(thread_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("thread {thread_id}")))?; + let channel_id = thread.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + self.repo + .add_thread_participant(thread_id, user_id, JoinReason::Joined.as_str()) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "thread:participant_joined", + serde_json::json!({ + "thread_id": thread_id.to_string(), + "user_id": user_id.to_string(), + }), + ) + .await; + } + + tracing::info!(%thread_id, %user_id, "User joined thread"); + Ok(()) + } + + /// Handle `thread:leave` — leave a thread. + pub async fn leave_thread( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let thread_id: Uuid = Self::parse_field(arr, "thread_id")?; + let thread = self + .repo + .get_thread(thread_id) + .await? + .ok_or_else(|| ImksError::NotFound(format!("thread {thread_id}")))?; + let channel_id = thread.channel_id; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + self.repo + .remove_thread_participant(thread_id, user_id) + .await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "thread:participant_left", + serde_json::json!({ + "thread_id": thread_id.to_string(), + "user_id": user_id.to_string(), + }), + ) + .await; + } + + tracing::info!(%thread_id, %user_id, "User left thread"); + Ok(()) + } + + /// Handle `thread:list` — list threads in a channel. + pub async fn list_threads( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + + let user_id = self.user_id(&socket)?; + let channel_id: Uuid = Self::parse_field(arr, "channel_id")?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + let threads = self.repo.list_threads(channel_id).await?; + let _ = socket.emit( + "thread:loaded", + serde_json::to_value(&threads).unwrap_or_default(), + ); + Ok(()) + } +} diff --git a/svc/typing.rs b/svc/typing.rs new file mode 100644 index 0000000..9d236cd --- /dev/null +++ b/svc/typing.rs @@ -0,0 +1,113 @@ +//! Typing indicator and presence event handlers on `MessageService`. +//! +//! These are pure broadcast events (no persistence). Typing indicators show +//! "user is typing…" in the channel. Presence indicates online/offline status. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::ImksError; +use crate::socket::socket::Socket; + +use super::message::MessageService; + +impl MessageService { + /// Handle `typing:start` — broadcast to the channel room. + pub async fn typing_start( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let channel_id: Uuid = self.parse_channel_id(data)?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "typing", + serde_json::json!({ + "channel_id": channel_id.to_string(), + "user_id": user_id.to_string(), + "typing": true, + }), + ) + .await; + } + Ok(()) + } + + /// Handle `typing:stop` — broadcast to the channel room. + pub async fn typing_stop( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let channel_id: Uuid = self.parse_channel_id(data)?; + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "typing", + serde_json::json!({ + "channel_id": channel_id.to_string(), + "user_id": user_id.to_string(), + "typing": false, + }), + ) + .await; + } + Ok(()) + } + + /// Handle `presence:update` — broadcast online status to all shared channels. + /// In a full implementation this would track which channels a user is in + /// and broadcast to all of them. For now it broadcasts to the specified channel. + pub async fn presence_update( + &self, + socket: Arc, + data: &serde_json::Value, + ) -> crate::ImksResult<()> { + let user_id = self.user_id(&socket)?; + let channel_id: Uuid = self.parse_channel_id(data)?; + let online: bool = + Self::parse_optional(Self::first_payload(data)?, "online")?.unwrap_or(true); + let channel_id_str = channel_id.to_string(); + let user_id_str = user_id.to_string(); + + self.ensure_readable(&channel_id_str, &user_id_str).await?; + self.ensure_member(&channel_id_str, &user_id_str).await?; + + if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) { + ns.emit_to_room( + &channel_id.to_string(), + "presence:update", + serde_json::json!({ + "user_id": user_id.to_string(), + "online": online, + }), + ) + .await; + } + Ok(()) + } + + fn parse_channel_id(&self, data: &serde_json::Value) -> crate::ImksResult { + let arr = data + .as_array() + .and_then(|a| a.first()) + .ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?; + Self::parse_field(arr, "channel_id") + } +} diff --git a/tests/adapter_tests.rs b/tests/adapter_tests.rs index 6b71439..c09fe5c 100644 --- a/tests/adapter_tests.rs +++ b/tests/adapter_tests.rs @@ -1,7 +1,9 @@ use std::collections::HashSet; use std::sync::Arc; -use imks::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, SocketInfo}; +use imks::socket::adapter::{ + Adapter, AdapterError, BroadcastFlags, BroadcastOptions, BusMessage, LocalAdapter, SocketInfo, +}; use imks::socket::packet::Packet; use imks::socket::session_store::{InMemorySessionStore, SessionInfo, SessionStoreTrait}; @@ -10,7 +12,11 @@ async fn test_local_adapter_add_and_del() { let sent_packets: dashmap::DashMap> = dashmap::DashMap::new(); let sent_packets_clone = sent_packets.clone(); let send_fn = move |engine_sid: &str, packet: &Packet| { - sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone()); + sent_packets_clone + .entry(engine_sid.to_string()) + .or_default() + .value_mut() + .push(packet.clone()); Ok(()) }; @@ -47,10 +53,15 @@ async fn test_local_adapter_del_all() { #[tokio::test] async fn test_local_adapter_register_and_broadcast() { - let sent_packets: Arc>> = Arc::new(dashmap::DashMap::new()); + let sent_packets: Arc>> = + Arc::new(dashmap::DashMap::new()); let sent_packets_clone = sent_packets.clone(); let send_fn = move |engine_sid: &str, packet: &Packet| { - sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone()); + sent_packets_clone + .entry(engine_sid.to_string()) + .or_default() + .value_mut() + .push(packet.clone()); Ok(()) }; @@ -71,10 +82,15 @@ async fn test_local_adapter_register_and_broadcast() { #[tokio::test] async fn test_local_adapter_broadcast_to_room() { - let sent_packets: Arc>> = Arc::new(dashmap::DashMap::new()); + let sent_packets: Arc>> = + Arc::new(dashmap::DashMap::new()); let sent_packets_clone = sent_packets.clone(); let send_fn = move |engine_sid: &str, packet: &Packet| { - sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone()); + sent_packets_clone + .entry(engine_sid.to_string()) + .or_default() + .value_mut() + .push(packet.clone()); Ok(()) }; @@ -99,10 +115,15 @@ async fn test_local_adapter_broadcast_to_room() { #[tokio::test] async fn test_local_adapter_broadcast_except() { - let sent_packets: Arc>> = Arc::new(dashmap::DashMap::new()); + let sent_packets: Arc>> = + Arc::new(dashmap::DashMap::new()); let sent_packets_clone = sent_packets.clone(); let send_fn = move |engine_sid: &str, packet: &Packet| { - sent_packets_clone.entry(engine_sid.to_string()).or_insert_with(Vec::new).value_mut().push(packet.clone()); + sent_packets_clone + .entry(engine_sid.to_string()) + .or_default() + .value_mut() + .push(packet.clone()); Ok(()) }; @@ -340,5 +361,7 @@ fn test_is_valid_namespace() { assert!(!imks::socket::namespace::is_valid_namespace("")); assert!(!imks::socket::namespace::is_valid_namespace("admin")); - assert!(!imks::socket::namespace::is_valid_namespace(&"/".repeat(257))); + assert!(!imks::socket::namespace::is_valid_namespace( + &"/".repeat(257) + )); } diff --git a/tests/engine_io_tests.rs b/tests/engine_io_tests.rs index 67375b4..083233a 100644 --- a/tests/engine_io_tests.rs +++ b/tests/engine_io_tests.rs @@ -78,7 +78,10 @@ fn test_engine_io_binary_encoding() { let decoded = codec::decode_packet(&encoded).unwrap(); assert_eq!(decoded.packet_type, PacketType::Message); - assert_eq!(decoded.data, PacketData::Binary(vec![0x01, 0x02, 0x03, 0x04])); + assert_eq!( + decoded.data, + PacketData::Binary(vec![0x01, 0x02, 0x03, 0x04]) + ); } #[test] diff --git a/tests/session_tests.rs b/tests/session_tests.rs index ee39167..686683b 100644 --- a/tests/session_tests.rs +++ b/tests/session_tests.rs @@ -1,4 +1,4 @@ -use imks::engine::session::{generate_sid, SessionState, SessionStore, TransportType}; +use imks::engine::session::{SessionState, SessionStore, TransportType, generate_sid}; #[test] fn test_session_store_create_and_get() { @@ -56,7 +56,10 @@ fn test_generate_sid_format() { let sid = generate_sid(); assert_eq!(sid.len(), 20); - assert!(sid.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')); + assert!( + sid.chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + ); } #[tokio::test]