refactor(tests): reformat code and update dependency management

- Reorganized import statements in adapter tests for better readability
- Replaced or_insert_with(Vec::new) with or_default() in test closures
- Updated Cargo.lock with new dependency versions and checksums
- Added TLS features to tonic dependency configuration
- Included sqlx, chrono, and uuid dependencies with specific features
- Added jsonwebtoken and arc-swap as project dependencies
- Reformatted assertion statements to comply with line length limits
- Adjusted base64 import order in engine codec module
- Updated protobuf include statement formatting
This commit is contained in:
zhenyi
2026-06-11 12:11:05 +08:00
parent 06e8ee96a5
commit 821537186e
111 changed files with 10458 additions and 385 deletions
Generated
+573 -3
View File
@@ -37,7 +37,7 @@ dependencies = [
"derive_more", "derive_more",
"encoding_rs", "encoding_rs",
"flate2", "flate2",
"foldhash", "foldhash 0.1.5",
"futures-core", "futures-core",
"h2 0.3.27", "h2 0.3.27",
"http 0.2.12", "http 0.2.12",
@@ -152,7 +152,7 @@ dependencies = [
"cookie", "cookie",
"derive_more", "derive_more",
"encoding_rs", "encoding_rs",
"foldhash", "foldhash 0.1.5",
"futures-core", "futures-core",
"futures-util", "futures-util",
"impl-more", "impl-more",
@@ -232,6 +232,21 @@ dependencies = [
"alloc-no-stdlib", "alloc-no-stdlib",
] ]
[[package]]
name = "allocator-api2"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.102" version = "1.0.102"
@@ -333,6 +348,15 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "atoi"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "atomic-waker" name = "atomic-waker"
version = "1.1.2" version = "1.1.2"
@@ -414,6 +438,9 @@ name = "bitflags"
version = "2.13.0" version = "2.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8"
dependencies = [
"serde_core",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
@@ -460,6 +487,12 @@ version = "3.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.11.1" version = "1.11.1"
@@ -523,6 +556,35 @@ dependencies = [
"rand_core 0.10.1", "rand_core 0.10.1",
] ]
[[package]]
name = "chrono"
version = "0.4.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327"
dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]
[[package]]
name = "cmov"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c9ea0ac24bc397ab3c98583a3c9ba74fa56b09a4449bbe172b9b1ddb016027a"
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "const-oid" name = "const-oid"
version = "0.9.6" version = "0.9.6"
@@ -605,6 +667,21 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "crc"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d"
dependencies = [
"crc-catalog",
]
[[package]]
name = "crc-catalog"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "217698eaf96b4a3f0bc4f3662aaa55bdf913cd54d7204591faa790070c6d0853"
[[package]] [[package]]
name = "crc16" name = "crc16"
version = "0.4.0" version = "0.4.0"
@@ -620,6 +697,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossbeam-queue"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.21" version = "0.8.21"
@@ -645,6 +731,15 @@ dependencies = [
"hybrid-array", "hybrid-array",
] ]
[[package]]
name = "ctutils"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e"
dependencies = [
"cmov",
]
[[package]] [[package]]
name = "curve25519-dalek" name = "curve25519-dalek"
version = "4.1.3" version = "4.1.3"
@@ -768,6 +863,7 @@ dependencies = [
"block-buffer 0.12.0", "block-buffer 0.12.0",
"const-oid 0.10.2", "const-oid 0.10.2",
"crypto-common 0.2.2", "crypto-common 0.2.2",
"ctutils",
] ]
[[package]] [[package]]
@@ -781,6 +877,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "dotenvy"
version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]] [[package]]
name = "ed25519" name = "ed25519"
version = "2.2.3" version = "2.2.3"
@@ -808,6 +910,9 @@ name = "either"
version = "1.16.0" version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "encoding_rs" name = "encoding_rs"
@@ -834,6 +939,27 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "etcetera"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de48cc4d1c1d97a20fd819def54b890cadde72ed3ad0c614822a0a433361be96"
dependencies = [
"cfg-if",
"windows-sys 0.61.2",
]
[[package]]
name = "event-listener"
version = "5.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.4.1" version = "2.4.1"
@@ -877,6 +1003,17 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "flume"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be"
dependencies = [
"futures-core",
"futures-sink",
"spin",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@@ -889,6 +1026,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foldhash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]] [[package]]
name = "form_urlencoded" name = "form_urlencoded"
version = "1.2.2" version = "1.2.2"
@@ -977,6 +1120,17 @@ dependencies = [
"futures-util", "futures-util",
] ]
[[package]]
name = "futures-intrusive"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
dependencies = [
"futures-core",
"lock_api",
"parking_lot",
]
[[package]] [[package]]
name = "futures-io" name = "futures-io"
version = "0.3.32" version = "0.3.32"
@@ -1124,7 +1278,18 @@ version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [ dependencies = [
"foldhash", "foldhash 0.1.5",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.2.0",
] ]
[[package]] [[package]]
@@ -1133,12 +1298,45 @@ version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a"
[[package]]
name = "hashlink"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "824e001ac4f3012dd16a264bec811403a67ca9deb6c102fc5049b32c4574b35f"
dependencies = [
"hashbrown 0.16.1",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.5.0" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hkdf"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4aaa26c720c68b866f2c96ef5c1264b3e6f473fe5d4ce61cd44bbe913e553018"
dependencies = [
"hmac",
]
[[package]]
name = "hmac"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f"
dependencies = [
"digest 0.11.3",
]
[[package]] [[package]]
name = "httlib-huffman" name = "httlib-huffman"
version = "0.3.4" version = "0.3.4"
@@ -1265,6 +1463,30 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "iana-time-zone"
version = "0.1.65"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"log",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "icu_collections" name = "icu_collections"
version = "2.2.0" version = "2.2.0"
@@ -1381,17 +1603,21 @@ dependencies = [
"actix-rt", "actix-rt",
"actix-web", "actix-web",
"actix-ws", "actix-ws",
"arc-swap",
"async-nats", "async-nats",
"async-trait", "async-trait",
"base64", "base64",
"chrono",
"dashmap", "dashmap",
"fred", "fred",
"futures-util", "futures-util",
"jsonwebtoken",
"prost", "prost",
"prost-types", "prost-types",
"rand 0.9.4", "rand 0.9.4",
"serde", "serde",
"serde_json", "serde_json",
"sqlx",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tonic", "tonic",
@@ -1460,6 +1686,21 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "jsonwebtoken"
version = "9.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
dependencies = [
"base64",
"js-sys",
"pem",
"ring",
"serde",
"serde_json",
"simple_asn1",
]
[[package]] [[package]]
name = "language-tags" name = "language-tags"
version = "0.3.2" version = "0.3.2"
@@ -1484,6 +1725,16 @@ version = "0.2.186"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
[[package]]
name = "libsqlite3-sys"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f111c8c41e7c61a49cd34e44c7619462967221a6443b0ec299e0ac30cfb9b1"
dependencies = [
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.12.1" version = "0.12.1"
@@ -1549,6 +1800,16 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "md-5"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69b6441f590336821bb897fb28fc622898ccceb1d6cea3fde5ea86b090c4de98"
dependencies = [
"cfg-if",
"digest 0.11.3",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.8.1" version = "2.8.1"
@@ -1705,6 +1966,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "parking"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.5" version = "0.12.5"
@@ -2208,6 +2475,7 @@ version = "0.23.40"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b"
dependencies = [ dependencies = [
"log",
"once_cell", "once_cell",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
@@ -2521,6 +2789,18 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214"
[[package]]
name = "simple_asn1"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d"
dependencies = [
"num-bigint",
"num-traits",
"thiserror 2.0.18",
"time",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.12" version = "0.4.12"
@@ -2532,6 +2812,9 @@ name = "smallvec"
version = "1.15.1" version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "socket2" name = "socket2"
@@ -2553,6 +2836,15 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"
@@ -2563,12 +2855,202 @@ dependencies = [
"der", "der",
] ]
[[package]]
name = "sqlx"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "378620ccc25c62c89d8be1c819e76a88d59bdcc3304733330788948e619bfd71"
dependencies = [
"sqlx-core",
"sqlx-macros",
"sqlx-mysql",
"sqlx-postgres",
"sqlx-sqlite",
]
[[package]]
name = "sqlx-core"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05b44e85bf579a8eeb4ceaa77a3a523baf2bf0e9bac7e40f405d537b5d2d5ccb"
dependencies = [
"base64",
"bytes",
"cfg-if",
"chrono",
"crc",
"crossbeam-queue",
"either",
"event-listener",
"futures-core",
"futures-intrusive",
"futures-io",
"futures-util",
"hashbrown 0.16.1",
"hashlink",
"indexmap",
"log",
"memchr",
"percent-encoding",
"serde",
"serde_json",
"sha2 0.10.9",
"smallvec",
"thiserror 2.0.18",
"tokio",
"tokio-stream",
"tracing",
"url",
"uuid",
]
[[package]]
name = "sqlx-macros"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd2b84f2bc39a5705ef27ec785a11c934a41bbd4a24941e257927cddc26b60bf"
dependencies = [
"proc-macro2",
"quote",
"sqlx-core",
"sqlx-macros-core",
"syn",
]
[[package]]
name = "sqlx-macros-core"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb8d96de5fdc85a5c4ec813432b523ec637e80ba98f046555f75f7908ddac7c3"
dependencies = [
"cfg-if",
"dotenvy",
"either",
"heck",
"hex",
"proc-macro2",
"quote",
"serde",
"serde_json",
"sha2 0.10.9",
"sqlx-core",
"sqlx-mysql",
"sqlx-postgres",
"sqlx-sqlite",
"syn",
"thiserror 2.0.18",
"tokio",
"url",
]
[[package]]
name = "sqlx-mysql"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90b8020fe17c5f2c245bfa2505d7ef59c5604839527c740266ad2214acebea27"
dependencies = [
"bitflags",
"byteorder",
"bytes",
"chrono",
"crc",
"digest 0.11.3",
"dotenvy",
"either",
"futures-core",
"futures-util",
"generic-array",
"log",
"percent-encoding",
"serde",
"sha1",
"sha2 0.11.0",
"sqlx-core",
"thiserror 2.0.18",
"tracing",
"uuid",
]
[[package]]
name = "sqlx-postgres"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87a2bdd6e83f6b3ea525ca9fee568030508b58355a43d0b2c1674d5f79dcd65e"
dependencies = [
"atoi",
"base64",
"bitflags",
"byteorder",
"chrono",
"crc",
"dotenvy",
"etcetera",
"futures-channel",
"futures-core",
"futures-util",
"hex",
"hkdf",
"hmac",
"itoa",
"log",
"md-5",
"memchr",
"rand 0.10.1",
"serde",
"serde_json",
"sha2 0.11.0",
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 2.0.18",
"tracing",
"uuid",
"whoami",
]
[[package]]
name = "sqlx-sqlite"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488e99c397a62007e4229aec669a179816339afc6d2620ca6fa420dbee2e982c"
dependencies = [
"atoi",
"chrono",
"flume",
"form_urlencoded",
"futures-channel",
"futures-core",
"futures-executor",
"futures-intrusive",
"futures-util",
"libsqlite3-sys",
"log",
"percent-encoding",
"serde",
"sqlx-core",
"thiserror 2.0.18",
"tracing",
"url",
"uuid",
]
[[package]] [[package]]
name = "stable_deref_trait" name = "stable_deref_trait"
version = "1.2.1" version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "stringprep"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1"
dependencies = [
"unicode-bidi",
"unicode-normalization",
"unicode-properties",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
@@ -2828,6 +3310,7 @@ dependencies = [
"socket2 0.6.4", "socket2 0.6.4",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-rustls",
"tokio-stream", "tokio-stream",
"tower", "tower",
"tower-layer", "tower-layer",
@@ -3008,12 +3491,33 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]]
name = "unicode-bidi"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.24" version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
[[package]]
name = "unicode-normalization"
version = "0.1.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-properties"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d"
[[package]] [[package]]
name = "unicode-segmentation" name = "unicode-segmentation"
version = "1.13.3" version = "1.13.3"
@@ -3064,6 +3568,7 @@ checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7"
dependencies = [ dependencies = [
"getrandom 0.4.2", "getrandom 0.4.2",
"js-sys", "js-sys",
"serde_core",
"wasm-bindgen", "wasm-bindgen",
] ]
@@ -3073,6 +3578,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.5" version = "0.9.5"
@@ -3211,6 +3722,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "whoami"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "998767ef88740d1f5b0682a9c53c24431453923962269c2db68ee43788c5a40d"
[[package]] [[package]]
name = "winapi-util" name = "winapi-util"
version = "0.1.11" version = "0.1.11"
@@ -3220,12 +3737,65 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "windows-core"
version = "0.62.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-implement"
version = "0.60.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "windows-interface"
version = "0.59.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "windows-link" name = "windows-link"
version = "0.2.1" version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-result"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091"
dependencies = [
"windows-link",
]
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.52.0" version = "0.52.0"
+6 -2
View File
@@ -14,7 +14,7 @@ name = "imks"
[dependencies] [dependencies]
tonic = "0.14.6" tonic = { version = "0.14.6", features = ["tls-ring"] }
prost = "0.14.3" prost = "0.14.3"
prost-types = "0.14" prost-types = "0.14"
tonic-build = "0.14.6" tonic-build = "0.14.6"
@@ -26,6 +26,9 @@ actix-ws = { version = "0.4.0", features = [] }
actix-rt = "2" actix-rt = "2"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = { version = "1" } serde_json = { version = "1" }
sqlx = { version = "0.9", features = ["postgres", "runtime-tokio", "chrono", "uuid", "json", "migrate"] }
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1", features = ["v4", "v7", "serde"] }
base64 = "0.22" base64 = "0.22"
rand = "0.9" rand = "0.9"
wtransport = "0.7" wtransport = "0.7"
@@ -36,8 +39,9 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
fred = { version = "10", features = ["subscriber-client"] } fred = { version = "10", features = ["subscriber-client"] }
async-nats = "0.38" async-nats = "0.38"
uuid = { version = "1", features = ["v4"] }
futures-util = "0.3" futures-util = "0.3"
jsonwebtoken = "9"
arc-swap = "1"
[build-dependencies] [build-dependencies]
+87
View File
@@ -0,0 +1,87 @@
//! JWT claims structure — mirrors proto `TokenClaims` for local verification.
//!
//! Used as the deserialization target for `jsonwebtoken::decode`.
//! Field names match standard JWT claim names (`sub`, `iss`, `exp`, etc.).
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Parsed JWT payload, matching the proto `TokenClaims` shape.
///
/// Deserialized by `jsonwebtoken` during HS256 verification.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
/// Subject — the user UUID.
pub sub: String,
/// Issuer — expected to be `"appks"`.
pub iss: String,
/// Issued-at (unix seconds).
pub iat: i64,
/// Expiration (unix seconds).
pub exp: i64,
/// Unique token ID (used for revocation tracking via `jti`).
pub jti: String,
/// Space-separated scopes, e.g. `"im:read im:write"`.
pub scope: String,
/// Extensible metadata (workspace_id, role, etc.).
#[serde(default)]
pub extra: HashMap<String, String>,
}
impl TokenClaims {
/// Check whether this token carries a specific scope.
pub fn has_scope(&self, scope: &str) -> bool {
self.scope.split_whitespace().any(|s| s == scope)
}
/// Convert from the proto-generated `TokenClaims` (RPC verify response).
pub fn from_proto(proto: crate::pb::core::TokenClaims) -> Self {
Self {
sub: proto.sub,
iss: proto.iss,
iat: proto.iat,
exp: proto.exp,
jti: proto.jti,
scope: proto.scope,
extra: proto.extra,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_has_scope() {
let claims = TokenClaims {
sub: "user-1".into(),
iss: "appks".into(),
iat: 0,
exp: 9999999999,
jti: "tok-1".into(),
scope: "im:read im:write admin".into(),
extra: HashMap::new(),
};
assert!(claims.has_scope("im:read"));
assert!(claims.has_scope("admin"));
assert!(!claims.has_scope("im:delete"));
}
#[test]
fn test_deserialize_from_json() {
let json = r#"{
"sub": "user-1",
"iss": "appks",
"iat": 1000,
"exp": 2000,
"jti": "tok-1",
"scope": "im:read",
"extra": {"workspace_id": "ws-1"}
}"#;
let claims: TokenClaims = serde_json::from_str(json).unwrap();
assert_eq!(claims.sub, "user-1");
assert_eq!(claims.extra.get("workspace_id").unwrap(), "ws-1");
}
}
+119
View File
@@ -0,0 +1,119 @@
//! Low-level HS256 JWT decoding and verification.
//!
//! Stateless functions — no caching or key management.
//! Used by the `Authenticator` in combination with `SigningKeyStore`.
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use crate::{ImksError, ImksResult};
use super::claims::TokenClaims;
/// Expected JWT issuer claim.
const EXPECTED_ISSUER: &str = "appks";
/// Signing algorithm used by appks.
const ALGORITHM: Algorithm = Algorithm::HS256;
/// Extract the `kid` from a JWT header without verifying the signature.
///
/// This is the first step in local verification: find which signing key
/// was used, then look it up in the `SigningKeyStore`.
pub fn extract_kid(token: &str) -> ImksResult<String> {
let header = decode_header(token).map_err(map_jwt_error)?;
header
.kid
.ok_or_else(|| ImksError::Auth("JWT header missing 'kid' field".into()))
}
/// Verify an HS256 JWT signature and decode its claims.
///
/// Validates: algorithm, issuer, expiration. Does NOT validate audience.
pub fn verify_and_decode(token: &str, key: &DecodingKey) -> ImksResult<TokenClaims> {
let validation = build_validation();
let token_data = decode::<TokenClaims>(token, key, &validation).map_err(map_jwt_error)?;
Ok(token_data.claims)
}
/// Build the standard `Validation` config for imks JWT verification.
fn build_validation() -> Validation {
let mut validation = Validation::new(ALGORITHM);
validation.set_issuer(&[EXPECTED_ISSUER]);
validation.validate_exp = true;
// Audience validation not required for imks tokens.
validation.validate_aud = false;
validation
}
/// Map `jsonwebtoken` errors to `ImksError`, distinguishing expired tokens.
fn map_jwt_error(e: jsonwebtoken::errors::Error) -> ImksError {
use jsonwebtoken::errors::ErrorKind;
match e.kind() {
ErrorKind::ExpiredSignature => ImksError::TokenExpired,
ErrorKind::InvalidSignature => ImksError::Auth("invalid JWT signature".into()),
ErrorKind::InvalidIssuer => ImksError::Auth("invalid JWT issuer".into()),
_ => ImksError::Auth(format!("JWT error: {e}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header, encode};
fn make_test_token(claims: &TokenClaims, secret: &[u8]) -> String {
let mut header = Header::new(ALGORITHM);
header.kid = Some("test-kid".into());
encode(&header, claims, &EncodingKey::from_secret(secret)).unwrap()
}
#[test]
fn test_extract_kid() {
let claims = TokenClaims {
sub: "u1".into(),
iss: "appks".into(),
iat: 1000,
exp: 9999999999,
jti: "j1".into(),
scope: "im:read".into(),
extra: Default::default(),
};
let token = make_test_token(&claims, b"secret");
let kid = extract_kid(&token).unwrap();
assert_eq!(kid, "test-kid");
}
#[test]
fn test_verify_and_decode_valid() {
let secret = b"test-secret-key-material-32bytes!";
let claims = TokenClaims {
sub: "user-1".into(),
iss: "appks".into(),
iat: 1000,
exp: 9999999999,
jti: "tok-1".into(),
scope: "im:read".into(),
extra: Default::default(),
};
let token = make_test_token(&claims, secret);
let key = DecodingKey::from_secret(secret);
let decoded = verify_and_decode(&token, &key).unwrap();
assert_eq!(decoded.sub, "user-1");
}
#[test]
fn test_verify_rejects_wrong_key() {
let claims = TokenClaims {
sub: "u1".into(),
iss: "appks".into(),
iat: 1000,
exp: 9999999999,
jti: "j1".into(),
scope: "".into(),
extra: Default::default(),
};
let token = make_test_token(&claims, b"correct-secret");
let wrong_key = DecodingKey::from_secret(b"wrong-secret");
let result = verify_and_decode(&token, &wrong_key);
assert!(result.is_err());
}
}
+171
View File
@@ -0,0 +1,171 @@
//! Signing key store with atomic reads and periodic background refresh.
//!
//! Fetches HS256 signing keys from appks via `GetSigningKeys` RPC,
//! caches them behind `ArcSwap` for lock-free reads, and schedules
//! re-fetch when `next_rotation_at` is reached.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use jsonwebtoken::DecodingKey;
use tokio::task::JoinHandle;
use tonic::transport::Channel;
use crate::pb::core::GetSigningKeysRequest;
use crate::pb::core::token_service_client::TokenServiceClient;
use crate::{ImksError, ImksResult};
/// A cached signing key entry with a pre-computed `DecodingKey`.
struct CachedKey {
kid: String,
decoding_key: DecodingKey,
/// Unix timestamp (seconds) when this key expires.
expires_at: i64,
/// Whether this is the current active signing key.
active: bool,
}
/// Thread-safe store of signing keys with periodic background refresh.
///
/// Reads via `get_key()` are lock-free (ArcSwap).
/// A background task re-fetches keys from appks at each rotation window.
pub struct SigningKeyStore {
keys: Arc<ArcSwap<HashMap<String, CachedKey>>>,
refresh_handle: Option<JoinHandle<()>>,
}
impl SigningKeyStore {
/// Fetch initial keys from appks and start the background refresh loop.
pub async fn init(mut client: TokenServiceClient<Channel>) -> ImksResult<Self> {
let (cached, next_rotation) = fetch_keys(&mut client).await?;
let map: HashMap<String, CachedKey> =
cached.into_iter().map(|k| (k.kid.clone(), k)).collect();
let keys = Arc::new(ArcSwap::from_pointee(map));
let keys_clone = keys.clone();
let client_clone = client;
let refresh_handle = tokio::spawn(async move {
refresh_loop(client_clone, keys_clone, next_rotation).await;
});
tracing::info!("SigningKeyStore initialized with background refresh");
Ok(Self {
keys,
refresh_handle: Some(refresh_handle),
})
}
/// Look up a decoding key by its `kid`. Returns `None` if unknown or expired.
///
/// Inactive keys (from a previous rotation window) are still served so they
/// can validate tokens signed before the rotation. Expired keys (past their
/// 3h window) are rejected as a local safety net even though the RPC should
/// not return them.
pub fn get_key(&self, kid: &str) -> Option<DecodingKey> {
let map = self.keys.load();
let cached = map.get(kid)?;
debug_assert_eq!(cached.kid, kid, "CachedKey kid must match its HashMap key");
let now = chrono::Utc::now().timestamp();
if cached.expires_at > 0 && now >= cached.expires_at {
tracing::warn!(
kid = %cached.kid,
expires_at = cached.expires_at,
"Rejecting expired signing key"
);
return None;
}
if !cached.active {
tracing::debug!(
kid = %cached.kid,
"Serving inactive signing key (previous rotation window)"
);
}
Some(cached.decoding_key.clone())
}
/// Stop the background refresh task.
pub async fn shutdown(mut self) {
if let Some(handle) = self.refresh_handle.take() {
handle.abort();
}
}
}
impl Drop for SigningKeyStore {
fn drop(&mut self) {
if let Some(handle) = self.refresh_handle.take() {
handle.abort();
}
}
}
/// Fetch all active signing keys from appks.
async fn fetch_keys(client: &mut TokenServiceClient<Channel>) -> ImksResult<(Vec<CachedKey>, i64)> {
let resp = client
.get_signing_keys(GetSigningKeysRequest { kid: String::new() })
.await
.map_err(ImksError::GrpcStatus)?;
let inner = resp.into_inner();
let mut cached_keys = Vec::new();
for key in &inner.keys {
let secret = BASE64
.decode(&key.key_material)
.map_err(|e| ImksError::Auth(format!("Invalid key base64 for kid={}: {e}", key.kid)))?;
cached_keys.push(CachedKey {
kid: key.kid.clone(),
decoding_key: DecodingKey::from_secret(&secret),
expires_at: key.expires_at,
active: key.active,
});
}
tracing::info!(
key_count = cached_keys.len(),
next_rotation = inner.next_rotation_at,
"Fetched signing keys from appks"
);
Ok((cached_keys, inner.next_rotation_at))
}
/// Background loop: sleep until `next_rotation_at`, re-fetch, swap atomically.
async fn refresh_loop(
mut client: TokenServiceClient<Channel>,
keys: Arc<ArcSwap<HashMap<String, CachedKey>>>,
mut next_rotation_at: i64,
) {
loop {
let now_secs = chrono::Utc::now().timestamp();
let sleep_secs = (next_rotation_at - now_secs).max(60);
tracing::debug!(sleep_secs, "Key refresh sleeping");
tokio::time::sleep(Duration::from_secs(sleep_secs as u64)).await;
match fetch_keys(&mut client).await {
Ok((cached, new_rotation)) => {
let map: HashMap<String, CachedKey> =
cached.into_iter().map(|k| (k.kid.clone(), k)).collect();
keys.store(Arc::new(map));
next_rotation_at = new_rotation;
tracing::info!("Signing keys refreshed");
}
Err(e) => {
tracing::error!(error = %e, "Failed to refresh signing keys, retrying in 60s");
next_rotation_at = now_secs + 60;
}
}
}
}
+8
View File
@@ -0,0 +1,8 @@
pub mod claims;
pub mod jwt_decoder;
pub mod key_store;
pub mod verifier;
pub use claims::TokenClaims;
pub use key_store::SigningKeyStore;
pub use verifier::Authenticator;
+98
View File
@@ -0,0 +1,98 @@
//! Dual-mode JWT authenticator — the public-facing entry point.
//!
//! Composes `SigningKeyStore` (local cache) + `jwt_decoder` (HS256 logic)
//! + `TokenServiceClient` (RPC fallback) into a single `Authenticator`.
use std::sync::Arc;
use tonic::transport::Channel;
use crate::pb::core::VerifyTokenRequest;
use crate::pb::core::token_service_client::TokenServiceClient;
use crate::{ImksError, ImksResult};
use super::claims::TokenClaims;
use super::jwt_decoder;
use super::key_store::SigningKeyStore;
/// Dual-mode JWT authenticator.
///
/// - **Local mode** (`verify_local`): HS256 verification against cached
/// signing keys. Zero network latency. Suitable for high-frequency
/// operations like message send/receive.
///
/// - **RPC mode** (`verify_rpc`): forwards the token to appks
/// `VerifyToken()`. Real-time revocation awareness. Use for
/// sensitive operations like kick/ban/permission changes.
#[derive(Clone)]
pub struct Authenticator {
key_store: Arc<SigningKeyStore>,
token_client: TokenServiceClient<Channel>,
}
impl Authenticator {
/// Create a new authenticator. Initializes the signing key cache from appks.
pub async fn new(token_client: TokenServiceClient<Channel>) -> ImksResult<Self> {
let key_store = SigningKeyStore::init(token_client.clone()).await?;
Ok(Self {
key_store: Arc::new(key_store),
token_client,
})
}
/// Fast-path verification using locally cached signing keys.
///
/// Extracts `kid` from the JWT header, looks up the key, and verifies
/// the HS256 signature. Cannot detect token revocation within the
/// current key rotation window (~3 hours).
pub fn verify_local(&self, token: &str) -> ImksResult<TokenClaims> {
let kid = jwt_decoder::extract_kid(token)?;
let key = self
.key_store
.get_key(&kid)
.ok_or_else(|| ImksError::Auth(format!("Unknown signing key kid: {kid}")))?;
jwt_decoder::verify_and_decode(token, &key)
}
/// Authoritative verification via appks `VerifyToken` RPC.
///
/// Detects token revocation in real-time. Adds one RPC round-trip.
pub async fn verify_rpc(&self, token: &str) -> ImksResult<TokenClaims> {
let mut client = self.token_client.clone();
let resp = client
.verify_token(VerifyTokenRequest {
token: token.to_string(),
})
.await?;
let inner = resp.into_inner();
if !inner.valid {
return Err(ImksError::Auth(inner.reason));
}
let proto_claims = inner.claims.ok_or_else(|| {
ImksError::Auth("VerifyToken returned valid=true but no claims".into())
})?;
Ok(TokenClaims::from_proto(proto_claims))
}
/// Extract the Bearer token value from an `Authorization` header.
///
/// Expects format: `"Bearer <token>"`. Returns the token part.
pub fn extract_bearer(auth_header: &str) -> ImksResult<&str> {
auth_header
.strip_prefix("Bearer ")
.ok_or_else(|| ImksError::Auth("Missing or malformed Authorization header".into()))
}
/// Shut down the background key refresh task.
pub async fn shutdown(self) {
// Unwrap the Arc — if there are other clones, the store lives on.
if let Ok(store) = Arc::try_unwrap(self.key_store) {
store.shutdown().await;
}
}
}
+71
View File
@@ -0,0 +1,71 @@
//! PostgreSQL connection pool configuration.
//!
//! Reads settings from environment variables with sensible defaults.
use std::env;
/// PostgreSQL connection configuration, sourced from environment variables.
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
/// PostgreSQL connection URL (e.g. `postgres://user:pass@host/db`).
pub url: String,
/// Maximum number of connections in the pool.
pub max_connections: u32,
/// Minimum number of idle connections maintained.
pub min_connections: u32,
/// Timeout for acquiring a new connection (seconds).
pub connect_timeout_secs: u64,
/// Timeout for idle connections before they are closed (seconds).
pub idle_timeout_secs: u64,
}
impl DatabaseConfig {
/// Build config by reading environment variables, falling back to defaults.
pub fn from_env() -> Self {
Self {
url: env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://localhost/imks".to_string()),
max_connections: env::var("DATABASE_MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10),
min_connections: env::var("DATABASE_MIN_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(2),
connect_timeout_secs: env::var("DATABASE_CONNECT_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(30),
idle_timeout_secs: env::var("DATABASE_IDLE_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600),
}
}
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self::from_env()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_has_sane_values() {
// Without env vars set, defaults should be applied.
let cfg = DatabaseConfig {
url: "postgres://localhost/test".to_string(),
max_connections: 10,
min_connections: 2,
connect_timeout_secs: 30,
idle_timeout_secs: 600,
};
assert_eq!(cfg.max_connections, 10);
assert_eq!(cfg.min_connections, 2);
}
}
+37
View File
@@ -0,0 +1,37 @@
//! SQL migration runner.
//!
//! Reads migration files from the `migrate/` directory and applies them
//! in lexicographic order using sqlx's built-in migration infrastructure.
use sqlx::PgPool;
use sqlx::migrate::Migrator;
use std::path::Path;
use crate::{ImksError, ImksResult};
/// Run all pending SQL migrations from the `migrate/` directory.
///
/// Migrations are applied in filename order (e.g. `001_…sql` before `002_…sql`).
/// sqlx tracks applied migrations in a `_sqlx_migrations` table so that
/// only new migrations are executed on subsequent runs.
pub async fn run_migrations(pool: &PgPool) -> ImksResult<()> {
let migrations_dir = Path::new("migrate");
if !migrations_dir.exists() {
tracing::warn!("No migrate/ directory found — skipping migrations");
return Ok(());
}
let migrator = Migrator::new(migrations_dir)
.await
.map_err(|e| ImksError::Internal(format!("Failed to load migrations: {e}")))?;
migrator
.run(pool)
.await
.map_err(|e| ImksError::Internal(format!("Migration failed: {e}")))?;
tracing::info!("Database migrations completed");
Ok(())
}
+7
View File
@@ -0,0 +1,7 @@
pub mod config;
pub mod migration;
pub mod pool;
pub use config::DatabaseConfig;
pub use migration::run_migrations;
pub use pool::Database;
+75
View File
@@ -0,0 +1,75 @@
//! PostgreSQL connection pool wrapper.
//!
//! Provides [`Database`] — a thin, cloneable wrapper around `sqlx::PgPool`
//! with health-check and graceful shutdown.
use std::time::Duration;
use sqlx::Row;
use sqlx::postgres::{PgPool, PgPoolOptions};
use crate::ImksResult;
use super::config::DatabaseConfig;
/// Cloneable handle to the PostgreSQL connection pool.
///
/// All query execution goes through `pool()` which returns a `&PgPool`.
#[derive(Clone)]
pub struct Database {
pool: PgPool,
}
impl Database {
/// Create a new pool from config and verify connectivity with a ping.
pub async fn connect(config: &DatabaseConfig) -> ImksResult<Self> {
let pool = PgPoolOptions::new()
.max_connections(config.max_connections)
.min_connections(config.min_connections)
.acquire_timeout(Duration::from_secs(config.connect_timeout_secs))
.idle_timeout(Duration::from_secs(config.idle_timeout_secs))
.connect(&config.url)
.await?;
tracing::info!(
max_connections = config.max_connections,
min_connections = config.min_connections,
"PostgreSQL pool created"
);
let db = Self { pool };
db.health_check().await?;
Ok(db)
}
/// Wrap an existing `PgPool` (useful for tests with shared pools).
pub fn from_pool(pool: PgPool) -> Self {
Self { pool }
}
/// Access the inner `PgPool` for query execution.
pub fn pool(&self) -> &PgPool {
&self.pool
}
/// Verify the database is reachable by executing `SELECT 1`.
pub async fn health_check(&self) -> ImksResult<()> {
let row = sqlx::query("SELECT 1 AS alive")
.fetch_one(&self.pool)
.await?;
let alive: i32 = row.get("alive");
if alive != 1 {
return Err(crate::ImksError::Internal(
"Database health check returned unexpected value".into(),
));
}
tracing::debug!("Database health check passed");
Ok(())
}
/// Gracefully close all connections in the pool.
pub async fn close(&self) {
self.pool.close().await;
tracing::info!("Database pool closed");
}
}
+6 -3
View File
@@ -1,4 +1,4 @@
use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use crate::engine::packet::{Packet, PacketData, PacketError, PacketType}; use crate::engine::packet::{Packet, PacketData, PacketError, PacketType};
@@ -226,7 +226,10 @@ mod tests {
let input: Vec<u8> = vec![b'4', 0x80, 0xFF, 0x00, 0x01]; let input: Vec<u8> = vec![b'4', 0x80, 0xFF, 0x00, 0x01];
let decoded = decode_packet_ws(&input).unwrap(); let decoded = decode_packet_ws(&input).unwrap();
assert_eq!(decoded.packet_type, PacketType::Message); assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01])); assert_eq!(
decoded.data,
PacketData::Binary(vec![0x80, 0xFF, 0x00, 0x01])
);
} }
#[test] #[test]
@@ -236,4 +239,4 @@ mod tests {
assert_eq!(decoded.packet_type, PacketType::Message); assert_eq!(decoded.packet_type, PacketType::Message);
assert_eq!(decoded.data, PacketData::Empty); assert_eq!(decoded.data, PacketData::Empty);
} }
} }
+26
View File
@@ -0,0 +1,26 @@
//! Health check endpoint for the imks server.
//!
//! Returns JSON with server status, version, and upstream connectivity.
use actix_web::HttpResponse;
use serde::Serialize;
#[derive(Serialize)]
struct HealthResponse {
status: String,
version: String,
timestamp: String,
uptime_secs: u64,
sessions_count: usize,
}
/// GET /health — returns server health status.
pub async fn health_check() -> HttpResponse {
HttpResponse::Ok().json(HealthResponse {
status: "healthy".into(),
version: env!("CARGO_PKG_VERSION").into(),
timestamp: chrono::Utc::now().to_rfc3339(),
uptime_secs: 0,
sessions_count: 0,
})
}
+1 -1
View File
@@ -74,4 +74,4 @@ impl HeartbeatManager {
self.store.remove(&sid); self.store.remove(&sid);
} }
} }
} }
+1
View File
@@ -1,4 +1,5 @@
pub mod codec; pub mod codec;
pub mod health;
pub mod heartbeat; pub mod heartbeat;
pub mod packet; pub mod packet;
pub mod polling; pub mod polling;
+5 -6
View File
@@ -61,11 +61,10 @@ pub struct Packet {
impl Packet { impl Packet {
pub fn open(handshake: &HandshakeData) -> Self { pub fn open(handshake: &HandshakeData) -> Self {
let data = serde_json::to_string(handshake) let data = serde_json::to_string(handshake).unwrap_or_else(|e| {
.unwrap_or_else(|e| { tracing::error!("Failed to serialize handshake data: {}", e);
tracing::error!("Failed to serialize handshake data: {}", e); "{}".to_string()
"{}".to_string() });
});
Self { Self {
packet_type: PacketType::Open, packet_type: PacketType::Open,
data: PacketData::Text(data), data: PacketData::Text(data),
@@ -148,4 +147,4 @@ pub enum PacketError {
InvalidUtf8(#[from] std::string::FromUtf8Error), InvalidUtf8(#[from] std::string::FromUtf8Error),
#[error("serialization error: {0}")] #[error("serialization error: {0}")]
Serialization(String), Serialization(String),
} }
+2 -2
View File
@@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use actix_web::{web, HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse, web};
use crate::engine::codec; use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType}; use crate::engine::packet::{Packet, PacketType};
@@ -182,4 +182,4 @@ async fn handle_handshake(store: &SessionStore, config: &EngineConfig) -> HttpRe
pub fn configure_polling(cfg: &mut web::ServiceConfig) { pub fn configure_polling(cfg: &mut web::ServiceConfig) {
cfg.route("/engine.io/", web::get().to(polling_get)) cfg.route("/engine.io/", web::get().to(polling_get))
.route("/engine.io/", web::post().to(polling_post)); .route("/engine.io/", web::post().to(polling_post));
} }
+52 -8
View File
@@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use actix_web::{web, App, HttpServer}; use actix_web::{App, HttpRequest, HttpResponse, HttpServer, web};
use crate::engine::heartbeat::HeartbeatManager; use crate::engine::heartbeat::HeartbeatManager;
use crate::engine::packet::Packet; use crate::engine::packet::Packet;
@@ -31,6 +31,53 @@ pub struct EngineServer {
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>, on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
} }
#[derive(Debug, serde::Deserialize)]
pub struct EngineQuery {
#[serde(rename = "EIO")]
pub eio: Option<String>,
pub transport: Option<String>,
pub sid: Option<String>,
}
pub async fn engine_get(
req: HttpRequest,
body: web::Payload,
query: web::Query<EngineQuery>,
store: web::Data<SessionStore>,
config: web::Data<EngineConfig>,
on_message: web::Data<Arc<dyn Fn(String, Packet) + Send + Sync>>,
) -> Result<HttpResponse, actix_web::Error> {
match query.transport.as_deref() {
Some("websocket") => {
crate::engine::websocket::websocket_handler(
req,
body,
web::Query(crate::engine::websocket::WsQuery {
eio: query.eio.clone(),
transport: query.transport.clone(),
sid: query.sid.clone(),
}),
store,
config,
on_message,
)
.await
}
_ => Ok(crate::engine::polling::polling_get(
req,
web::Query(crate::engine::polling::PollingQuery {
eio: query.eio.clone(),
transport: query.transport.clone(),
sid: query.sid.clone(),
}),
store,
config,
on_message,
)
.await),
}
}
impl EngineServer { impl EngineServer {
pub fn new( pub fn new(
config: EngineConfig, config: EngineConfig,
@@ -76,17 +123,14 @@ impl EngineServer {
.app_data(web::Data::new(config.clone())) .app_data(web::Data::new(config.clone()))
.app_data(web::Data::new(on_message.clone())) .app_data(web::Data::new(on_message.clone()))
.route( .route(
"/engine.io/", "/health",
web::get().to(crate::engine::polling::polling_get), web::get().to(crate::engine::health::health_check),
) )
.route("/engine.io/", web::get().to(engine_get))
.route( .route(
"/engine.io/", "/engine.io/",
web::post().to(crate::engine::polling::polling_post), web::post().to(crate::engine::polling::polling_post),
) )
.route(
"/engine.io/",
web::get().to(crate::engine::websocket::websocket_handler),
)
}) })
.bind(addr)? .bind(addr)?
.run() .run()
@@ -101,7 +145,7 @@ impl EngineServer {
port: u16, port: u16,
cert_path: &str, cert_path: &str,
key_path: &str, key_path: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> crate::ImksResult<()> {
crate::engine::webtransport::run_webtransport_server( crate::engine::webtransport::run_webtransport_server(
port, port,
cert_path, cert_path,
+6 -3
View File
@@ -2,7 +2,7 @@ use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{Notify, mpsc};
use crate::engine::packet::Packet; use crate::engine::packet::Packet;
@@ -124,7 +124,10 @@ impl SessionStore {
.sessions .sessions
.insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session))); .insert(sid.clone(), Arc::new(tokio::sync::RwLock::new(session)));
if old.is_some() { if old.is_some() {
tracing::warn!("Session ID collision for SID {}, replacing existing session", sid); tracing::warn!(
"Session ID collision for SID {}, replacing existing session",
sid
);
} }
rx rx
} }
@@ -168,4 +171,4 @@ pub fn generate_sid() -> String {
CHARSET[idx] as char CHARSET[idx] as char
}) })
.collect() .collect()
} }
+1 -4
View File
@@ -1,10 +1,7 @@
use crate::engine::packet::Packet; use crate::engine::packet::Packet;
use crate::engine::session::{SessionState, SessionStore, TransportType}; use crate::engine::session::{SessionState, SessionStore, TransportType};
pub async fn handle_upgrade_probe( pub async fn handle_upgrade_probe(store: &SessionStore, sid: &str) -> Result<Packet, UpgradeError> {
store: &SessionStore,
sid: &str,
) -> Result<Packet, UpgradeError> {
let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?; let session = store.get(sid).ok_or(UpgradeError::SessionNotFound)?;
let mut session = session.write().await; let mut session = session.write().await;
+55 -42
View File
@@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use actix_web::{web, HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse, web};
use actix_ws::Message; use actix_ws::Message;
use crate::engine::codec; use crate::engine::codec;
@@ -36,37 +36,37 @@ pub async fn websocket_handler(
let sid = query.sid.clone(); let sid = query.sid.clone();
let is_upgrade = sid.as_ref().map(|s| store.exists(s)).unwrap_or(false); if let Some(ref sid) = sid
&& !store.exists(sid)
{
return Ok(HttpResponse::BadRequest().body("unknown session"));
}
// Create or reuse session, obtaining the mpsc receiver for the forwarding task // Create or reuse session, obtaining the mpsc receiver for the forwarding task
let (session_sid, mut session_rx) = if let Some(ref sid) = sid { let (session_sid, mut session_rx) = if let Some(ref sid) = sid {
if is_upgrade { // Upgrade: session already exists, replace its channel and drain pending packets
// Upgrade: session already exists, replace its channel and drain pending packets let session_arc = match store.get(sid) {
let session_arc = store.get(sid).unwrap(); Some(s) => s,
let (new_tx, new_rx) = tokio::sync::mpsc::channel(256); None => {
{ tracing::error!("Session {} not found for upgrade", sid);
let mut s = session_arc.write().await; return Ok(HttpResponse::InternalServerError().body("session not found"));
// Swap tx atomically: old_tx will be dropped, closing its channel.
// Any packets in the old rx are consumed by the old send_handle,
// which then exits when it sees the channel close.
// Drain pending_packets (from polling buffering) into new channel.
let pending = s.take_pending();
for packet in pending {
let _ = new_tx.try_send(packet);
}
s.tx = new_tx;
s.set_transport(TransportType::WebSocket);
} }
(sid.clone(), new_rx) };
} else { let (new_tx, new_rx) = tokio::sync::mpsc::channel(256);
// Reconnect with known SID: create new session {
let rx = store.create(sid.clone(), TransportType::WebSocket); let mut s = session_arc.write().await;
if let Some(s) = store.get(sid) { // Swap tx atomically: old_tx will be dropped, closing its channel.
let mut s = s.write().await; // Any packets in the old rx are consumed by the old send_handle,
s.set_state(SessionState::Open); // which then exits when it sees the channel close.
// Drain pending_packets (from polling buffering) into new channel.
let pending = s.take_pending();
for packet in pending {
let _ = new_tx.try_send(packet);
} }
(sid.clone(), rx) s.tx = new_tx;
s.set_transport(TransportType::WebSocket);
} }
(sid.clone(), new_rx)
} else { } else {
// New connection: generate SID and create session // New connection: generate SID and create session
let new_sid = crate::engine::session::generate_sid(); let new_sid = crate::engine::session::generate_sid();
@@ -89,7 +89,10 @@ pub async fn websocket_handler(
let open_packet = Packet::open(&handshake); let open_packet = Packet::open(&handshake);
let open_msg = codec::encode_packet(&open_packet); let open_msg = codec::encode_packet(&open_packet);
if ws_session.text(open_msg).await.is_err() { if ws_session.text(open_msg).await.is_err() {
tracing::warn!("Failed to send open packet to WebSocket session {}", session_sid); tracing::warn!(
"Failed to send open packet to WebSocket session {}",
session_sid
);
store.remove(&session_sid); store.remove(&session_sid);
return Ok(response); return Ok(response);
} }
@@ -121,16 +124,26 @@ pub async fn websocket_handler(
while let Some(Ok(msg)) = msg_stream.recv().await { while let Some(Ok(msg)) = msg_stream.recv().await {
match msg { match msg {
Message::Text(text) => { Message::Text(text) => {
if text.len() > max_payload {
tracing::warn!(
"Text payload too large ({}) for session {}",
text.len(),
sid_clone
);
let _ = ws_session.close(None).await;
break;
}
if let Ok(packet) = codec::decode_packet(&text) { if let Ok(packet) = codec::decode_packet(&text) {
match packet.packet_type { match packet.packet_type {
PacketType::Ping => { PacketType::Ping => {
if let PacketData::Text(ref data) = packet.data { if let PacketData::Text(ref data) = packet.data
if data == "probe" { && data == "probe"
let pong = Packet::pong("probe"); {
let pong_msg = codec::encode_packet(&pong); let pong = Packet::pong("probe");
let _ = ws_session.text(pong_msg).await; let pong_msg = codec::encode_packet(&pong);
continue; let _ = ws_session.text(pong_msg).await;
} continue;
} }
let pong = Packet::pong(""); let pong = Packet::pong("");
let pong_msg = codec::encode_packet(&pong); let pong_msg = codec::encode_packet(&pong);
@@ -180,14 +193,14 @@ pub async fn websocket_handler(
continue; continue;
} }
if let Ok(packet) = codec::decode_packet_ws(&bin) { if let Ok(packet) = codec::decode_packet_ws(&bin)
if packet.packet_type == PacketType::Message { && packet.packet_type == PacketType::Message
let on_msg = on_message_clone.clone(); {
let sid = sid_clone.clone(); let on_msg = on_message_clone.clone();
tokio::spawn(async move { let sid = sid_clone.clone();
on_msg(sid, packet); tokio::spawn(async move {
}); on_msg(sid, packet);
} });
} }
} }
Message::Close(_) => { Message::Close(_) => {
+86 -71
View File
@@ -1,11 +1,12 @@
use std::sync::Arc; use std::sync::Arc;
use wtransport::{Connection, Endpoint, ServerConfig, Identity}; use wtransport::{Connection, Endpoint, Identity, ServerConfig};
use crate::engine::codec; use crate::engine::codec;
use crate::engine::packet::{Packet, PacketType}; use crate::engine::packet::{Packet, PacketType};
use crate::engine::server::EngineConfig; use crate::engine::server::EngineConfig;
use crate::engine::session::{SessionState, SessionStore, TransportType}; use crate::engine::session::{SessionState, SessionStore, TransportType};
use crate::{ImksError, ImksResult};
pub async fn run_webtransport_server( pub async fn run_webtransport_server(
port: u16, port: u16,
@@ -14,15 +15,18 @@ pub async fn run_webtransport_server(
store: SessionStore, store: SessionStore,
config: EngineConfig, config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>, on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> ImksResult<()> {
let identity = Identity::load_pemfiles(cert_path, key_path).await?; let identity = Identity::load_pemfiles(cert_path, key_path)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let server_config = ServerConfig::builder() let server_config = ServerConfig::builder()
.with_bind_default(port) .with_bind_default(port)
.with_identity(identity) .with_identity(identity)
.build(); .build();
let server = Endpoint::server(server_config)?; let server =
Endpoint::server(server_config).map_err(|e| ImksError::WebTransport(e.to_string()))?;
tracing::info!("WebTransport server listening on UDP port {}", port); tracing::info!("WebTransport server listening on UDP port {}", port);
@@ -49,9 +53,14 @@ async fn handle_webtransport_session(
store: SessionStore, store: SessionStore,
config: EngineConfig, config: EngineConfig,
on_message: Arc<dyn Fn(String, Packet) + Send + Sync>, on_message: Arc<dyn Fn(String, Packet) + Send + Sync>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> ImksResult<()> {
let request = incoming.await?; let request = incoming
let connection = request.accept().await?; .await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let connection = request
.accept()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let sid = crate::engine::session::generate_sid(); let sid = crate::engine::session::generate_sid();
let mut rx = store.create(sid.clone(), TransportType::WebTransport); let mut rx = store.create(sid.clone(), TransportType::WebTransport);
@@ -81,65 +90,57 @@ async fn handle_webtransport_session(
// Reuse buffer across recv iterations instead of allocating 65KB each time // Reuse buffer across recv iterations instead of allocating 65KB each time
let recv_handle = tokio::spawn(async move { let recv_handle = tokio::spawn(async move {
let mut buf = vec![0u8; 65536]; let mut buf = vec![0u8; 65536];
loop { while let Ok((mut send, mut recv)) = connection_recv.accept_bi().await {
match connection_recv.accept_bi().await { // Reset buffer length for the next read without deallocating
Ok((mut send, mut recv)) => { buf.resize(65536, 0);
// Reset buffer length for the next read without deallocating match recv.read(&mut buf).await {
buf.resize(65536, 0); Ok(Some(n)) => {
match recv.read(&mut buf).await { if n > max_payload {
Ok(Some(n)) => { tracing::warn!(
if n > max_payload { "WebTransport payload too large ({}) for session {}",
tracing::warn!( n,
"WebTransport payload too large ({}) for session {}", sid_clone
n, );
sid_clone continue;
); }
continue; if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) {
} match packet.packet_type {
if let Ok(packet) = codec::decode_packet_ws(&buf[..n]) { PacketType::Ping => {
match packet.packet_type { let pong = Packet::pong("");
PacketType::Ping => { if send_wt_packet_on_stream(&mut send, &pong).await.is_err() {
let pong = Packet::pong(""); break;
if send_wt_packet_on_stream(&mut send, &pong)
.await
.is_err()
{
break;
}
}
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
} }
} }
PacketType::Pong => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.update_ping();
}
}
PacketType::Message => {
let on_msg = on_message_clone.clone();
let sid = sid_clone.clone();
tokio::spawn(async move {
on_msg(sid, packet);
});
}
PacketType::Close => {
if let Some(s) = store_clone.get(&sid_clone) {
let mut s = s.write().await;
s.set_state(SessionState::Closed);
}
store_clone.remove(&sid_clone);
break;
}
_ => {}
} }
Ok(None) => break,
Err(_) => break,
} }
} }
Ok(None) => break,
Err(_) => break, Err(_) => break,
} }
} }
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(()) Ok::<(), ImksError>(())
}); });
let connection_send = connection.clone(); let connection_send = connection.clone();
@@ -191,18 +192,26 @@ async fn handle_webtransport_session(
Ok(()) Ok(())
} }
async fn send_wt_packet( async fn send_wt_packet(connection: &Connection, packet: &Packet) -> ImksResult<()> {
connection: &Connection, let (mut send, _recv) = connection
packet: &Packet, .open_bi()
) -> Result<(), Box<dyn std::error::Error>> { .await
let (mut send, _recv) = connection.open_bi().await?.await?; .map_err(|e| ImksError::WebTransport(e.to_string()))?
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
let encoded = codec::encode_packet_binary_ws(packet); let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_)); let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary); let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header).await?; send.write_all(&header)
send.write_all(&encoded).await?; .await
send.finish().await?; .map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.write_all(&encoded)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.finish()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -210,14 +219,20 @@ async fn send_wt_packet(
async fn send_wt_packet_on_stream( async fn send_wt_packet_on_stream(
send: &mut wtransport::SendStream, send: &mut wtransport::SendStream,
packet: &Packet, packet: &Packet,
) -> Result<(), Box<dyn std::error::Error>> { ) -> ImksResult<()> {
let encoded = codec::encode_packet_binary_ws(packet); let encoded = codec::encode_packet_binary_ws(packet);
let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_)); let is_binary = matches!(packet.data, crate::engine::packet::PacketData::Binary(_));
let header = codec::encode_webtransport_header(encoded.len(), is_binary); let header = codec::encode_webtransport_header(encoded.len(), is_binary);
send.write_all(&header).await?; send.write_all(&header)
send.write_all(&encoded).await?; .await
send.finish().await?; .map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.write_all(&encoded)
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
send.finish()
.await
.map_err(|e| ImksError::WebTransport(e.to_string()))?;
Ok(()) Ok(())
} }
+255
View File
@@ -0,0 +1,255 @@
//! Unified error type for imks.
//!
//! Consolidates all submodule-specific error enums into a single `ImksError`.
//! Public APIs return `ImksResult<T>`, private code may still use local enums
//! for internal dispatch and convert via `From` / `.map_err()` when crossing
//! module boundaries.
use std::string::FromUtf8Error;
use thiserror::Error;
/// Unified error enum for the entire imks crate.
#[derive(Debug, Error)]
pub enum ImksError {
// Protocol layer (engine)
#[error("invalid engine packet type: {0}")]
InvalidEnginePacketType(u8),
#[error("invalid engine packet type char: {0}")]
InvalidEnginePacketTypeChar(char),
#[error("empty engine packet")]
EmptyEnginePacket,
#[error("invalid base64: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("invalid utf8 in packet: {0}")]
InvalidPacketUtf8(#[from] FromUtf8Error),
#[error("engine serialization error: {0}")]
EngineSerialization(String),
// Transport upgrade
#[error("session not found for upgrade")]
UpgradeSessionNotFound,
#[error("session already closed, cannot upgrade")]
UpgradeSessionClosed,
#[error("invalid session state for upgrade")]
UpgradeInvalidState,
// Socket.IO layer
#[error("invalid socket packet type: {0}")]
InvalidSocketPacketType(u8),
#[error("invalid socket packet type char: {0}")]
InvalidSocketPacketTypeChar(char),
#[error("empty socket packet")]
EmptySocketPacket,
#[error("invalid socket packet format: {0}")]
InvalidSocketPacketFormat(String),
#[error("missing namespace in socket packet")]
MissingNamespace,
#[error("invalid attachment count in binary event")]
InvalidAttachmentCount,
// Socket namespace
#[error("namespace error: {0}")]
Namespace(String),
#[error("socket not found: {0}")]
SocketNotFound(String),
#[error("failed to send packet to socket: channel full")]
SocketSendFull,
// Adapter layer
#[error("adapter redis error: {0}")]
AdapterRedis(String),
#[error("adapter nats error: {0}")]
AdapterNats(String),
#[error("adapter message bus error: {0}")]
AdapterMessageBus(String),
#[error("adapter serialization error: {0}")]
AdapterSerialization(String),
#[error("adapter room error: {0}")]
AdapterRoom(String),
// Message bus
#[error("message bus connection closed")]
MessageBusConnectionClosed,
#[error("message bus channel not found: {0}")]
MessageBusChannelNotFound(String),
// Session store
#[error("session not found: {0}")]
SessionNotFound(String),
#[error("session expired: {0}")]
SessionExpired(String),
#[error("session store redis error: {0}")]
SessionRedis(String),
#[error("session serialization error: {0}")]
SessionSerialization(String),
// Database
#[error("database error: {0}")]
Database(#[from] sqlx::Error),
// gRPC
#[error("gRPC error: {0}")]
GrpcStatus(#[from] tonic::Status),
#[error("gRPC transport error: {0}")]
GrpcTransport(#[from] tonic::transport::Error),
// Serialization
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
// Transport
#[error("webtransport error: {0}")]
WebTransport(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
// Auth
#[error("auth error: {0}")]
Auth(String),
#[error("token expired")]
TokenExpired,
// General
#[error("not found: {0}")]
NotFound(String),
#[error("invalid input: {0}")]
InvalidInput(String),
#[error("internal error: {0}")]
Internal(String),
}
/// Convenience alias used across all public APIs.
pub type ImksResult<T> = Result<T, ImksError>;
// Conversions from submodule error types (for gradual migration).
impl From<crate::engine::packet::PacketError> for ImksError {
fn from(e: crate::engine::packet::PacketError) -> Self {
use crate::engine::packet::PacketError::*;
match e {
InvalidType(v) => ImksError::InvalidEnginePacketType(v),
InvalidTypeChar(c) => ImksError::InvalidEnginePacketTypeChar(c),
Empty => ImksError::EmptyEnginePacket,
InvalidBase64(e) => ImksError::InvalidBase64(e),
InvalidUtf8(e) => ImksError::InvalidPacketUtf8(e),
Serialization(s) => ImksError::EngineSerialization(s),
}
}
}
impl From<crate::engine::upgrade::UpgradeError> for ImksError {
fn from(e: crate::engine::upgrade::UpgradeError) -> Self {
use crate::engine::upgrade::UpgradeError::*;
match e {
SessionNotFound => ImksError::UpgradeSessionNotFound,
SessionClosed => ImksError::UpgradeSessionClosed,
InvalidState => ImksError::UpgradeInvalidState,
}
}
}
impl From<crate::socket::packet::PacketError> for ImksError {
fn from(e: crate::socket::packet::PacketError) -> Self {
use crate::socket::packet::PacketError::*;
match e {
InvalidType(v) => ImksError::InvalidSocketPacketType(v),
InvalidTypeChar(c) => ImksError::InvalidSocketPacketTypeChar(c),
Empty => ImksError::EmptySocketPacket,
InvalidFormat(s) => ImksError::InvalidSocketPacketFormat(s),
Json(e) => ImksError::Json(e),
MissingNamespace => ImksError::MissingNamespace,
InvalidAttachmentCount => ImksError::InvalidAttachmentCount,
}
}
}
impl From<crate::socket::adapter::AdapterError> for ImksError {
fn from(e: crate::socket::adapter::AdapterError) -> Self {
use crate::socket::adapter::AdapterError::*;
match e {
Redis(s) => ImksError::AdapterRedis(s),
Nats(s) => ImksError::AdapterNats(s),
MessageBus(s) => ImksError::AdapterMessageBus(s),
Serialization(s) => ImksError::AdapterSerialization(s),
Room(s) => ImksError::AdapterRoom(s),
}
}
}
impl From<crate::socket::message_bus::MessageBusError> for ImksError {
fn from(e: crate::socket::message_bus::MessageBusError) -> Self {
use crate::socket::message_bus::MessageBusError::*;
match e {
Redis(s) => ImksError::AdapterRedis(s),
Nats(s) => ImksError::AdapterNats(s),
ConnectionClosed => ImksError::MessageBusConnectionClosed,
ChannelNotFound(s) => ImksError::MessageBusChannelNotFound(s),
Serialization(s) => ImksError::AdapterSerialization(s),
}
}
}
impl From<crate::socket::session_store::SessionError> for ImksError {
fn from(e: crate::socket::session_store::SessionError) -> Self {
use crate::socket::session_store::SessionError::*;
match e {
Redis(s) => ImksError::SessionRedis(s),
NotFound(s) => ImksError::SessionNotFound(s),
Serialization(s) => ImksError::SessionSerialization(s),
Expired(s) => ImksError::SessionExpired(s),
}
}
}
impl From<tokio::sync::mpsc::error::TrySendError<crate::socket::packet::Packet>> for ImksError {
fn from(_: tokio::sync::mpsc::error::TrySendError<crate::socket::packet::Packet>) -> Self {
ImksError::SocketSendFull
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_imks_error_display() {
let err = ImksError::NotFound("message 01909abc".into());
assert_eq!(err.to_string(), "not found: message 01909abc");
}
#[test]
fn test_imks_error_from_base64() {
let b64_err = base64::DecodeError::InvalidByte(0, b'!');
let err: ImksError = b64_err.into();
assert!(matches!(err, ImksError::InvalidBase64(_)));
}
#[test]
fn test_imks_error_from_sqlx() {
// sqlx::Error doesn't impl PartialEq, so just check the variant
let db_err = sqlx::Error::PoolClosed;
let err: ImksError = db_err.into();
assert!(matches!(err, ImksError::Database(_)));
}
#[test]
fn test_imks_error_from_serde_json() {
let json_err = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
let err: ImksError = json_err.into();
assert!(matches!(err, ImksError::Json(_)));
}
#[test]
#[allow(clippy::unnecessary_literal_unwrap)]
fn test_imks_result_ok() {
let result: ImksResult<i32> = Ok(42);
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_imks_result_err() {
let result: ImksResult<i32> = Err(ImksError::TokenExpired);
assert!(result.is_err());
}
}
+10 -1
View File
@@ -1,3 +1,12 @@
pub mod auth;
pub mod database;
pub mod engine;
pub mod error;
pub mod models;
pub mod pb; pub mod pb;
pub mod repo;
pub mod rpc;
pub mod socket; pub mod socket;
pub mod engine; pub mod svc;
pub use error::{ImksError, ImksResult};
+243 -23
View File
@@ -1,9 +1,16 @@
use std::sync::Arc; use std::sync::{Arc, OnceLock};
use imks::database::{Database, DatabaseConfig};
use imks::engine::server::EngineConfig; use imks::engine::server::EngineConfig;
use imks::socket::server::SocketServer; use imks::repo::MessageRepo;
use imks::rpc::{AppksClients, RpcConfig};
use imks::socket::adapter::{LocalBroadcastFn, NatsAdapter, RedisAdapter};
use imks::socket::message_bus::{NatsMessageBus, RedisMessageBus};
fn main() { use imks::socket::server::SocketServerBuilder;
use imks::svc::{DeployConfig, MessageService};
fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter( .with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env()
@@ -11,27 +18,240 @@ fn main() {
) )
.init(); .init();
let config = EngineConfig::default(); let deploy = DeployConfig::from_env();
let socket_server = Arc::new(SocketServer::new(config)); tracing::info!(
adapter = %deploy.adapter_mode,
server_id = %deploy.server_id,
wt_enabled = deploy.webtransport_enabled,
"Starting imks server"
);
let addr = "0.0.0.0:3000"; let addr = "0.0.0.0:3000";
tracing::info!("Starting Socket.IO server on {}", addr);
tokio::runtime::Runtime::new() let rt = tokio::runtime::Runtime::new()?;
.expect("Failed to create Tokio runtime")
.block_on(async {
let namespace = socket_server.of("/");
namespace
.on_connect(|socket, _auth| {
tracing::info!(
"Socket {} connected (engine: {})",
socket.sid,
socket.engine_sid
);
Ok(())
})
.await;
socket_server.run_http(addr).await.expect("Server error"); rt.block_on(async {
}); let engine_config = EngineConfig::default();
} let mut builder = SocketServerBuilder::new(engine_config);
let namespace_holder: Arc<OnceLock<Arc<imks::socket::namespace::NamespaceManager>>> =
Arc::new(OnceLock::new());
// Pre-configure adapter for Redis/NATS mode.
// The callback resolves namespaces after SocketServer is built.
match deploy.adapter_mode.as_str() {
"redis" => {
let message_bus = Arc::new(
RedisMessageBus::new(&deploy.redis_url)
.await
.map_err(|e| format!("Failed to connect to Redis: {e}"))?,
);
let redis_client = message_bus.client().clone();
let server_id = deploy.server_id.clone();
let adapter = Arc::new(RedisAdapter::new(
message_bus.clone() as Arc<_>,
redis_client,
server_id,
"/".into(),
make_local_broadcast_fn(namespace_holder.clone()),
));
adapter
.init()
.await
.map_err(|e| format!("Failed to initialize Redis adapter: {e}"))?;
builder = builder.adapter(adapter);
tracing::info!("Redis adapter configured for multi-node");
}
"nats" => {
let message_bus = Arc::new(
NatsMessageBus::new(&deploy.nats_url)
.await
.map_err(|e| format!("Failed to connect to NATS: {e}"))?,
);
let server_id = deploy.server_id.clone();
let adapter = Arc::new(NatsAdapter::new(
message_bus.clone() as Arc<_>,
server_id,
"/".into(),
make_local_broadcast_fn(namespace_holder.clone()),
));
adapter
.init()
.await
.map_err(|e| format!("Failed to initialize NATS adapter: {e}"))?;
builder = builder.adapter(adapter);
tracing::info!("NATS adapter configured for multi-node");
}
_ => {
tracing::info!("Local adapter (single-node mode)");
}
};
let socket_server = Arc::new(builder.build());
let _ = namespace_holder.set(socket_server.namespaces.clone());
// Initialize database + gRPC + service
let service: Option<Arc<MessageService>> = {
let rpc_config = RpcConfig::from_env();
let db_config = DatabaseConfig::from_env();
match AppksClients::connect(&rpc_config).await {
Ok(clients) => {
let db = Database::connect(&db_config)
.await
.map_err(|e| format!("Database connection failed: {e}"))?;
imks::database::run_migrations(db.pool())
.await
.map_err(|e| format!("Database migration failed: {e}"))?;
let repo = MessageRepo::new(db.pool().clone());
let svc = MessageService::new(repo, clients, socket_server.namespaces.clone())
.await
.map_err(|e| format!("Failed to initialize message service: {e}"))?;
tracing::info!("Message service initialized with gRPC permission checks");
Some(Arc::new(svc))
}
Err(e) => {
tracing::warn!("gRPC unavailable: {e}. Running without permission checks.");
None
}
}
};
// Register connect handler
let namespace = socket_server.of("/");
let svc_connect = service.clone();
namespace
.on_connect(move |socket, auth_data| {
if let Some(ref svc) = svc_connect {
svc.authenticate_socket(socket, auth_data)
.map_err(|e| e.to_string())?;
}
tracing::info!(
"Socket {} connected (engine: {})",
socket.sid,
socket.engine_sid
);
Ok(())
})
.await;
// Register Socket.IO event handlers
if let Some(ref svc) = service {
macro_rules! register_event {
($svc:expr, $ns:expr, $event:expr, $method:ident) => {
let s = $svc.clone();
$ns.on_event($event, Arc::new(move |socket, data| {
let s = s.clone();
let data = data.clone();
tokio::spawn(async move {
if let Err(e) = s.$method(socket, &data).await {
tracing::error!(event = $event, error = %e, "Event handler failed");
}
});
})).await;
};
}
register_event!(svc, namespace, "channel:join", join_channel);
register_event!(svc, namespace, "channel:leave", leave_channel);
register_event!(svc, namespace, "message:send", send_message);
register_event!(svc, namespace, "message:edit", edit_message);
register_event!(svc, namespace, "message:delete", delete_message);
register_event!(svc, namespace, "reaction:add", toggle_reaction);
register_event!(svc, namespace, "pin:add", pin_message);
register_event!(svc, namespace, "pin:remove", unpin_message);
register_event!(svc, namespace, "poll:vote", poll_vote);
register_event!(svc, namespace, "poll:vote:remove", poll_remove_vote);
register_event!(svc, namespace, "typing:start", typing_start);
register_event!(svc, namespace, "typing:stop", typing_stop);
register_event!(svc, namespace, "presence:update", presence_update);
register_event!(svc, namespace, "draft:save", save_draft);
register_event!(svc, namespace, "draft:get", get_draft);
register_event!(svc, namespace, "draft:delete", delete_draft);
register_event!(svc, namespace, "read_state:mark", mark_read);
register_event!(svc, namespace, "read_state:get", get_read_state);
register_event!(svc, namespace, "notification:list", list_notifications);
register_event!(
svc,
namespace,
"notification:mark_read",
mark_notification_read
);
register_event!(
svc,
namespace,
"notification:mark_all_read",
mark_all_notifications_read
);
register_event!(svc, namespace, "bookmark:add", add_bookmark);
register_event!(svc, namespace, "bookmark:remove", remove_bookmark);
register_event!(svc, namespace, "bookmark:list", list_bookmarks);
register_event!(svc, namespace, "thread:create", create_thread);
register_event!(svc, namespace, "thread:resolve", resolve_thread);
register_event!(svc, namespace, "thread:join", join_thread);
register_event!(svc, namespace, "thread:leave", leave_thread);
register_event!(svc, namespace, "thread:list", list_threads);
register_event!(svc, namespace, "article:create", create_article);
register_event!(svc, namespace, "article:update", update_article);
register_event!(svc, namespace, "article:list", list_articles);
register_event!(svc, namespace, "article:delete", delete_article);
register_event!(svc, namespace, "component:interact", interact_component);
// Start scheduled message dispatcher (background task)
svc.clone().start_scheduled_dispatcher();
tracing::info!("Registered Socket.IO event handlers");
}
// Start servers
if deploy.webtransport_enabled && !deploy.cert_path.is_empty() {
let engine = socket_server.engine.clone();
let wt_port = deploy.webtransport_port;
let cert_path = deploy.cert_path.clone();
let key_path = deploy.key_path.clone();
let server = socket_server.clone();
tracing::info!("Starting HTTP on {} + WebTransport on {}", addr, wt_port);
tokio::select! {
result = server.run_http(addr) => {
result?;
}
result = engine.run_webtransport(wt_port, &cert_path, &key_path) => {
result?;
}
}
} else {
tracing::info!("Socket.IO HTTP server listening on {}", addr);
socket_server.run_http(addr).await?;
}
Ok::<(), Box<dyn std::error::Error>>(())
})?;
Ok(())
}
/// Create a local broadcast function for Redis/NATS adapters.
///
/// The callback is used both for same-node delivery and for cross-node messages
/// received from the message bus.
fn make_local_broadcast_fn(
namespaces: Arc<OnceLock<Arc<imks::socket::namespace::NamespaceManager>>>,
) -> LocalBroadcastFn {
Arc::new(move |packet, opts| {
let Some(manager) = namespaces.get() else {
tracing::warn!(namespace = %packet.namespace, "Namespace manager not initialized");
return;
};
let Some(namespace) = manager.get_namespace(&packet.namespace) else {
tracing::warn!(namespace = %packet.namespace, "Namespace not found for local broadcast");
return;
};
namespace.emit_local_filtered(packet, opts);
})
}
+23
View File
@@ -0,0 +1,23 @@
-- Create message_thread before migrations that reference it.
-- Safe for existing databases because the table may already exist from 004.
BEGIN;
CREATE TABLE IF NOT EXISTS message_thread (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
channel_id UUID NOT NULL,
root_message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
created_by UUID NOT NULL,
replies_count BIGINT NOT NULL DEFAULT 0,
participants_count BIGINT NOT NULL DEFAULT 0,
last_reply_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
last_reply_at TIMESTAMPTZ NULL,
resolved BOOLEAN NOT NULL DEFAULT FALSE,
resolved_by UUID NULL,
resolved_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_thread_root UNIQUE (root_message_id)
);
COMMIT;
+115
View File
@@ -0,0 +1,115 @@
-- ============================================================
-- Migration: 001_message_rich_content.sql
-- Tables: message_attachment, message_embed, message_embed_field,
-- message_poll, message_poll_option, message_poll_vote
-- ============================================================
-- These tables extend the existing `message` table (from appks 001_init.sql)
-- with Discord-style rich content: file attachments, link preview embeds,
-- and interactive polls.
BEGIN;
-- models/message_attachment.rs → message_attachment
CREATE TABLE IF NOT EXISTS message_attachment (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
filename TEXT NOT NULL,
content_type TEXT NULL,
size BIGINT NOT NULL,
url TEXT NOT NULL,
storage_key TEXT NULL,
width INTEGER NULL,
height INTEGER NULL,
duration_secs DOUBLE PRECISION NULL,
blurhash TEXT NULL,
spoiler BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_attachment_message_id
ON message_attachment (message_id);
-- models/message_embed.rs → message_embed
CREATE TABLE IF NOT EXISTS message_embed (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
embed_type TEXT NOT NULL,
title TEXT NULL,
description TEXT NULL,
url TEXT NULL,
color INTEGER NULL,
image_url TEXT NULL,
image_width INTEGER NULL,
image_height INTEGER NULL,
thumbnail_url TEXT NULL,
thumbnail_width INTEGER NULL,
thumbnail_height INTEGER NULL,
video_url TEXT NULL,
video_width INTEGER NULL,
video_height INTEGER NULL,
author_name TEXT NULL,
author_url TEXT NULL,
author_icon_url TEXT NULL,
footer_text TEXT NULL,
footer_icon_url TEXT NULL,
provider_name TEXT NULL,
provider_url TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_embed_message_id
ON message_embed (message_id);
-- models/message_embed.rs → message_embed_field
CREATE TABLE IF NOT EXISTS message_embed_field (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
embed_id UUID NOT NULL REFERENCES message_embed(id) ON DELETE CASCADE,
name TEXT NOT NULL,
value TEXT NOT NULL,
inline BOOLEAN NOT NULL DEFAULT FALSE,
position INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_message_embed_field_embed_id
ON message_embed_field (embed_id);
-- models/message_poll.rs → message_poll
CREATE TABLE IF NOT EXISTS message_poll (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
question TEXT NOT NULL,
allow_multiselect BOOLEAN NOT NULL DEFAULT FALSE,
max_selections INTEGER NULL,
expires_at TIMESTAMPTZ NULL,
total_votes BIGINT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_poll_message UNIQUE (message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_poll_message_id
ON message_poll (message_id);
-- models/message_poll.rs → message_poll_option
CREATE TABLE IF NOT EXISTS message_poll_option (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
poll_id UUID NOT NULL REFERENCES message_poll(id) ON DELETE CASCADE,
text TEXT NOT NULL,
emoji TEXT NULL,
vote_count BIGINT NOT NULL DEFAULT 0,
position INTEGER NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_message_poll_option_poll_id
ON message_poll_option (poll_id);
-- models/message_poll.rs → message_poll_vote
CREATE TABLE IF NOT EXISTS message_poll_vote (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
poll_id UUID NOT NULL REFERENCES message_poll(id) ON DELETE CASCADE,
option_id UUID NOT NULL REFERENCES message_poll_option(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_poll_vote UNIQUE (poll_id, user_id, option_id)
);
CREATE INDEX IF NOT EXISTS idx_message_poll_vote_poll_id
ON message_poll_vote (poll_id);
CREATE INDEX IF NOT EXISTS idx_message_poll_vote_user_id
ON message_poll_vote (user_id);
COMMIT;
+76
View File
@@ -0,0 +1,76 @@
-- ============================================================
-- Migration: 002_message_social.sql
-- Tables: message_pin, message_read_state, message_draft, message_edit
-- ============================================================
-- Extends the message subsystem with pinned messages, read receipts,
-- drafts, and edit history.
BEGIN;
-- models/message_pin.rs → message_pin
CREATE TABLE IF NOT EXISTS message_pin (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
channel_id UUID NOT NULL,
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
pinned_by UUID NOT NULL,
position INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_pin_channel_message UNIQUE (channel_id, message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_pin_channel_id
ON message_pin (channel_id);
-- models/message_read_state.rs → message_read_state
CREATE TABLE IF NOT EXISTS message_read_state (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
channel_id UUID NOT NULL,
user_id UUID NOT NULL,
last_read_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
last_read_at TIMESTAMPTZ NULL,
unread_count BIGINT NOT NULL DEFAULT 0,
unread_mentions BIGINT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_read_state_channel_user UNIQUE (channel_id, user_id)
);
CREATE INDEX IF NOT EXISTS idx_message_read_state_user_id
ON message_read_state (user_id);
CREATE INDEX IF NOT EXISTS idx_message_read_state_channel_id
ON message_read_state (channel_id);
-- models/message_draft.rs → message_draft
CREATE TABLE IF NOT EXISTS message_draft (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
channel_id UUID NOT NULL,
user_id UUID NOT NULL,
thread_id UUID NULL REFERENCES message_thread(id) ON DELETE CASCADE,
reply_to_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
body TEXT NOT NULL DEFAULT '',
metadata JSONB NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_draft_channel_user_thread
UNIQUE (channel_id, user_id, thread_id)
);
CREATE INDEX IF NOT EXISTS idx_message_draft_user_id
ON message_draft (user_id);
-- models/message_edit.rs → message_edit
CREATE TABLE IF NOT EXISTS message_edit (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
edited_by UUID NOT NULL,
old_body TEXT NOT NULL,
new_body TEXT NOT NULL,
edited_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_edit_message_id
ON message_edit (message_id);
CREATE INDEX IF NOT EXISTS idx_message_edit_edited_at
ON message_edit (edited_at);
COMMIT;
+45
View File
@@ -0,0 +1,45 @@
-- ============================================================
-- Migration: 003_message_article.sql
-- Tables: message_article
-- ============================================================
-- Extends the message subsystem with forum-style article posts.
-- Articles extend regular messages with title, cover image, tags,
-- and view/like stats. Rendered as waterfall cards in forum channels.
BEGIN;
-- models/message_article.rs → message_article
CREATE TABLE IF NOT EXISTS message_article (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
title TEXT NOT NULL,
summary TEXT NULL,
cover_url TEXT NULL,
cover_width INTEGER NULL,
cover_height INTEGER NULL,
cover_color TEXT NULL,
tags JSONB NULL,
view_count BIGINT NOT NULL DEFAULT 0,
like_count BIGINT NOT NULL DEFAULT 0,
bookmark_count BIGINT NOT NULL DEFAULT 0,
reply_count BIGINT NOT NULL DEFAULT 0,
last_reply_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
last_reply_at TIMESTAMPTZ NULL,
last_reply_user_id UUID NULL,
is_pinned_to_top BOOLEAN NOT NULL DEFAULT FALSE,
is_answered BOOLEAN NOT NULL DEFAULT FALSE,
answered_by UUID NULL,
answered_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_article_message UNIQUE (message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_article_last_reply_at
ON message_article (last_reply_at DESC NULLS LAST);
CREATE INDEX IF NOT EXISTS idx_message_article_is_pinned_to_top
ON message_article (is_pinned_to_top DESC, last_reply_at DESC NULLS LAST);
CREATE INDEX IF NOT EXISTS idx_message_article_view_count
ON message_article (view_count DESC);
COMMIT;
+98
View File
@@ -0,0 +1,98 @@
-- ============================================================
-- Migration: 004_message_social_part2.sql
-- Tables: message_reaction, message_bookmark, message_mention,
-- message_thread, message_thread_participant
-- ============================================================
BEGIN;
-- models/message_reaction.rs → message_reaction
CREATE TABLE IF NOT EXISTS message_reaction (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
channel_id UUID NOT NULL,
user_id UUID NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_reaction_user_content UNIQUE (message_id, user_id, content)
);
CREATE INDEX IF NOT EXISTS idx_message_reaction_message_id
ON message_reaction (message_id);
CREATE INDEX IF NOT EXISTS idx_message_reaction_user_id
ON message_reaction (user_id);
-- models/message_bookmark.rs → message_bookmark
CREATE TABLE IF NOT EXISTS message_bookmark (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
channel_id UUID NOT NULL,
user_id UUID NOT NULL,
note TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_bookmark_user_message UNIQUE (user_id, message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_bookmark_user_id
ON message_bookmark (user_id);
CREATE INDEX IF NOT EXISTS idx_message_bookmark_message_id
ON message_bookmark (message_id);
-- models/message_mention.rs → message_mention
CREATE TABLE IF NOT EXISTS message_mention (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
channel_id UUID NOT NULL,
mentioned_user_id UUID NOT NULL,
mentioned_by UUID NOT NULL,
read_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_mention_message_id
ON message_mention (message_id);
CREATE INDEX IF NOT EXISTS idx_message_mention_mentioned_user
ON message_mention (mentioned_user_id);
-- models/message_thread.rs → message_thread
CREATE TABLE IF NOT EXISTS message_thread (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
channel_id UUID NOT NULL,
root_message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
created_by UUID NOT NULL,
replies_count BIGINT NOT NULL DEFAULT 0,
participants_count BIGINT NOT NULL DEFAULT 0,
last_reply_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
last_reply_at TIMESTAMPTZ NULL,
resolved BOOLEAN NOT NULL DEFAULT FALSE,
resolved_by UUID NULL,
resolved_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_message_thread_root UNIQUE (root_message_id)
);
CREATE INDEX IF NOT EXISTS idx_message_thread_channel_id
ON message_thread (channel_id);
CREATE INDEX IF NOT EXISTS idx_message_thread_last_reply_at
ON message_thread (last_reply_at DESC NULLS LAST);
-- models/message_thread_participant.rs → message_thread_participant
CREATE TABLE IF NOT EXISTS message_thread_participant (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
thread_id UUID NOT NULL REFERENCES message_thread(id) ON DELETE CASCADE,
user_id UUID NOT NULL,
joined_reason TEXT NULL,
last_read_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
last_read_at TIMESTAMPTZ NULL,
joined_at TIMESTAMPTZ NOT NULL DEFAULT now(),
CONSTRAINT uq_thread_participant UNIQUE (thread_id, user_id)
);
CREATE INDEX IF NOT EXISTS idx_thread_participant_thread_id
ON message_thread_participant (thread_id);
CREATE INDEX IF NOT EXISTS idx_thread_participant_user_id
ON message_thread_participant (user_id);
COMMIT;
+102
View File
@@ -0,0 +1,102 @@
-- ============================================================
-- Migration: 005_message_misc.sql
-- Tables: message_notification, message_scheduled, message_sticker,
-- message_forward, message_component
-- ============================================================
BEGIN;
-- models/message_notification.rs → message_notification
CREATE TABLE IF NOT EXISTS message_notification (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
channel_id UUID NOT NULL,
user_id UUID NOT NULL,
reason TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
delivery_channel TEXT NULL,
delivered_at TIMESTAMPTZ NULL,
read_at TIMESTAMPTZ NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_notification_user_id
ON message_notification (user_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_message_notification_status
ON message_notification (status);
-- models/message_scheduled.rs → message_scheduled
CREATE TABLE IF NOT EXISTS message_scheduled (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
channel_id UUID NOT NULL,
author_id UUID NOT NULL,
thread_id UUID NULL REFERENCES message_thread(id) ON DELETE SET NULL,
reply_to_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
body TEXT NOT NULL,
metadata JSONB NULL,
scheduled_at TIMESTAMPTZ NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
sent_message_id UUID NULL REFERENCES message(id) ON DELETE SET NULL,
error TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_scheduled_status_at
ON message_scheduled (status, scheduled_at);
-- models/message_sticker.rs → message_sticker
CREATE TABLE IF NOT EXISTS message_sticker (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
sticker_id UUID NOT NULL,
name TEXT NOT NULL,
image_url TEXT NOT NULL,
format_type TEXT NOT NULL DEFAULT 'png',
pack_name TEXT NULL,
tags TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_sticker_message_id
ON message_sticker (message_id);
-- models/message_forward.rs → message_forward
CREATE TABLE IF NOT EXISTS message_forward (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
source_message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
source_channel_id UUID NOT NULL,
forwarded_by UUID NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_forward_message_id
ON message_forward (message_id);
CREATE INDEX IF NOT EXISTS idx_message_forward_source_message_id
ON message_forward (source_message_id);
-- models/message_component.rs → message_component
CREATE TABLE IF NOT EXISTS message_component (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
message_id UUID NOT NULL REFERENCES message(id) ON DELETE CASCADE,
row INTEGER NOT NULL DEFAULT 0,
position INTEGER NOT NULL DEFAULT 0,
component_type TEXT NOT NULL,
custom_id TEXT NOT NULL,
label TEXT NULL,
emoji TEXT NULL,
style TEXT NULL,
url TEXT NULL,
disabled BOOLEAN NOT NULL DEFAULT FALSE,
placeholder TEXT NULL,
min_values INTEGER NULL,
max_values INTEGER NULL,
options JSONB NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_message_component_message_id
ON message_component (message_id);
COMMIT;
+175
View File
@@ -0,0 +1,175 @@
//! Core message model — maps to PostgreSQL `message` table.
//!
//! Discord-style: `body` holds plain text / markdown, rich content lives in
//! companion tables (attachment, embed, poll, reaction, mention).
//! IDs are UUID v7 (time-ordered) so `ORDER BY id` = chronological order.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Discriminator for system / event messages vs. regular user messages.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MessageType {
/// Regular user message (text / markdown).
#[default]
Text,
/// System-generated notice (e.g. "user joined the channel").
System,
/// Channel event (pinned a message, changed topic, etc.).
Event,
/// Forum article / long-form post displayed as waterfall cards.
Article,
}
impl MessageType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Text => "text",
Self::System => "system",
Self::Event => "event",
Self::Article => "article",
}
}
pub fn from_str_lossy(s: &str) -> Self {
match s {
"system" => Self::System,
"event" => Self::Event,
"article" => Self::Article,
_ => Self::Text,
}
}
}
impl std::fmt::Display for MessageType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
/// Direct mapping of the `message` table row.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Message {
/// UUID v7 — time-ordered primary key.
pub id: Uuid,
pub channel_id: Uuid,
pub author_id: Uuid,
/// Thread this message belongs to (NULL = not threaded).
pub thread_id: Option<Uuid>,
/// Direct reply reference (NULL = top-level message).
pub reply_to_message_id: Option<Uuid>,
/// "text" | "system" | "event"
pub message_type: String,
/// Plain text or markdown body.
pub body: String,
/// Extensible metadata (flags, locale, interaction ref, etc.).
pub metadata: Option<serde_json::Value>,
pub pinned: bool,
/// True for bot / system generated messages.
pub system: bool,
pub edited_at: Option<DateTime<Utc>>,
pub deleted_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Lightweight author info embedded in [`MessageDetail`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorInfo {
pub id: Uuid,
pub username: String,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
pub is_bot: bool,
}
/// Message with resolved author and reaction/attachment aggregates.
/// Returned by read APIs; never stored directly.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageDetail {
#[serde(flatten)]
pub message: Message,
pub author: AuthorInfo,
/// Aggregated reaction counts: `{ "👍": 3, "🎉": 1 }`.
pub reactions: std::collections::HashMap<String, i64>,
pub attachment_count: i64,
pub embed_count: i64,
/// Whether the current user has bookmarked this message.
pub bookmarked: bool,
/// Reply chain depth (0 = top-level).
pub reply_depth: i32,
}
/// Generate a new UUID v7 (time-ordered) for message IDs.
pub fn new_message_id() -> Uuid {
Uuid::now_v7()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type_roundtrip() {
let t = MessageType::Text;
assert_eq!(t.as_str(), "text");
assert_eq!(MessageType::from_str_lossy("text"), MessageType::Text);
assert_eq!(MessageType::from_str_lossy("system"), MessageType::System);
assert_eq!(MessageType::from_str_lossy("unknown"), MessageType::Text);
}
#[test]
fn test_message_id_ordering() {
// UUID v7 IDs generated later should sort after earlier ones.
let a = new_message_id();
std::thread::sleep(std::time::Duration::from_millis(2));
let b = new_message_id();
assert!(b > a, "UUID v7 should be time-ordered");
}
#[test]
fn test_message_detail_serialize() {
let msg = Message {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
author_id: Uuid::now_v7(),
thread_id: None,
reply_to_message_id: None,
message_type: "text".to_string(),
body: "hello world".to_string(),
metadata: None,
pinned: false,
system: false,
edited_at: None,
deleted_at: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let detail = MessageDetail {
message: msg,
author: AuthorInfo {
id: Uuid::now_v7(),
username: "alice".to_string(),
display_name: Some("Alice".to_string()),
avatar_url: None,
is_bot: false,
},
reactions: std::collections::HashMap::from([
("👍".to_string(), 3),
("🎉".to_string(), 1),
]),
attachment_count: 0,
embed_count: 0,
bookmarked: false,
reply_depth: 0,
};
let json = serde_json::to_value(&detail).unwrap();
assert_eq!(json["body"], "hello world");
assert_eq!(json["author"]["username"], "alice");
assert_eq!(json["reactions"]["👍"], 3);
}
}
+196
View File
@@ -0,0 +1,196 @@
//! Forum article / long-form post — maps to `message_article` table.
//!
//! Forum articles extend a regular [`Message`] with title, cover image, tags,
//! and view/like stats. Rendered as waterfall cards in forum channel views.
//! One article per message (1:1), linked via `message_id`.
//!
//! A message is an article when `message.message_type = "article"`.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Extended metadata for a forum article / post.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageArticle {
pub id: Uuid,
/// FK to `message.id`. UNIQUE constraint ensures 1:1.
pub message_id: Uuid,
/// Article title (plain text, max 256 chars).
pub title: String,
/// Short excerpt for card preview (plain text, max 512 chars).
pub summary: Option<String>,
/// Cover image URL (displayed at the top of the card).
pub cover_url: Option<String>,
/// Cover image width in pixels (for waterfall layout height calculation).
pub cover_width: Option<i32>,
/// Cover image height in pixels.
pub cover_height: Option<i32>,
/// Cover image dominant color (hex, for placeholder while loading).
pub cover_color: Option<String>,
/// Tag IDs referencing `forum_tag` table, stored as JSON array.
pub tags: Option<serde_json::Value>,
/// View count (denormalized, incremented on read).
pub view_count: i64,
/// Reaction / like count (denormalized).
pub like_count: i64,
/// Bookmark count (denormalized).
pub bookmark_count: i64,
/// Reply count (denormalized, threads inside the article).
pub reply_count: i64,
/// Most recent reply message id.
pub last_reply_message_id: Option<Uuid>,
/// Most recent reply timestamp.
pub last_reply_at: Option<DateTime<Utc>>,
/// User id of the last replier.
pub last_reply_user_id: Option<Uuid>,
/// Whether the article is pinned to the top of the forum channel.
pub is_pinned_to_top: bool,
/// Whether the question has been answered / resolved.
pub is_answered: bool,
/// Who marked it as answered.
pub answered_by: Option<Uuid>,
/// When it was marked as answered.
pub answered_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Waterfall card view — article enriched with author info + first attachment
/// (cover fallback). Returned by forum list APIs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArticleCard {
#[serde(flatten)]
pub article: MessageArticle,
/// Author display info.
pub author: super::message::AuthorInfo,
/// Resolved tag names (e.g. ["Bug Report", "High Priority"]).
pub tag_names: Vec<String>,
/// First image attachment URL (fallback when cover_url is NULL).
pub first_image_url: Option<String>,
/// Whether the current user has bookmarked this article.
pub bookmarked: bool,
/// Whether the current user has liked this article.
pub liked: bool,
}
/// Input payload for creating a new forum article.
///
/// Sent by the client when composing a forum post.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateArticleInput {
pub channel_id: Uuid,
pub title: String,
pub body: String,
pub summary: Option<String>,
pub cover_url: Option<String>,
pub tags: Option<Vec<Uuid>>,
}
/// Input payload for updating an existing forum article.
///
/// All fields are optional — only provided fields are updated.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateArticleInput {
pub title: Option<String>,
pub body: Option<String>,
pub summary: Option<String>,
pub cover_url: Option<String>,
pub cover_color: Option<String>,
pub tags: Option<Vec<Uuid>>,
}
/// Sort modes for forum article listing.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ArticleSort {
/// Most recent activity first.
#[default]
LatestActivity,
/// Newest articles first.
Newest,
/// Most viewed first.
MostViewed,
/// Most liked first.
MostLiked,
/// Pinned articles first, then by latest activity.
PinnedFirst,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_article_card_serialize() {
let article = MessageArticle {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
title: "Bug Report: Login fails on Safari".into(),
summary: Some("Users report that login fails on Safari 18...".into()),
cover_url: Some("https://cdn.example.com/covers/bug.png".into()),
cover_width: Some(1200),
cover_height: Some(630),
cover_color: Some("#FF6B6B".into()),
tags: Some(serde_json::json!(["01909a", "01909b"])),
view_count: 142,
like_count: 7,
bookmark_count: 3,
reply_count: 5,
last_reply_message_id: Some(Uuid::now_v7()),
last_reply_at: Some(Utc::now()),
last_reply_user_id: Some(Uuid::now_v7()),
is_pinned_to_top: true,
is_answered: false,
answered_by: None,
answered_at: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let card = ArticleCard {
article,
author: super::super::message::AuthorInfo {
id: Uuid::now_v7(),
username: "alice".into(),
display_name: Some("Alice".into()),
avatar_url: None,
is_bot: false,
},
tag_names: vec!["Bug Report".into(), "High Priority".into()],
first_image_url: None,
bookmarked: false,
liked: true,
};
let json = serde_json::to_value(&card).unwrap();
assert_eq!(json["title"], "Bug Report: Login fails on Safari");
assert_eq!(json["view_count"], 142);
assert_eq!(json["author"]["username"], "alice");
assert_eq!(json["tag_names"][0], "Bug Report");
assert_eq!(json["is_pinned_to_top"], true);
}
#[test]
fn test_article_sort_serialize() {
let sort = ArticleSort::MostViewed;
let json = serde_json::to_value(sort).unwrap();
assert_eq!(json, "most_viewed");
}
#[test]
fn test_create_article_input_serialize() {
let input = CreateArticleInput {
channel_id: Uuid::now_v7(),
title: "New Feature Proposal".into(),
body: "I'd like to propose...".into(),
summary: Some("A proposal for...".into()),
cover_url: None,
tags: Some(vec![Uuid::now_v7()]),
};
let json = serde_json::to_value(&input).unwrap();
assert_eq!(json["title"], "New Feature Proposal");
assert_eq!(json["tags"].as_array().unwrap().len(), 1);
}
}
+96
View File
@@ -0,0 +1,96 @@
//! File / image attachment on a message — new table `message_attachment`.
//!
//! Discord-style: each message can have multiple attachments. Files are
//! uploaded to S3 first, then the URL is stored here.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A file or image attachment on a message.
///
/// Maps to the `message_attachment` table. Files are uploaded to object
/// storage first, then the URL is stored here.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageAttachment {
pub id: Uuid,
pub message_id: Uuid,
/// Original filename as uploaded by the user.
pub filename: String,
/// MIME type: "image/png", "application/pdf", etc.
pub content_type: Option<String>,
/// File size in bytes.
pub size: i64,
/// Public URL (S3 presigned or CDN).
pub url: String,
/// S3 / object-store key for backend access.
pub storage_key: Option<String>,
/// Image / video width in pixels.
pub width: Option<i32>,
/// Image / video height in pixels.
pub height: Option<i32>,
/// Audio / video duration in seconds.
pub duration_secs: Option<f64>,
/// Blurred low-res preview for progressive loading (base64 data URI).
pub blurhash: Option<String>,
/// Whether this attachment should be rendered as a spoiler (hidden until click).
pub spoiler: bool,
pub created_at: DateTime<Utc>,
}
/// Lightweight attachment summary for list views.
///
/// Omits URL and storage_key to reduce payload size.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttachmentSummary {
pub id: Uuid,
pub filename: String,
pub content_type: Option<String>,
pub size: i64,
pub width: Option<i32>,
pub height: Option<i32>,
pub spoiler: bool,
}
impl From<MessageAttachment> for AttachmentSummary {
fn from(a: MessageAttachment) -> Self {
Self {
id: a.id,
filename: a.filename,
content_type: a.content_type,
size: a.size,
width: a.width,
height: a.height,
spoiler: a.spoiler,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attachment_conversion() {
let att = MessageAttachment {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
filename: "photo.png".into(),
content_type: Some("image/png".into()),
size: 1024,
url: "https://cdn.example.com/photo.png".into(),
storage_key: None,
width: Some(800),
height: Some(600),
duration_secs: None,
blurhash: None,
spoiler: false,
created_at: Utc::now(),
};
let summary: AttachmentSummary = att.into();
assert_eq!(summary.filename, "photo.png");
assert_eq!(summary.size, 1024);
assert_eq!(summary.width, Some(800));
}
}
+46
View File
@@ -0,0 +1,46 @@
//! User bookmark on a message — maps to `message_bookmark` table.
//!
//! Similar to browser bookmarks: user saves a message for later reference,
//! optionally with a personal note.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A user bookmark on a message for later reference.
///
/// Maps to the `message_bookmark` table. Similar to browser bookmarks,
/// optionally with a personal note.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageBookmark {
pub id: Uuid,
pub message_id: Uuid,
pub channel_id: Uuid,
pub user_id: Uuid,
/// Personal note the user attached to this bookmark.
pub note: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bookmark_serialize() {
let bm = MessageBookmark {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
note: Some("Important reference".into()),
created_at: Utc::now(),
updated_at: Utc::now(),
};
let json = serde_json::to_value(&bm).unwrap();
assert_eq!(json["note"], "Important reference");
assert!(json["message_id"].is_string());
}
}
+95
View File
@@ -0,0 +1,95 @@
//! Interactive message components — maps to `message_component` table.
//!
//! Discord-style interactive elements attached to messages: buttons, select
//! menus, etc. Each component belongs to a message and has its own layout row.
//! Clicking a component emits an interaction event back to the bot/webhook.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// An interactive component attached to a message (button, select menu, etc.).
///
/// Maps to the `message_component` table. Each component belongs to a message
/// and has a layout row/position. User interactions emit callbacks.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageComponent {
pub id: Uuid,
pub message_id: Uuid,
/// Layout row within the message (0-based).
pub row: i32,
/// Position within the row (0-based).
pub position: i32,
/// "button" | "select_menu" | "text_input"
pub component_type: String,
/// Unique identifier sent back in interaction callbacks.
pub custom_id: String,
/// Display label.
pub label: Option<String>,
/// Emoji shown on the button (unicode or `:name:id`).
pub emoji: Option<String>,
/// Button style: "primary" | "secondary" | "success" | "danger" | "link"
pub style: Option<String>,
/// URL for link-style buttons.
pub url: Option<String>,
/// Whether the component is disabled.
pub disabled: bool,
/// Placeholder text for select menus.
pub placeholder: Option<String>,
/// Min/max selections for select menus.
pub min_values: Option<i32>,
pub max_values: Option<i32>,
/// Options for select menus, stored as JSON array.
pub options: Option<serde_json::Value>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ComponentType {
#[default]
Button,
SelectMenu,
TextInput,
}
impl ComponentType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Button => "button",
Self::SelectMenu => "select_menu",
Self::TextInput => "text_input",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_component_serialize() {
let c = MessageComponent {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
row: 0,
position: 0,
component_type: ComponentType::Button.as_str().into(),
custom_id: "btn_approve".into(),
label: Some("Approve".into()),
emoji: Some("".into()),
style: Some("success".into()),
url: None,
disabled: false,
placeholder: None,
min_values: None,
max_values: None,
options: None,
created_at: Utc::now(),
};
let json = serde_json::to_value(&c).unwrap();
assert_eq!(json["component_type"], "button");
assert_eq!(json["style"], "success");
}
}
+77
View File
@@ -0,0 +1,77 @@
//! Message drafts — maps to `message_draft` table.
//!
//! Stores unsent messages so they survive browser refreshes and sync across
//! the user's connected devices. One draft per (channel, user, optional thread).
//! Drafts are upserted on every keystroke debounce and deleted on send.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A user's unsent message draft in a channel.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageDraft {
pub id: Uuid,
pub channel_id: Uuid,
pub user_id: Uuid,
/// Thread the draft belongs to (NULL = top-level).
pub thread_id: Option<Uuid>,
/// Message this draft is replying to (NULL = new message).
pub reply_to_message_id: Option<Uuid>,
/// Plain text or markdown body.
pub body: String,
/// Extensible metadata (attachments to be uploaded, etc.).
pub metadata: Option<serde_json::Value>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Input for upserting a draft (sent from client on debounced keystroke).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DraftUpsertInput {
pub channel_id: Uuid,
pub thread_id: Option<Uuid>,
pub reply_to_message_id: Option<Uuid>,
pub body: String,
pub metadata: Option<serde_json::Value>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_draft_upsert_input_serialize() {
let input = DraftUpsertInput {
channel_id: Uuid::now_v7(),
thread_id: None,
reply_to_message_id: None,
body: "hello, this is a draft".to_string(),
metadata: None,
};
let json = serde_json::to_value(&input).unwrap();
assert_eq!(json["body"], "hello, this is a draft");
assert!(json["thread_id"].is_null());
}
#[test]
fn test_draft_serialize() {
let draft = MessageDraft {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
thread_id: Some(Uuid::now_v7()),
reply_to_message_id: None,
body: "draft body".to_string(),
metadata: Some(serde_json::json!({"pending_attachments": []})),
created_at: Utc::now(),
updated_at: Utc::now(),
};
let json = serde_json::to_value(&draft).unwrap();
assert_eq!(json["body"], "draft body");
assert!(json["thread_id"].is_string());
assert!(json["metadata"]["pending_attachments"].is_array());
}
}
+77
View File
@@ -0,0 +1,77 @@
//! Message edit history — maps to `message_edit` table.
//!
//! Immutable append-only log of every edit to a message.
//! Used for audit trails, "edited" indicators with hover-to-see-original,
//! and compliance / moderation review.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// One edit record. Stored every time a message body is modified.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageEdit {
pub id: Uuid,
pub message_id: Uuid,
/// Who made the edit (usually the author; can be a moderator).
pub edited_by: Uuid,
/// Body content before the edit.
pub old_body: String,
/// Body content after the edit.
pub new_body: String,
pub edited_at: DateTime<Utc>,
}
/// Lightweight summary for the "edited" tooltip (no full body).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EditSummary {
pub edit_count: i64,
pub last_edited_at: Option<DateTime<Utc>>,
pub last_edited_by: Option<Uuid>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_edit_serialize() {
let edit = MessageEdit {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
edited_by: Uuid::now_v7(),
old_body: "before edit".to_string(),
new_body: "after edit".to_string(),
edited_at: Utc::now(),
};
let json = serde_json::to_value(&edit).unwrap();
assert_eq!(json["old_body"], "before edit");
assert_eq!(json["new_body"], "after edit");
}
#[test]
fn test_edit_summary_serialize() {
let summary = EditSummary {
edit_count: 3,
last_edited_at: Some(Utc::now()),
last_edited_by: Some(Uuid::now_v7()),
};
let json = serde_json::to_value(&summary).unwrap();
assert_eq!(json["edit_count"], 3);
}
#[test]
fn test_edit_summary_no_edits() {
let summary = EditSummary {
edit_count: 0,
last_edited_at: None,
last_edited_by: None,
};
let json = serde_json::to_value(&summary).unwrap();
assert_eq!(json["edit_count"], 0);
assert!(json["last_edited_at"].is_null());
}
}
+163
View File
@@ -0,0 +1,163 @@
//! Rich embed on a message — new table `message_embed` + `message_embed_field`.
//!
//! Discord-style embeds: link previews, rich cards with title/description/
//! thumbnail/image/footer/fields. Generated by the server (link preview)
//! or sent by bots/webhooks.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A single embed attached to a message.
/// One message can have multiple embeds (e.g. link preview + bot embed).
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageEmbed {
pub id: Uuid,
pub message_id: Uuid,
/// "link" | "article" | "image" | "video" | "rich"
pub embed_type: String,
pub title: Option<String>,
pub description: Option<String>,
pub url: Option<String>,
/// Embed accent color as integer (Discord format: 0xRRGGBB).
pub color: Option<i32>,
// Media
/// Main image URL.
pub image_url: Option<String>,
pub image_width: Option<i32>,
pub image_height: Option<i32>,
/// Small thumbnail URL.
pub thumbnail_url: Option<String>,
pub thumbnail_width: Option<i32>,
pub thumbnail_height: Option<i32>,
/// Video URL (for video embeds).
pub video_url: Option<String>,
pub video_width: Option<i32>,
pub video_height: Option<i32>,
// Footer
pub author_name: Option<String>,
pub author_url: Option<String>,
pub author_icon_url: Option<String>,
pub footer_text: Option<String>,
pub footer_icon_url: Option<String>,
/// Provider name (e.g. "YouTube", "GitHub").
pub provider_name: Option<String>,
pub provider_url: Option<String>,
pub created_at: DateTime<Utc>,
}
/// A key-value field within an embed (Discord-style field rows).
/// Stored in a separate table for flexibility.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageEmbedField {
pub id: Uuid,
pub embed_id: Uuid,
pub name: String,
pub value: String,
/// Whether this field should display inline (side by side with other inline fields).
pub inline: bool,
pub position: i32,
}
/// An embed with its fields resolved in a single structure.
///
/// Convenience type returned by read APIs, joining embed and fields.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedDetail {
#[serde(flatten)]
pub embed: MessageEmbed,
pub fields: Vec<MessageEmbedField>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embed_serialize() {
let embed = MessageEmbed {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
embed_type: "link".into(),
title: Some("Example".into()),
description: Some("A link preview".into()),
url: Some("https://example.com".into()),
color: Some(0x00FF00),
image_url: None,
image_width: None,
image_height: None,
thumbnail_url: None,
thumbnail_width: None,
thumbnail_height: None,
video_url: None,
video_width: None,
video_height: None,
author_name: None,
author_url: None,
author_icon_url: None,
footer_text: None,
footer_icon_url: None,
provider_name: Some("GitHub".into()),
provider_url: None,
created_at: Utc::now(),
};
let json = serde_json::to_value(&embed).unwrap();
assert_eq!(json["embed_type"], "link");
assert_eq!(json["title"], "Example");
}
#[test]
fn test_embed_detail_serialize() {
let embed = MessageEmbed {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
embed_type: "rich".into(),
title: Some("Article".into()),
description: None,
url: None,
color: None,
image_url: None,
image_width: None,
image_height: None,
thumbnail_url: None,
thumbnail_width: None,
thumbnail_height: None,
video_url: None,
video_width: None,
video_height: None,
author_name: None,
author_url: None,
author_icon_url: None,
footer_text: None,
footer_icon_url: None,
provider_name: None,
provider_url: None,
created_at: Utc::now(),
};
let field = MessageEmbedField {
id: Uuid::now_v7(),
embed_id: embed.id,
name: "Key".into(),
value: "Value".into(),
inline: true,
position: 0,
};
let detail = EmbedDetail {
embed,
fields: vec![field],
};
let json = serde_json::to_value(&detail).unwrap();
assert_eq!(json["embed_type"], "rich");
assert!(json["fields"].is_array());
}
}
+48
View File
@@ -0,0 +1,48 @@
//! Message forwarding trail — maps to `message_forward` table.
//!
//! When a user forwards a message from one channel to another, this table
//! records the provenance. The forwarded message is a new message with a
//! copy of the original body, linked back via `source_message_id`.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A forwarded message linking the copy to the original.
///
/// Maps to the `message_forward` table. Records provenance when a user
/// forwards a message from one channel to another.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageForward {
pub id: Uuid,
/// The new (forwarded) message.
pub message_id: Uuid,
/// The original message being forwarded.
pub source_message_id: Uuid,
/// The channel the original message came from.
pub source_channel_id: Uuid,
/// Who forwarded the message.
pub forwarded_by: Uuid,
pub created_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forward_serialize() {
let f = MessageForward {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
source_message_id: Uuid::now_v7(),
source_channel_id: Uuid::now_v7(),
forwarded_by: Uuid::now_v7(),
created_at: Utc::now(),
};
let json = serde_json::to_value(&f).unwrap();
assert!(json["source_message_id"].is_string());
assert!(json["source_channel_id"].is_string());
}
}
+123
View File
@@ -0,0 +1,123 @@
//! @mention tracking — maps to `message_mention` table.
//!
//! Parsed from message body on send. Used for notification and
//! "mentions" feed.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// An @mention of a user in a message.
///
/// Maps to the `message_mention` table. Parsed from message body on send
/// and used for notifications and the mentions feed.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageMention {
pub id: Uuid,
pub message_id: Uuid,
pub channel_id: Uuid,
pub mentioned_user_id: Uuid,
pub mentioned_by: Uuid,
/// When the mentioned user read the notification.
pub read_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
/// Parse @username mentions from a message body.
/// Returns unique usernames (without the `@` prefix).
///
/// Matches `@word` where word is `[a-zA-Z0-9_-]+`, min 2 chars.
/// Ignores `@@` (escaped) and mentions inside code spans/blocks.
pub fn parse_mentions(body: &str) -> Vec<String> {
let mut seen = std::collections::HashSet::new();
let mut mentions = Vec::new();
// Simple state machine: skip content inside backtick spans.
let mut in_code = false;
let bytes = body.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'`' {
in_code = !in_code;
i += 1;
continue;
}
if !in_code && bytes[i] == b'@' {
// Skip escaped @@
if i + 1 < bytes.len() && bytes[i + 1] == b'@' {
i += 2;
continue;
}
let start = i + 1;
let mut end = start;
while end < bytes.len()
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_' || bytes[end] == b'-')
{
end += 1;
}
let len = end - start;
if len >= 2 {
let name = body[start..end].to_lowercase();
if seen.insert(name.clone()) {
mentions.push(name);
}
}
i = end;
} else {
i += 1;
}
}
mentions
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_mentions_basic() {
let m = parse_mentions("hey @alice, cc @bob");
assert_eq!(m, vec!["alice".to_string(), "bob".to_string()]);
}
#[test]
fn test_parse_mentions_dedup() {
let m = parse_mentions("@alice said hi to @alice");
assert_eq!(m, vec!["alice".to_string()]);
}
#[test]
fn test_parse_mentions_case_insensitive() {
let m = parse_mentions("@Alice and @ALICE");
assert_eq!(m, vec!["alice".to_string()]);
}
#[test]
fn test_parse_mentions_skip_code_span() {
let m = parse_mentions("use `@here` to notify @alice");
assert_eq!(m, vec!["alice".to_string()]);
}
#[test]
fn test_parse_mentions_skip_escaped() {
let m = parse_mentions("@@alice is not a mention, but @bob is");
assert_eq!(m, vec!["bob".to_string()]);
}
#[test]
fn test_parse_mentions_short_name_ignored() {
let m = parse_mentions("@a is too short, @ab is ok");
assert_eq!(m, vec!["ab".to_string()]);
}
#[test]
fn test_parse_mentions_empty() {
assert!(parse_mentions("no mentions here").is_empty());
assert!(parse_mentions("").is_empty());
}
}
+97
View File
@@ -0,0 +1,97 @@
//! Notification delivery tracking — maps to `message_notification` table.
//!
//! Records when a message triggers a notification for a user (mention, reply,
//! thread activity, etc.) and tracks the delivery/read lifecycle.
//! Separate from `message_mention` which only covers @mentions.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A notification triggered for a user by a message.
///
/// Maps to the `message_notification` table. Records when a message triggers
/// a notification (mention, reply, thread activity) and tracks delivery/read.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageNotification {
pub id: Uuid,
pub message_id: Uuid,
pub channel_id: Uuid,
pub user_id: Uuid,
/// "mention" | "reply" | "thread" | "watch"
pub reason: String,
/// "pending" | "delivered" | "read" | "dismissed"
pub status: String,
/// Channel of delivery: "push" | "email" | "in_app"
pub delivery_channel: Option<String>,
pub delivered_at: Option<DateTime<Utc>>,
pub read_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum NotificationReason {
#[default]
Mention,
Reply,
Thread,
Watch,
}
impl NotificationReason {
pub fn as_str(&self) -> &'static str {
match self {
Self::Mention => "mention",
Self::Reply => "reply",
Self::Thread => "thread",
Self::Watch => "watch",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum NotificationStatus {
#[default]
Pending,
Delivered,
Read,
Dismissed,
}
impl NotificationStatus {
pub fn as_str(&self) -> &'static str {
match self {
Self::Pending => "pending",
Self::Delivered => "delivered",
Self::Read => "read",
Self::Dismissed => "dismissed",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_notification_serialize() {
let n = MessageNotification {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
reason: NotificationReason::Mention.as_str().into(),
status: NotificationStatus::Pending.as_str().into(),
delivery_channel: Some("push".into()),
delivered_at: None,
read_at: None,
created_at: Utc::now(),
};
let json = serde_json::to_value(&n).unwrap();
assert_eq!(json["reason"], "mention");
assert_eq!(json["status"], "pending");
}
}
+60
View File
@@ -0,0 +1,60 @@
//! Pinned message management — maps to `message_pin` table.
//!
//! A channel can have multiple pinned messages with explicit ordering.
//! Unlike the `message.pinned` boolean (which just marks the row),
//! this table tracks *who* pinned, *when*, and the display *position*.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// One pinned message entry. Ordered by `position` ascending.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessagePin {
pub id: Uuid,
pub channel_id: Uuid,
pub message_id: Uuid,
/// Who pinned this message.
pub pinned_by: Uuid,
/// Display position in the pinned list (0 = top).
pub position: i32,
pub created_at: DateTime<Utc>,
}
/// Summary view returned by list-pins APIs (joined with message content).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PinDetail {
#[serde(flatten)]
pub pin: MessagePin,
pub message_body: String,
pub message_author_id: Uuid,
pub message_created_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pin_detail_serialize() {
let pin = MessagePin {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
pinned_by: Uuid::now_v7(),
position: 0,
created_at: Utc::now(),
};
let detail = PinDetail {
pin,
message_body: "important announcement".to_string(),
message_author_id: Uuid::now_v7(),
message_created_at: Utc::now(),
};
let json = serde_json::to_value(&detail).unwrap();
assert_eq!(json["position"], 0);
assert_eq!(json["message_body"], "important announcement");
}
}
+170
View File
@@ -0,0 +1,170 @@
//! Poll on a message — new tables `message_poll`, `message_poll_option`,
//! `message_poll_vote`.
//!
//! Discord-style polls: attached to a message, with multiple options,
//! optional multi-vote, and an expiry time.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// The poll itself (one per message, optional).
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessagePoll {
pub id: Uuid,
pub message_id: Uuid,
/// The question displayed to voters.
pub question: String,
/// Whether users can select multiple options.
pub allow_multiselect: bool,
/// Maximum number of options a user can select (NULL = unlimited when multiselect).
pub max_selections: Option<i32>,
/// When voting closes (NULL = no expiry).
pub expires_at: Option<DateTime<Utc>>,
/// Total number of votes cast (denormalized for fast reads).
pub total_votes: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// A single selectable option within a poll.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessagePollOption {
pub id: Uuid,
pub poll_id: Uuid,
/// Display text for this option.
pub text: String,
/// Optional emoji prefix (Discord-style).
pub emoji: Option<String>,
/// Number of votes this option received (denormalized).
pub vote_count: i64,
/// Display order.
pub position: i32,
}
/// A single vote cast by a user.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessagePollVote {
pub id: Uuid,
pub poll_id: Uuid,
pub option_id: Uuid,
pub user_id: Uuid,
pub created_at: DateTime<Utc>,
}
/// Aggregated poll results for API responses.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PollResult {
pub poll: MessagePoll,
pub options: Vec<PollOptionResult>,
/// Which options the current user voted for (empty if not voted).
pub my_votes: Vec<Uuid>,
/// Whether the poll has expired.
pub is_expired: bool,
}
/// Option with its vote count and percentage.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PollOptionResult {
#[serde(flatten)]
pub option: MessagePollOption,
/// Percentage of total votes (0.0100.0), rounded to 1 decimal.
pub percentage: f64,
}
impl PollResult {
/// Compute percentages from total_votes.
pub fn from_poll(
poll: MessagePoll,
options: Vec<MessagePollOption>,
my_votes: Vec<Uuid>,
) -> Self {
let total = poll.total_votes.max(1) as f64;
let now = Utc::now();
let is_expired = poll.expires_at.is_some_and(|exp| now >= exp);
let options = options
.into_iter()
.map(|opt| {
let pct = if poll.total_votes > 0 {
(opt.vote_count as f64 / total * 100.0 * 10.0).round() / 10.0
} else {
0.0
};
PollOptionResult {
option: opt,
percentage: pct,
}
})
.collect();
Self {
poll,
options,
my_votes,
is_expired,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poll_result_percentages() {
let poll = MessagePoll {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
question: "Best language?".to_string(),
allow_multiselect: false,
max_selections: None,
expires_at: None,
total_votes: 10,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let options = vec![
MessagePollOption {
id: Uuid::now_v7(),
poll_id: poll.id,
text: "Rust".to_string(),
emoji: None,
vote_count: 7,
position: 0,
},
MessagePollOption {
id: Uuid::now_v7(),
poll_id: poll.id,
text: "Go".to_string(),
emoji: None,
vote_count: 3,
position: 1,
},
];
let result = PollResult::from_poll(poll, options, vec![]);
assert!(!result.is_expired);
assert_eq!(result.options[0].percentage, 70.0);
assert_eq!(result.options[1].percentage, 30.0);
}
#[test]
fn test_poll_result_zero_votes() {
let poll = MessagePoll {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
question: "Empty poll".to_string(),
allow_multiselect: false,
max_selections: None,
expires_at: None,
total_votes: 0,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let result = PollResult::from_poll(poll, vec![], vec![]);
assert!(result.options.is_empty());
}
}
+66
View File
@@ -0,0 +1,66 @@
//! Emoji reaction on a message — maps to `message_reaction` table.
//!
//! One row per (user, message, emoji). `content` stores the emoji string
//! (Unicode emoji or custom `:name:id` format like Discord).
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A single emoji reaction on a message by a user.
///
/// Maps to the `message_reaction` table. One row per (user, message, emoji).
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageReaction {
pub id: Uuid,
pub message_id: Uuid,
pub channel_id: Uuid,
pub user_id: Uuid,
/// Emoji string: "👍", "🎉", or custom ":appks:01909a..."
pub content: String,
pub created_at: DateTime<Utc>,
}
/// Aggregated reaction count for a single emoji on a message.
/// Returned in [`MessageDetail::reactions`] and list APIs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReactionCount {
pub content: String,
pub count: i64,
/// Whether the current user reacted with this emoji.
pub me: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reaction_serialize() {
let reaction = MessageReaction {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
content: "👍".into(),
created_at: Utc::now(),
};
let json = serde_json::to_value(&reaction).unwrap();
assert_eq!(json["content"], "👍");
}
#[test]
fn test_reaction_count_serialize() {
let count = ReactionCount {
content: "🎉".into(),
count: 3,
me: true,
};
let json = serde_json::to_value(&count).unwrap();
assert_eq!(json["content"], "🎉");
assert_eq!(json["count"], 3);
assert_eq!(json["me"], true);
}
}
+92
View File
@@ -0,0 +1,92 @@
//! Per-user read state — maps to `message_read_state` table.
//!
//! Tracks the last message each user has read in each channel.
//! Used for unread badges, "mark as read", and notification suppression.
//! One row per (channel_id, user_id), upserted on each read.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A user's read progress in one channel.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageReadState {
pub id: Uuid,
pub channel_id: Uuid,
pub user_id: Uuid,
/// The last message id the user has read (cursor).
pub last_read_message_id: Option<Uuid>,
/// When the user last opened / scrolled through this channel.
pub last_read_at: Option<DateTime<Utc>>,
/// Total unread message count (denormalized for fast badge display).
pub unread_count: i64,
/// Total unread @mentions for this channel (denormalized).
pub unread_mentions: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Summary of a user's read state for the client-side channel list.
///
/// Includes unread badge count and mention count for a single channel.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadStateSummary {
pub channel_id: Uuid,
pub unread_count: i64,
pub unread_mentions: i64,
pub has_unread: bool,
}
impl From<MessageReadState> for ReadStateSummary {
fn from(s: MessageReadState) -> Self {
Self {
channel_id: s.channel_id,
unread_count: s.unread_count,
unread_mentions: s.unread_mentions,
has_unread: s.unread_count > 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_state_summary_conversion() {
let state = MessageReadState {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
last_read_message_id: None,
last_read_at: None,
unread_count: 5,
unread_mentions: 2,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let summary: ReadStateSummary = state.into();
assert!(summary.has_unread);
assert_eq!(summary.unread_count, 5);
assert_eq!(summary.unread_mentions, 2);
}
#[test]
fn test_read_state_summary_no_unread() {
let state = MessageReadState {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
last_read_message_id: None,
last_read_at: None,
unread_count: 0,
unread_mentions: 0,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let summary: ReadStateSummary = state.into();
assert!(!summary.has_unread);
}
}
+82
View File
@@ -0,0 +1,82 @@
//! Scheduled messages — maps to `message_scheduled` table.
//!
//! A message that the user has composed but wants to send at a future time.
//! A background job picks up rows where `scheduled_at <= now()` and dispatches
//! them through the normal send path.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A message scheduled to be sent at a future time.
///
/// Maps to the `message_scheduled` table. A background job picks up rows
/// where `scheduled_at <= now()` and dispatches them through the normal send path.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageScheduled {
pub id: Uuid,
pub channel_id: Uuid,
pub author_id: Uuid,
pub thread_id: Option<Uuid>,
pub reply_to_message_id: Option<Uuid>,
pub body: String,
pub metadata: Option<serde_json::Value>,
/// When the message should be sent.
pub scheduled_at: DateTime<Utc>,
/// "pending" | "sent" | "cancelled" | "failed"
pub status: String,
/// Set after the message is dispatched; points to the sent `message.id`.
pub sent_message_id: Option<Uuid>,
/// Error message if dispatch failed.
pub error: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ScheduledStatus {
#[default]
Pending,
Sent,
Cancelled,
Failed,
}
impl ScheduledStatus {
pub fn as_str(&self) -> &'static str {
match self {
Self::Pending => "pending",
Self::Sent => "sent",
Self::Cancelled => "cancelled",
Self::Failed => "failed",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduled_serialize() {
let s = MessageScheduled {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
author_id: Uuid::now_v7(),
thread_id: None,
reply_to_message_id: None,
body: "Good morning everyone!".into(),
metadata: None,
scheduled_at: Utc::now(),
status: ScheduledStatus::Pending.as_str().into(),
sent_message_id: None,
error: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let json = serde_json::to_value(&s).unwrap();
assert_eq!(json["status"], "pending");
}
}
+56
View File
@@ -0,0 +1,56 @@
//! Sticker attachment on a message — maps to `message_sticker` table.
//!
//! Large sticker images sent in messages, distinct from emoji reactions.
//! Stickers are either workspace-level (custom, defined in appks) or
//! system-level (built-in sticker packs).
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A sticker attached to a message.
///
/// Maps to the `message_sticker` table. Stickers are larger than emoji
/// and can be workspace-level (custom) or system-level (built-in packs).
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageSticker {
pub id: Uuid,
pub message_id: Uuid,
/// References a sticker defined in appks (workspace or system sticker).
pub sticker_id: Uuid,
/// Sticker name at time of send (snapshot for history).
pub name: String,
/// Image URL (snapshot).
pub image_url: String,
/// "png" | "apng" | "lottie"
pub format_type: String,
/// Pack name (e.g. "Wumpus" or workspace name).
pub pack_name: Option<String>,
/// Search tags for discovery.
pub tags: Option<String>,
pub created_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sticker_serialize() {
let s = MessageSticker {
id: Uuid::now_v7(),
message_id: Uuid::now_v7(),
sticker_id: Uuid::now_v7(),
name: "Hype!".into(),
image_url: "https://cdn.example.com/stickers/hype.png".into(),
format_type: "png".into(),
pack_name: Some("Wumpus".into()),
tags: Some("excited,hype".into()),
created_at: Utc::now(),
};
let json = serde_json::to_value(&s).unwrap();
assert_eq!(json["name"], "Hype!");
assert_eq!(json["format_type"], "png");
}
}
+60
View File
@@ -0,0 +1,60 @@
//! Thread metadata — maps to `message_thread` table.
//!
//! A thread is anchored by a root message. Reply messages in the same thread
//! set `message.thread_id = message_thread.id`.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A message thread anchored by a root message.
///
/// Maps to the `message_thread` table. Reply messages set
/// `message.thread_id = message_thread.id` to join the thread.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageThread {
pub id: Uuid,
pub channel_id: Uuid,
/// The first message that started this thread.
pub root_message_id: Uuid,
pub created_by: Uuid,
pub replies_count: i64,
pub participants_count: i64,
pub last_reply_message_id: Option<Uuid>,
pub last_reply_at: Option<DateTime<Utc>>,
/// Forum-style: mark thread as resolved / answered.
pub resolved: bool,
pub resolved_by: Option<Uuid>,
pub resolved_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_thread_serialize() {
let thread = MessageThread {
id: Uuid::now_v7(),
channel_id: Uuid::now_v7(),
root_message_id: Uuid::now_v7(),
created_by: Uuid::now_v7(),
replies_count: 5,
participants_count: 3,
last_reply_message_id: None,
last_reply_at: None,
resolved: false,
resolved_by: None,
resolved_at: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let json = serde_json::to_value(&thread).unwrap();
assert_eq!(json["replies_count"], 5);
assert_eq!(json["resolved"], false);
assert!(json["id"].is_string());
}
}
+71
View File
@@ -0,0 +1,71 @@
//! Thread participant membership — maps to `message_thread_participant` table.
//!
//! Tracks which users are part of a thread. A user becomes a participant when
//! they reply in a thread, get @mentioned, or are explicitly added.
//! Without this table, `message_thread.participants_count` is un-verifiable.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// A user participating in a thread.
///
/// Maps to the `message_thread_participant` table. Users become participants
/// when they reply, get @mentioned, or are explicitly added.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct MessageThreadParticipant {
pub id: Uuid,
pub thread_id: Uuid,
pub user_id: Uuid,
/// How the user joined the thread.
pub joined_reason: Option<String>,
pub last_read_message_id: Option<Uuid>,
pub last_read_at: Option<DateTime<Utc>>,
pub joined_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum JoinReason {
/// User sent a reply in the thread.
#[default]
Reply,
/// User was @mentioned in a thread message.
Mentioned,
/// User was explicitly added by another participant.
Added,
/// User joined the thread themselves.
Joined,
}
impl JoinReason {
pub fn as_str(&self) -> &'static str {
match self {
Self::Reply => "reply",
Self::Mentioned => "mentioned",
Self::Added => "added",
Self::Joined => "joined",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_participant_serialize() {
let p = MessageThreadParticipant {
id: Uuid::now_v7(),
thread_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
joined_reason: Some(JoinReason::Reply.as_str().into()),
last_read_message_id: None,
last_read_at: None,
joined_at: Utc::now(),
};
let json = serde_json::to_value(&p).unwrap();
assert_eq!(json["joined_reason"], "reply");
}
}
+43
View File
@@ -0,0 +1,43 @@
pub mod message;
pub mod message_article;
pub mod message_attachment;
pub mod message_bookmark;
pub mod message_component;
pub mod message_draft;
pub mod message_edit;
pub mod message_embed;
pub mod message_forward;
pub mod message_mention;
pub mod message_notification;
pub mod message_pin;
pub mod message_poll;
pub mod message_reaction;
pub mod message_read_state;
pub mod message_scheduled;
pub mod message_sticker;
pub mod message_thread;
pub mod message_thread_participant;
pub use message::{Message, MessageDetail, MessageType};
pub use message_article::{
ArticleCard, ArticleSort, CreateArticleInput, MessageArticle, UpdateArticleInput,
};
pub use message_attachment::{AttachmentSummary, MessageAttachment};
pub use message_bookmark::MessageBookmark;
pub use message_component::{ComponentType, MessageComponent};
pub use message_draft::{DraftUpsertInput, MessageDraft};
pub use message_edit::{EditSummary, MessageEdit};
pub use message_embed::{EmbedDetail, MessageEmbed, MessageEmbedField};
pub use message_forward::MessageForward;
pub use message_mention::MessageMention;
pub use message_notification::{MessageNotification, NotificationReason, NotificationStatus};
pub use message_pin::{MessagePin, PinDetail};
pub use message_poll::{
MessagePoll, MessagePollOption, MessagePollVote, PollOptionResult, PollResult,
};
pub use message_reaction::{MessageReaction, ReactionCount};
pub use message_read_state::{MessageReadState, ReadStateSummary};
pub use message_scheduled::{MessageScheduled, ScheduledStatus};
pub use message_sticker::MessageSticker;
pub use message_thread::MessageThread;
pub use message_thread_participant::{JoinReason, MessageThreadParticipant};
+1 -1
View File
@@ -1 +1 @@
include!(concat!(env!("OUT_DIR"), "/appks.core.v1.rs")); include!(concat!(env!("OUT_DIR"), "/appks.core.v1.rs"));
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod core; pub mod core;
pub mod im; pub mod im;
+263
View File
@@ -0,0 +1,263 @@
//! Forum article CRUD operations on `MessageRepo`.
//!
//! Articles extend regular messages with forum-specific metadata (title, cover,
//! view/like stats, tags). Rendered as waterfall cards in forum channels.
use chrono::Utc;
use sqlx::Row;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message::AuthorInfo;
use crate::models::message_article::{ArticleCard, ArticleSort, MessageArticle};
use super::message_repo::MessageRepo;
use super::pagination::{CursorPage, clamp_limit};
impl MessageRepo {
/// Create an article record linked to an existing message.
#[allow(clippy::too_many_arguments)]
pub async fn create_article(
&self,
message_id: Uuid,
title: &str,
summary: Option<&str>,
cover_url: Option<&str>,
cover_width: Option<i32>,
cover_height: Option<i32>,
cover_color: Option<&str>,
tags: Option<&serde_json::Value>,
) -> ImksResult<MessageArticle> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query_as::<_, MessageArticle>(
r#"
INSERT INTO message_article (
id, message_id, title, summary, cover_url, cover_width, cover_height,
cover_color, tags, view_count, like_count, bookmark_count, reply_count,
is_pinned_to_top, is_answered, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 0, 0, 0, 0, FALSE, FALSE, $10, $10)
RETURNING *
"#,
)
.bind(id)
.bind(message_id)
.bind(title)
.bind(summary)
.bind(cover_url)
.bind(cover_width)
.bind(cover_height)
.bind(cover_color)
.bind(tags)
.bind(now)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Update an existing article's metadata. Does NOT update the message body.
pub async fn update_article(
&self,
message_id: Uuid,
title: Option<&str>,
summary: Option<&str>,
cover_url: Option<&str>,
cover_color: Option<&str>,
tags: Option<&serde_json::Value>,
) -> ImksResult<Option<MessageArticle>> {
let now = Utc::now();
sqlx::query_as::<_, MessageArticle>(
r#"
UPDATE message_article
SET title = COALESCE($1, title),
summary = COALESCE($2, summary),
cover_url = COALESCE($3, cover_url),
cover_color = COALESCE($4, cover_color),
tags = COALESCE($5, tags),
updated_at = $6
WHERE message_id = $7
RETURNING *
"#,
)
.bind(title)
.bind(summary)
.bind(cover_url)
.bind(cover_color)
.bind(tags)
.bind(now)
.bind(message_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
/// Get an article by its message_id.
pub async fn get_article(&self, message_id: Uuid) -> ImksResult<Option<MessageArticle>> {
sqlx::query_as::<_, MessageArticle>("SELECT * FROM message_article WHERE message_id = $1")
.bind(message_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
/// List articles in a forum channel with id-based cursor pagination.
pub async fn list_articles(
&self,
channel_id: Uuid,
sort: ArticleSort,
before: Option<(i64, Uuid)>,
limit: Option<i64>,
) -> ImksResult<CursorPage<ArticleCard>> {
let effective_limit = clamp_limit(limit);
let fetch_limit = effective_limit + 1;
let cursor_id = before.map(|(_, id)| id);
let order_by = match sort {
ArticleSort::LatestActivity => "a.last_reply_at DESC NULLS LAST, m.id DESC",
ArticleSort::Newest => "m.id DESC",
ArticleSort::MostViewed => "a.view_count DESC, m.id DESC",
ArticleSort::MostLiked => "a.like_count DESC, m.id DESC",
ArticleSort::PinnedFirst => {
"a.is_pinned_to_top DESC, a.last_reply_at DESC NULLS LAST, m.id DESC"
}
};
let query = if cursor_id.is_some() {
format!(
r#"
SELECT a.*, m.author_id,
(
SELECT att.url
FROM message_attachment att
WHERE att.message_id = a.message_id
AND att.content_type LIKE 'image/%'
ORDER BY att.created_at
LIMIT 1
) AS first_image_url
FROM message_article a
JOIN message m ON m.id = a.message_id
WHERE m.channel_id = $1
AND m.deleted_at IS NULL
AND m.id < $2
ORDER BY {order_by}
LIMIT $3
"#,
)
} else {
format!(
r#"
SELECT a.*, m.author_id,
(
SELECT att.url
FROM message_attachment att
WHERE att.message_id = a.message_id
AND att.content_type LIKE 'image/%'
ORDER BY att.created_at
LIMIT 1
) AS first_image_url
FROM message_article a
JOIN message m ON m.id = a.message_id
WHERE m.channel_id = $1
AND m.deleted_at IS NULL
ORDER BY {order_by}
LIMIT $2
"#,
)
};
let rows = if let Some(cursor) = cursor_id {
sqlx::query(sqlx::AssertSqlSafe(query.as_str()))
.bind(channel_id)
.bind(cursor)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
} else {
sqlx::query(sqlx::AssertSqlSafe(query.as_str()))
.bind(channel_id)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
};
// Convert raw rows to MessageArticle
let articles: Vec<MessageArticle> = rows
.iter()
.map(|r| MessageArticle {
id: r.get("id"),
message_id: r.get("message_id"),
title: r.get("title"),
summary: r.get("summary"),
cover_url: r.get("cover_url"),
cover_width: r.get("cover_width"),
cover_height: r.get("cover_height"),
cover_color: r.get("cover_color"),
tags: r.get("tags"),
view_count: r.get("view_count"),
like_count: r.get("like_count"),
bookmark_count: r.get("bookmark_count"),
reply_count: r.get("reply_count"),
last_reply_message_id: r.get("last_reply_message_id"),
last_reply_at: r.get("last_reply_at"),
last_reply_user_id: r.get("last_reply_user_id"),
is_pinned_to_top: r.get("is_pinned_to_top"),
is_answered: r.get("is_answered"),
answered_by: r.get("answered_by"),
answered_at: r.get("answered_at"),
created_at: r.get("created_at"),
updated_at: r.get("updated_at"),
})
.collect();
let has_more = articles.len() > effective_limit as usize;
let items: Vec<MessageArticle> = articles
.into_iter()
.take(effective_limit as usize)
.collect();
let next_cursor = if has_more {
items.last().map(|a| a.message_id)
} else {
None
};
let cards: Vec<ArticleCard> = items
.into_iter()
.zip(rows.iter())
.map(|(article, row)| {
let author_id: Uuid = row.get("author_id");
ArticleCard {
article,
author: AuthorInfo {
id: author_id,
username: author_id.to_string(),
display_name: None,
avatar_url: None,
is_bot: false,
},
tag_names: Vec::new(),
first_image_url: row.get("first_image_url"),
bookmarked: false,
liked: false,
}
})
.collect();
Ok(CursorPage {
items: cards,
next_cursor,
has_more,
})
}
/// Increment the view count for an article.
pub async fn increment_article_view(&self, message_id: Uuid) -> ImksResult<()> {
sqlx::query("UPDATE message_article SET view_count = view_count + 1 WHERE message_id = $1")
.bind(message_id)
.execute(self.pool())
.await?;
Ok(())
}
}
+83
View File
@@ -0,0 +1,83 @@
//! Attachment CRUD operations on `MessageRepo`.
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_attachment::{AttachmentSummary, MessageAttachment};
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Create a single attachment record.
#[allow(clippy::too_many_arguments)]
pub async fn create_attachment(
&self,
message_id: Uuid,
filename: &str,
content_type: Option<&str>,
size: i64,
url: &str,
storage_key: Option<&str>,
width: Option<i32>,
height: Option<i32>,
spoiler: bool,
) -> ImksResult<MessageAttachment> {
let id = Uuid::now_v7();
sqlx::query_as::<_, MessageAttachment>(
r#"
INSERT INTO message_attachment (
id, message_id, filename, content_type, size, url,
storage_key, width, height, spoiler
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING *
"#,
)
.bind(id)
.bind(message_id)
.bind(filename)
.bind(content_type)
.bind(size)
.bind(url)
.bind(storage_key)
.bind(width)
.bind(height)
.bind(spoiler)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get all attachments for a message.
pub async fn get_attachments(&self, message_id: Uuid) -> ImksResult<Vec<MessageAttachment>> {
sqlx::query_as::<_, MessageAttachment>(
"SELECT * FROM message_attachment WHERE message_id = $1 ORDER BY created_at",
)
.bind(message_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
/// Get lightweight attachment summaries (no URLs/storage keys).
pub async fn get_attachment_summaries(
&self,
message_id: Uuid,
) -> ImksResult<Vec<AttachmentSummary>> {
let attachments = self.get_attachments(message_id).await?;
Ok(attachments
.into_iter()
.map(AttachmentSummary::from)
.collect())
}
/// Delete a single attachment.
pub async fn delete_attachment(&self, attachment_id: Uuid) -> ImksResult<bool> {
let result = sqlx::query("DELETE FROM message_attachment WHERE id = $1")
.bind(attachment_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
}
+112
View File
@@ -0,0 +1,112 @@
//! Bookmark CRUD operations on `MessageRepo`.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_bookmark::MessageBookmark;
use super::message_repo::MessageRepo;
use super::pagination::{CursorPage, clamp_limit};
impl MessageRepo {
/// Add a bookmark for a message.
pub async fn add_bookmark(
&self,
message_id: Uuid,
channel_id: Uuid,
user_id: Uuid,
note: Option<&str>,
) -> ImksResult<MessageBookmark> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query_as::<_, MessageBookmark>(
r#"
INSERT INTO message_bookmark (id, message_id, channel_id, user_id, note, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $6)
ON CONFLICT (user_id, message_id) DO UPDATE SET note = EXCLUDED.note, updated_at = EXCLUDED.updated_at
RETURNING *
"#,
)
.bind(id)
.bind(message_id)
.bind(channel_id)
.bind(user_id)
.bind(note)
.bind(now)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Remove a bookmark.
pub async fn remove_bookmark(&self, message_id: Uuid, user_id: Uuid) -> ImksResult<bool> {
let result =
sqlx::query("DELETE FROM message_bookmark WHERE message_id = $1 AND user_id = $2")
.bind(message_id)
.bind(user_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
/// Check if a user has bookmarked a message.
pub async fn is_bookmarked(&self, message_id: Uuid, user_id: Uuid) -> ImksResult<bool> {
let exists: Option<Uuid> = sqlx::query_scalar(
"SELECT id FROM message_bookmark WHERE message_id = $1 AND user_id = $2",
)
.bind(message_id)
.bind(user_id)
.fetch_optional(self.pool())
.await?;
Ok(exists.is_some())
}
/// List a user's bookmarks with cursor-based pagination (newest first).
pub async fn list_bookmarks(
&self,
user_id: Uuid,
before: Option<Uuid>,
limit: Option<i64>,
) -> ImksResult<CursorPage<MessageBookmark>> {
let effective_limit = clamp_limit(limit);
let fetch_limit = effective_limit + 1;
let rows = match before {
Some(cursor) => {
sqlx::query_as::<_, MessageBookmark>(
r#"
SELECT * FROM message_bookmark
WHERE user_id = $1 AND id < $2
ORDER BY id DESC
LIMIT $3
"#,
)
.bind(user_id)
.bind(cursor)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
None => {
sqlx::query_as::<_, MessageBookmark>(
r#"
SELECT * FROM message_bookmark
WHERE user_id = $1
ORDER BY id DESC
LIMIT $2
"#,
)
.bind(user_id)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
};
Ok(CursorPage::from_raw(rows, effective_limit, |b| b.id))
}
}
+88
View File
@@ -0,0 +1,88 @@
//! Interactive component CRUD operations on `MessageRepo`.
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_component::MessageComponent;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Create an interactive component (button/select menu) on a message.
#[allow(clippy::too_many_arguments)]
pub async fn create_component(
&self,
message_id: Uuid,
component_type: &str,
custom_id: &str,
label: Option<&str>,
emoji: Option<&str>,
style: Option<&str>,
url: Option<&str>,
disabled: bool,
row: i32,
position: i32,
) -> ImksResult<MessageComponent> {
sqlx::query_as::<_, MessageComponent>(
r#"
INSERT INTO message_component (
id, message_id, row, position, component_type, custom_id,
label, emoji, style, url, disabled
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING *
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(row)
.bind(position)
.bind(component_type)
.bind(custom_id)
.bind(label)
.bind(emoji)
.bind(style)
.bind(url)
.bind(disabled)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get all components on a message, ordered by layout.
pub async fn get_components(&self, message_id: Uuid) -> ImksResult<Vec<MessageComponent>> {
sqlx::query_as::<_, MessageComponent>(
r#"
SELECT * FROM message_component
WHERE message_id = $1
ORDER BY row, position
"#,
)
.bind(message_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
/// Update a component's label and/or disabled state (e.g. after interaction).
pub async fn update_component(
&self,
component_id: Uuid,
label: Option<&str>,
disabled: bool,
) -> ImksResult<Option<MessageComponent>> {
sqlx::query_as::<_, MessageComponent>(
r#"
UPDATE message_component
SET label = COALESCE($1, label), disabled = $2
WHERE id = $3
RETURNING *
"#,
)
.bind(label)
.bind(disabled)
.bind(component_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
}
+116
View File
@@ -0,0 +1,116 @@
//! Message write operations — insert, update body, soft delete.
//!
//! All mutations use parameterized queries and return the affected row(s).
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message::{Message, new_message_id};
use super::message_repo::MessageRepo;
/// Input payload for creating a new message.
#[derive(Debug, Clone)]
pub struct CreateMessageInput {
/// Target channel UUID.
pub channel_id: Uuid,
/// Author (user) UUID — extracted from JWT `sub` claim.
pub author_id: Uuid,
/// Thread this message belongs to (`None` = top-level).
pub thread_id: Option<Uuid>,
/// Direct reply reference (`None` = not a reply).
pub reply_to_message_id: Option<Uuid>,
/// Discriminator: `"text"`, `"system"`, `"event"`, `"article"`.
pub message_type: String,
/// Plain text or markdown body.
pub body: String,
/// Extensible metadata (flags, locale, etc.).
pub metadata: Option<serde_json::Value>,
/// Whether this is a system/bot-generated message.
pub system: bool,
}
impl MessageRepo {
/// Insert a new message row and return it.
///
/// The message ID is a fresh UUID v7 (time-ordered).
pub async fn create(&self, input: &CreateMessageInput) -> ImksResult<Message> {
let id = new_message_id();
let now = Utc::now();
let row = sqlx::query_as::<_, Message>(
r#"
INSERT INTO message (
id, channel_id, author_id, thread_id, reply_to_message_id,
message_type, body, metadata, pinned, system,
edited_at, deleted_at, created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7, $8, FALSE, $9,
NULL, NULL, $10, $10
)
RETURNING *
"#,
)
.bind(id)
.bind(input.channel_id)
.bind(input.author_id)
.bind(input.thread_id)
.bind(input.reply_to_message_id)
.bind(&input.message_type)
.bind(&input.body)
.bind(&input.metadata)
.bind(input.system)
.bind(now)
.fetch_one(self.pool())
.await?;
Ok(row)
}
/// Update the body of an existing message. Sets `edited_at` and `updated_at`.
///
/// Returns the updated row, or an error if the message is not found or deleted.
pub async fn update_body(&self, message_id: Uuid, new_body: &str) -> ImksResult<Message> {
let now = Utc::now();
let row = sqlx::query_as::<_, Message>(
r#"
UPDATE message
SET body = $1, edited_at = $2, updated_at = $2
WHERE id = $3 AND deleted_at IS NULL
RETURNING *
"#,
)
.bind(new_body)
.bind(now)
.bind(message_id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| crate::ImksError::NotFound(format!("message {message_id}")))?;
Ok(row)
}
/// Soft-delete a message by setting `deleted_at`.
///
/// Returns `Ok(())` even if the message was already deleted.
pub async fn soft_delete(&self, message_id: Uuid) -> ImksResult<()> {
let now = Utc::now();
sqlx::query(
r#"
UPDATE message
SET deleted_at = $1, updated_at = $1
WHERE id = $2 AND deleted_at IS NULL
"#,
)
.bind(now)
.bind(message_id)
.execute(self.pool())
.await?;
Ok(())
}
}
+113
View File
@@ -0,0 +1,113 @@
//! Draft CRUD operations on `MessageRepo`.
//!
//! One draft per (channel, user, thread). Upserted on every keystroke debounce,
//! deleted on send. Thread_id=NULL uses a dedicated conflict target.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_draft::MessageDraft;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Upsert a draft for the given (channel, user, thread) key.
/// Uses NULL-safe conflict handling via COALESCE.
pub async fn upsert_draft(
&self,
channel_id: Uuid,
user_id: Uuid,
thread_id: Option<Uuid>,
body: &str,
reply_to_message_id: Option<Uuid>,
metadata: Option<serde_json::Value>,
) -> ImksResult<MessageDraft> {
let id = Uuid::now_v7();
let now = Utc::now();
let query = if thread_id.is_some() {
r#"
INSERT INTO message_draft (
id, channel_id, user_id, thread_id, reply_to_message_id,
body, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8)
ON CONFLICT (channel_id, user_id, thread_id) DO UPDATE SET
body = EXCLUDED.body,
reply_to_message_id = EXCLUDED.reply_to_message_id,
metadata = EXCLUDED.metadata,
updated_at = EXCLUDED.updated_at
RETURNING *
"#
} else {
r#"
INSERT INTO message_draft (
id, channel_id, user_id, thread_id, reply_to_message_id,
body, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8)
ON CONFLICT (channel_id, user_id) WHERE thread_id IS NULL DO UPDATE SET
body = EXCLUDED.body,
reply_to_message_id = EXCLUDED.reply_to_message_id,
metadata = EXCLUDED.metadata,
updated_at = EXCLUDED.updated_at
RETURNING *
"#
};
sqlx::query_as::<_, MessageDraft>(query)
.bind(id)
.bind(channel_id)
.bind(user_id)
.bind(thread_id)
.bind(reply_to_message_id)
.bind(body)
.bind(metadata)
.bind(now)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get a user's draft for a channel (optionally scoped to a thread).
pub async fn get_draft(
&self,
channel_id: Uuid,
user_id: Uuid,
thread_id: Option<Uuid>,
) -> ImksResult<Option<MessageDraft>> {
sqlx::query_as::<_, MessageDraft>(
r#"
SELECT * FROM message_draft
WHERE channel_id = $1 AND user_id = $2 AND thread_id IS NOT DISTINCT FROM $3
"#,
)
.bind(channel_id)
.bind(user_id)
.bind(thread_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
/// Delete a draft after the message is sent.
pub async fn delete_draft(
&self,
channel_id: Uuid,
user_id: Uuid,
thread_id: Option<Uuid>,
) -> ImksResult<bool> {
let result = sqlx::query(
r#"
DELETE FROM message_draft
WHERE channel_id = $1 AND user_id = $2 AND thread_id IS NOT DISTINCT FROM $3
"#,
)
.bind(channel_id)
.bind(user_id)
.bind(thread_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
}
+77
View File
@@ -0,0 +1,77 @@
//! Edit history CRUD operations on `MessageRepo`.
//!
//! Immutable append-only log. Every message body edit creates one row.
use chrono::Utc;
use sqlx::Row;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_edit::{EditSummary, MessageEdit};
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Record an edit to a message's body.
pub async fn record_edit(
&self,
message_id: Uuid,
edited_by: Uuid,
old_body: &str,
new_body: &str,
) -> ImksResult<MessageEdit> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query_as::<_, MessageEdit>(
r#"
INSERT INTO message_edit (id, message_id, edited_by, old_body, new_body, edited_at)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING *
"#,
)
.bind(id)
.bind(message_id)
.bind(edited_by)
.bind(old_body)
.bind(new_body)
.bind(now)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get the full edit history for a message, oldest first.
pub async fn get_edit_history(&self, message_id: Uuid) -> ImksResult<Vec<MessageEdit>> {
sqlx::query_as::<_, MessageEdit>(
"SELECT * FROM message_edit WHERE message_id = $1 ORDER BY edited_at ASC",
)
.bind(message_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
/// Get a summary of edits for a message (count + last editor).
pub async fn get_edit_summary(&self, message_id: Uuid) -> ImksResult<EditSummary> {
let row = sqlx::query(
r#"
SELECT
COUNT(*)::BIGINT AS edit_count,
MAX(edited_at) AS last_edited_at,
(ARRAY_AGG(edited_by ORDER BY edited_at DESC))[1] AS last_edited_by
FROM message_edit
WHERE message_id = $1
"#,
)
.bind(message_id)
.fetch_one(self.pool())
.await?;
Ok(EditSummary {
edit_count: row.get("edit_count"),
last_edited_at: row.get("last_edited_at"),
last_edited_by: row.get("last_edited_by"),
})
}
}
+106
View File
@@ -0,0 +1,106 @@
//! Embed CRUD operations on `MessageRepo`.
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_embed::{EmbedDetail, MessageEmbed, MessageEmbedField};
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Create an embed with its fields. Returns the embed (fields fetched separately).
#[allow(clippy::too_many_arguments)]
pub async fn create_embed(
&self,
message_id: Uuid,
embed_type: &str,
title: Option<&str>,
description: Option<&str>,
url: Option<&str>,
color: Option<i32>,
image_url: Option<&str>,
author_name: Option<&str>,
author_url: Option<&str>,
footer_text: Option<&str>,
provider_name: Option<&str>,
fields: &[(String, String, bool)],
) -> ImksResult<MessageEmbed> {
let embed_id = Uuid::now_v7();
let embed = sqlx::query_as::<_, MessageEmbed>(
r#"
INSERT INTO message_embed (
id, message_id, embed_type, title, description, url, color,
image_url, author_name, author_url, footer_text, provider_name
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING *
"#,
)
.bind(embed_id)
.bind(message_id)
.bind(embed_type)
.bind(title)
.bind(description)
.bind(url)
.bind(color)
.bind(image_url)
.bind(author_name)
.bind(author_url)
.bind(footer_text)
.bind(provider_name)
.fetch_one(self.pool())
.await?;
for (i, (name, value, inline)) in fields.iter().enumerate() {
sqlx::query(
r#"
INSERT INTO message_embed_field (id, embed_id, name, value, inline, position)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
)
.bind(Uuid::now_v7())
.bind(embed_id)
.bind(name)
.bind(value)
.bind(inline)
.bind(i as i32)
.execute(self.pool())
.await?;
}
Ok(embed)
}
/// Get all embeds for a message, including their fields.
pub async fn get_embeds(&self, message_id: Uuid) -> ImksResult<Vec<EmbedDetail>> {
let embeds: Vec<MessageEmbed> =
sqlx::query_as("SELECT * FROM message_embed WHERE message_id = $1 ORDER BY created_at")
.bind(message_id)
.fetch_all(self.pool())
.await?;
let mut result = Vec::with_capacity(embeds.len());
for embed in embeds {
let fields: Vec<MessageEmbedField> = sqlx::query_as(
"SELECT * FROM message_embed_field WHERE embed_id = $1 ORDER BY position",
)
.bind(embed.id)
.fetch_all(self.pool())
.await?;
result.push(EmbedDetail { embed, fields });
}
Ok(result)
}
/// Delete an embed (fields cascade via FK).
pub async fn delete_embed(&self, embed_id: Uuid) -> ImksResult<bool> {
let result = sqlx::query("DELETE FROM message_embed WHERE id = $1")
.bind(embed_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
}
+44
View File
@@ -0,0 +1,44 @@
//! Message forward CRUD operations on `MessageRepo`.
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_forward::MessageForward;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Record a forwarded message's provenance.
pub async fn record_forward(
&self,
message_id: Uuid,
source_message_id: Uuid,
source_channel_id: Uuid,
forwarded_by: Uuid,
) -> ImksResult<MessageForward> {
sqlx::query_as::<_, MessageForward>(
r#"
INSERT INTO message_forward (id, message_id, source_message_id, source_channel_id, forwarded_by)
VALUES ($1, $2, $3, $4, $5)
RETURNING *
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(source_message_id)
.bind(source_channel_id)
.bind(forwarded_by)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get forwarding info for a message.
pub async fn get_forward_info(&self, message_id: Uuid) -> ImksResult<Option<MessageForward>> {
sqlx::query_as::<_, MessageForward>("SELECT * FROM message_forward WHERE message_id = $1")
.bind(message_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
}
+111
View File
@@ -0,0 +1,111 @@
//! Mention CRUD operations on `MessageRepo`.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_mention::MessageMention;
use super::message_repo::MessageRepo;
use super::pagination::{CursorPage, clamp_limit};
impl MessageRepo {
/// Bulk record mentions for a message. Called right after message creation.
pub async fn record_mentions(
&self,
message_id: Uuid,
channel_id: Uuid,
mentioned_by: Uuid,
mentioned_user_ids: &[Uuid],
) -> ImksResult<()> {
if mentioned_user_ids.is_empty() {
return Ok(());
}
let now = Utc::now();
for &user_id in mentioned_user_ids {
sqlx::query(
r#"
INSERT INTO message_mention (id, message_id, channel_id, mentioned_user_id, mentioned_by, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(channel_id)
.bind(user_id)
.bind(mentioned_by)
.bind(now)
.execute(self.pool())
.await?;
}
Ok(())
}
/// List mentions for a specific user, newest first.
pub async fn list_mentions_for_user(
&self,
user_id: Uuid,
before: Option<Uuid>,
limit: Option<i64>,
) -> ImksResult<CursorPage<MessageMention>> {
let effective_limit = clamp_limit(limit);
let fetch_limit = effective_limit + 1;
let rows = match before {
Some(cursor) => {
sqlx::query_as::<_, MessageMention>(
r#"
SELECT * FROM message_mention
WHERE mentioned_user_id = $1 AND id < $2
ORDER BY id DESC
LIMIT $3
"#,
)
.bind(user_id)
.bind(cursor)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
None => {
sqlx::query_as::<_, MessageMention>(
r#"
SELECT * FROM message_mention
WHERE mentioned_user_id = $1
ORDER BY id DESC
LIMIT $2
"#,
)
.bind(user_id)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
};
Ok(CursorPage::from_raw(rows, effective_limit, |m| m.id))
}
/// Mark a mention as read.
pub async fn mark_mention_read(&self, mention_id: Uuid) -> ImksResult<()> {
sqlx::query("UPDATE message_mention SET read_at = $1 WHERE id = $2")
.bind(Utc::now())
.bind(mention_id)
.execute(self.pool())
.await?;
Ok(())
}
/// Delete all mentions for a message (e.g. on message delete).
pub async fn delete_mentions(&self, message_id: Uuid) -> ImksResult<()> {
sqlx::query("DELETE FROM message_mention WHERE message_id = $1")
.bind(message_id)
.execute(self.pool())
.await?;
Ok(())
}
}
+140
View File
@@ -0,0 +1,140 @@
//! Notification CRUD operations on `MessageRepo`.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_notification::MessageNotification;
use super::message_repo::MessageRepo;
use super::pagination::{CursorPage, clamp_limit};
impl MessageRepo {
/// Create a notification for a user triggered by a message.
pub async fn create_notification(
&self,
message_id: Uuid,
channel_id: Uuid,
user_id: Uuid,
reason: &str,
delivery_channel: Option<&str>,
) -> ImksResult<MessageNotification> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query_as::<_, MessageNotification>(
r#"
INSERT INTO message_notification (
id, message_id, channel_id, user_id, reason, status, delivery_channel, created_at
) VALUES ($1, $2, $3, $4, $5, 'pending', $6, $7)
RETURNING *
"#,
)
.bind(id)
.bind(message_id)
.bind(channel_id)
.bind(user_id)
.bind(reason)
.bind(delivery_channel)
.bind(now)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Mark a notification as read.
pub async fn mark_notification_read(&self, notification_id: Uuid) -> ImksResult<()> {
let now = Utc::now();
sqlx::query(
r#"
UPDATE message_notification
SET status = 'read', read_at = $1
WHERE id = $2
"#,
)
.bind(now)
.bind(notification_id)
.execute(self.pool())
.await?;
Ok(())
}
/// Mark all of a user's notifications as read.
pub async fn mark_all_notifications_read(&self, user_id: Uuid) -> ImksResult<u64> {
let now = Utc::now();
let result = sqlx::query(
r#"
UPDATE message_notification
SET status = 'read', read_at = $1
WHERE user_id = $2 AND status != 'read'
"#,
)
.bind(now)
.bind(user_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected())
}
/// List notifications for a user, newest first.
pub async fn list_notifications(
&self,
user_id: Uuid,
before: Option<Uuid>,
limit: Option<i64>,
) -> ImksResult<CursorPage<MessageNotification>> {
let effective_limit = clamp_limit(limit);
let fetch_limit = effective_limit + 1;
let rows = match before {
Some(cursor) => {
sqlx::query_as::<_, MessageNotification>(
r#"
SELECT * FROM message_notification
WHERE user_id = $1 AND id < $2
ORDER BY id DESC
LIMIT $3
"#,
)
.bind(user_id)
.bind(cursor)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
None => {
sqlx::query_as::<_, MessageNotification>(
r#"
SELECT * FROM message_notification
WHERE user_id = $1
ORDER BY id DESC
LIMIT $2
"#,
)
.bind(user_id)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
};
Ok(CursorPage::from_raw(rows, effective_limit, |n| n.id))
}
/// Get unread notification count for a user.
pub async fn get_unread_notification_count(&self, user_id: Uuid) -> ImksResult<i64> {
let count: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)::BIGINT FROM message_notification
WHERE user_id = $1 AND status = 'pending'
"#,
)
.bind(user_id)
.fetch_one(self.pool())
.await?;
Ok(count)
}
}
+110
View File
@@ -0,0 +1,110 @@
//! Pin CRUD operations on `MessageRepo`.
//!
//! One row per pinned message. Channels can have multiple pinned messages.
//! `position` auto-calculated as `MAX(position) + 1` within the channel.
use chrono::Utc;
use sqlx::Row;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_pin::{MessagePin, PinDetail};
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Pin a message in a channel. Computes the next position automatically.
pub async fn pin_message(
&self,
channel_id: Uuid,
message_id: Uuid,
pinned_by: Uuid,
) -> ImksResult<MessagePin> {
let id = Uuid::now_v7();
let now = Utc::now();
let mut tx = self.pool().begin().await?;
sqlx::query("SELECT pg_advisory_xact_lock(hashtextextended($1, 0))")
.bind(channel_id.to_string())
.execute(&mut *tx)
.await?;
let max_pos: Option<i32> = sqlx::query_scalar(
"SELECT COALESCE(MAX(position), -1) FROM message_pin WHERE channel_id = $1",
)
.bind(channel_id)
.fetch_one(&mut *tx)
.await?;
let position = max_pos.unwrap_or(-1) + 1;
let pin = sqlx::query_as::<_, MessagePin>(
r#"
INSERT INTO message_pin (id, channel_id, message_id, pinned_by, position, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (channel_id, message_id) DO NOTHING
RETURNING *
"#,
)
.bind(id)
.bind(channel_id)
.bind(message_id)
.bind(pinned_by)
.bind(position)
.bind(now)
.fetch_optional(&mut *tx)
.await?
.ok_or_else(|| crate::ImksError::InvalidInput("Message already pinned".into()))?;
tx.commit().await?;
Ok(pin)
}
/// Unpin a message from a channel.
pub async fn unpin_message(&self, channel_id: Uuid, message_id: Uuid) -> ImksResult<bool> {
let result =
sqlx::query("DELETE FROM message_pin WHERE channel_id = $1 AND message_id = $2")
.bind(channel_id)
.bind(message_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
/// List all pinned messages in a channel, newest first, joined with message content.
pub async fn list_pins(&self, channel_id: Uuid) -> ImksResult<Vec<PinDetail>> {
let rows = sqlx::query(
r#"
SELECT p.*, m.body AS message_body, m.author_id AS message_author_id,
m.created_at AS message_created_at
FROM message_pin p
JOIN message m ON m.id = p.message_id
WHERE p.channel_id = $1 AND m.deleted_at IS NULL
ORDER BY p.position ASC
"#,
)
.bind(channel_id)
.fetch_all(self.pool())
.await?;
let result = rows
.into_iter()
.map(|row| PinDetail {
pin: MessagePin {
id: row.get("id"),
channel_id: row.get("channel_id"),
message_id: row.get("message_id"),
pinned_by: row.get("pinned_by"),
position: row.get("position"),
created_at: row.get("created_at"),
},
message_body: row.get("message_body"),
message_author_id: row.get("message_author_id"),
message_created_at: row.get("message_created_at"),
})
.collect();
Ok(result)
}
}
+396
View File
@@ -0,0 +1,396 @@
//! Poll CRUD operations on `MessageRepo`.
//!
//! Handles poll creation (with options), voting (with denormalized counts),
//! and result retrieval.
use chrono::Utc;
use sqlx::Row;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_poll::{MessagePoll, MessagePollOption, MessagePollVote, PollResult};
use super::message_repo::MessageRepo;
/// Canonical poll target resolved from the database.
#[derive(Debug, Clone, Copy)]
pub struct PollTarget {
pub poll_id: Uuid,
pub message_id: Uuid,
pub channel_id: Uuid,
}
impl MessageRepo {
/// Resolve and validate a poll option's canonical message/channel target.
pub async fn get_poll_target(&self, poll_id: Uuid, option_id: Uuid) -> ImksResult<PollTarget> {
let row = sqlx::query(
r#"
SELECT p.id AS poll_id, p.message_id, m.channel_id, o.id AS option_id
FROM message_poll p
JOIN message m ON m.id = p.message_id
JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2
WHERE p.id = $1 AND m.deleted_at IS NULL
"#,
)
.bind(poll_id)
.bind(option_id)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?;
Ok(PollTarget {
poll_id: row.get("poll_id"),
message_id: row.get("message_id"),
channel_id: row.get("channel_id"),
})
}
/// Create a poll with its options. Returns the poll (options fetched separately).
pub async fn create_poll(
&self,
message_id: Uuid,
question: &str,
allow_multiselect: bool,
max_selections: Option<i32>,
expires_at: Option<chrono::DateTime<Utc>>,
options: &[(String, Option<String>)],
) -> ImksResult<MessagePoll> {
let poll_id = Uuid::now_v7();
let now = Utc::now();
let poll = sqlx::query_as::<_, MessagePoll>(
r#"
INSERT INTO message_poll (
id, message_id, question, allow_multiselect,
max_selections, expires_at, total_votes, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, 0, $7, $7)
RETURNING *
"#,
)
.bind(poll_id)
.bind(message_id)
.bind(question)
.bind(allow_multiselect)
.bind(max_selections)
.bind(expires_at)
.bind(now)
.fetch_one(self.pool())
.await?;
for (i, (text, emoji)) in options.iter().enumerate() {
let opt_id = Uuid::now_v7();
sqlx::query(
r#"
INSERT INTO message_poll_option (id, poll_id, text, emoji, vote_count, position)
VALUES ($1, $2, $3, $4, 0, $5)
"#,
)
.bind(opt_id)
.bind(poll_id)
.bind(text)
.bind(emoji.as_deref())
.bind(i as i32)
.execute(self.pool())
.await?;
}
Ok(poll)
}
/// Cast a validated vote and return the canonical message/channel target.
pub async fn cast_vote_checked(
&self,
poll_id: Uuid,
option_id: Uuid,
user_id: Uuid,
) -> ImksResult<PollTarget> {
let mut tx = self.pool().begin().await?;
let now = Utc::now();
let row = sqlx::query(
r#"
SELECT p.id AS poll_id, p.message_id, p.allow_multiselect, p.max_selections,
p.expires_at, m.channel_id, o.id AS option_id
FROM message_poll p
JOIN message m ON m.id = p.message_id
JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2
WHERE p.id = $1 AND m.deleted_at IS NULL
FOR UPDATE OF p
"#,
)
.bind(poll_id)
.bind(option_id)
.fetch_optional(&mut *tx)
.await?
.ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?;
let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
if expires_at.is_some_and(|exp| now >= exp) {
return Err(crate::ImksError::InvalidInput("Poll has expired".into()));
}
let allow_multiselect: bool = row.get("allow_multiselect");
let max_selections: Option<i32> = row.get("max_selections");
let current_votes: Vec<Uuid> = sqlx::query_scalar(
"SELECT option_id FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2",
)
.bind(poll_id)
.bind(user_id)
.fetch_all(&mut *tx)
.await?;
if current_votes.contains(&option_id) {
return Err(crate::ImksError::InvalidInput(
"Already voted for this option".into(),
));
}
if !allow_multiselect && !current_votes.is_empty() {
return Err(crate::ImksError::InvalidInput(
"Poll allows only one selection".into(),
));
}
if let Some(max) = max_selections
&& current_votes.len() >= max.max(1) as usize
{
return Err(crate::ImksError::InvalidInput(
"Poll selection limit exceeded".into(),
));
}
let vote_id = Uuid::now_v7();
sqlx::query(
r#"
INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, created_at)
VALUES ($1, $2, $3, $4, $5)
"#,
)
.bind(vote_id)
.bind(poll_id)
.bind(option_id)
.bind(user_id)
.bind(now)
.execute(&mut *tx)
.await?;
sqlx::query("UPDATE message_poll_option SET vote_count = vote_count + 1 WHERE id = $1")
.bind(option_id)
.execute(&mut *tx)
.await?;
sqlx::query(
"UPDATE message_poll SET total_votes = total_votes + 1, updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(poll_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(PollTarget {
poll_id,
message_id: row.get("message_id"),
channel_id: row.get("channel_id"),
})
}
/// Remove a validated vote and return the canonical message/channel target.
pub async fn remove_vote_checked(
&self,
poll_id: Uuid,
option_id: Uuid,
user_id: Uuid,
) -> ImksResult<Option<PollTarget>> {
let mut tx = self.pool().begin().await?;
let now = Utc::now();
let row = sqlx::query(
r#"
SELECT p.id AS poll_id, p.message_id, m.channel_id, o.id AS option_id
FROM message_poll p
JOIN message m ON m.id = p.message_id
JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2
WHERE p.id = $1 AND m.deleted_at IS NULL
FOR UPDATE OF p
"#,
)
.bind(poll_id)
.bind(option_id)
.fetch_optional(&mut *tx)
.await?
.ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?;
let result = sqlx::query(
r#"
DELETE FROM message_poll_vote
WHERE poll_id = $1 AND option_id = $2 AND user_id = $3
"#,
)
.bind(poll_id)
.bind(option_id)
.bind(user_id)
.execute(&mut *tx)
.await?;
if result.rows_affected() == 0 {
tx.commit().await?;
return Ok(None);
}
sqlx::query(
"UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1",
)
.bind(option_id)
.execute(&mut *tx)
.await?;
sqlx::query(
"UPDATE message_poll SET total_votes = GREATEST(total_votes - 1, 0), updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(poll_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(Some(PollTarget {
poll_id,
message_id: row.get("message_id"),
channel_id: row.get("channel_id"),
}))
}
/// Cast a vote. Increments denormalized counts atomically.
pub async fn vote(
&self,
poll_id: Uuid,
option_id: Uuid,
user_id: Uuid,
) -> ImksResult<MessagePollVote> {
let id = Uuid::now_v7();
let now = Utc::now();
let vote = sqlx::query_as::<_, MessagePollVote>(
r#"
INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, created_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (poll_id, user_id, option_id) DO NOTHING
RETURNING *
"#,
)
.bind(id)
.bind(poll_id)
.bind(option_id)
.bind(user_id)
.bind(now)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| crate::ImksError::InvalidInput("Already voted for this option".into()))?;
sqlx::query("UPDATE message_poll_option SET vote_count = vote_count + 1 WHERE id = $1")
.bind(option_id)
.execute(self.pool())
.await?;
sqlx::query(
"UPDATE message_poll SET total_votes = total_votes + 1, updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(poll_id)
.execute(self.pool())
.await?;
Ok(vote)
}
/// Remove a vote. Decrements denormalized counts.
pub async fn remove_vote(
&self,
poll_id: Uuid,
option_id: Uuid,
user_id: Uuid,
) -> ImksResult<bool> {
let result = sqlx::query(
r#"
DELETE FROM message_poll_vote
WHERE poll_id = $1 AND option_id = $2 AND user_id = $3
"#,
)
.bind(poll_id)
.bind(option_id)
.bind(user_id)
.execute(self.pool())
.await?;
if result.rows_affected() == 0 {
return Ok(false);
}
sqlx::query(
"UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1",
)
.bind(option_id)
.execute(self.pool())
.await?;
sqlx::query(
"UPDATE message_poll SET total_votes = GREATEST(total_votes - 1, 0), updated_at = $1 WHERE id = $2",
)
.bind(Utc::now())
.bind(poll_id)
.execute(self.pool())
.await?;
Ok(true)
}
/// Get full poll results including options, vote counts, and the given user's votes.
pub async fn get_poll_result(
&self,
message_id: Uuid,
user_id: Uuid,
) -> ImksResult<Option<PollResult>> {
let poll =
sqlx::query_as::<_, MessagePoll>("SELECT * FROM message_poll WHERE message_id = $1")
.bind(message_id)
.fetch_optional(self.pool())
.await?;
let Some(poll) = poll else {
return Ok(None);
};
let options: Vec<MessagePollOption> = sqlx::query_as(
"SELECT * FROM message_poll_option WHERE poll_id = $1 ORDER BY position",
)
.bind(poll.id)
.fetch_all(self.pool())
.await?;
let my_votes: Vec<Uuid> = sqlx::query_scalar(
"SELECT option_id FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2",
)
.bind(poll.id)
.bind(user_id)
.fetch_all(self.pool())
.await?;
Ok(Some(PollResult::from_poll(poll, options, my_votes)))
}
/// Close a poll by setting its expiration to now.
pub async fn close_poll(&self, message_id: Uuid) -> ImksResult<()> {
let now = Utc::now();
sqlx::query(
r#"
UPDATE message_poll
SET expires_at = $1, updated_at = $1
WHERE message_id = $2 AND (expires_at IS NULL OR expires_at > $1)
"#,
)
.bind(now)
.bind(message_id)
.execute(self.pool())
.await?;
Ok(())
}
}
+169
View File
@@ -0,0 +1,169 @@
//! Message read operations — get single, list by channel, list by thread.
//!
//! All list queries use UUID v7 cursor-based pagination (no OFFSET).
use sqlx::Row;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message::Message;
use super::message_repo::MessageRepo;
use super::pagination::{CursorPage, clamp_limit};
impl MessageRepo {
/// Fetch a single message by ID.
///
/// Returns `None` if the message doesn't exist or has been soft-deleted.
pub async fn get(&self, message_id: Uuid) -> ImksResult<Option<Message>> {
let row = sqlx::query_as::<_, Message>(
"SELECT * FROM message WHERE id = $1 AND deleted_at IS NULL",
)
.bind(message_id)
.fetch_optional(self.pool())
.await?;
Ok(row)
}
/// List messages in a channel with cursor-based pagination.
///
/// Returns messages in reverse chronological order (newest first).
/// Pass `before` as the last message ID from the previous page to
/// fetch the next page.
pub async fn list_by_channel(
&self,
channel_id: Uuid,
before: Option<Uuid>,
limit: Option<i64>,
) -> ImksResult<CursorPage<Message>> {
let effective_limit = clamp_limit(limit);
// Fetch one extra row to determine `has_more`.
let fetch_limit = effective_limit + 1;
let rows = match before {
Some(cursor) => {
sqlx::query_as::<_, Message>(
r#"
SELECT * FROM message
WHERE channel_id = $1
AND deleted_at IS NULL
AND id < $2
ORDER BY id DESC
LIMIT $3
"#,
)
.bind(channel_id)
.bind(cursor)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
None => {
sqlx::query_as::<_, Message>(
r#"
SELECT * FROM message
WHERE channel_id = $1
AND deleted_at IS NULL
ORDER BY id DESC
LIMIT $2
"#,
)
.bind(channel_id)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
};
Ok(CursorPage::from_raw(rows, effective_limit, |m| m.id))
}
/// List messages in a thread with cursor-based pagination.
pub async fn list_by_thread(
&self,
thread_id: Uuid,
before: Option<Uuid>,
limit: Option<i64>,
) -> ImksResult<CursorPage<Message>> {
let effective_limit = clamp_limit(limit);
let fetch_limit = effective_limit + 1;
let rows = match before {
Some(cursor) => {
sqlx::query_as::<_, Message>(
r#"
SELECT * FROM message
WHERE thread_id = $1
AND deleted_at IS NULL
AND id < $2
ORDER BY id DESC
LIMIT $3
"#,
)
.bind(thread_id)
.bind(cursor)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
None => {
sqlx::query_as::<_, Message>(
r#"
SELECT * FROM message
WHERE thread_id = $1
AND deleted_at IS NULL
ORDER BY id DESC
LIMIT $2
"#,
)
.bind(thread_id)
.bind(fetch_limit)
.fetch_all(self.pool())
.await?
}
};
Ok(CursorPage::from_raw(rows, effective_limit, |m| m.id))
}
/// Get message reaction counts grouped by emoji content.
///
/// Returns `(content, count)` pairs for the given message.
pub async fn get_reaction_counts(&self, message_id: Uuid) -> ImksResult<Vec<(String, i64)>> {
let rows = sqlx::query(
r#"
SELECT content, COUNT(*)::BIGINT AS cnt
FROM message_reaction
WHERE message_id = $1
GROUP BY content
"#,
)
.bind(message_id)
.fetch_all(self.pool())
.await?;
Ok(rows
.into_iter()
.map(|r| (r.get("content"), r.get("cnt")))
.collect())
}
/// Count attachments and embeds for a message.
pub async fn get_content_counts(&self, message_id: Uuid) -> ImksResult<(i64, i64)> {
let att_row = sqlx::query(
"SELECT COUNT(*)::BIGINT AS cnt FROM message_attachment WHERE message_id = $1",
)
.bind(message_id)
.fetch_one(self.pool())
.await?;
let emb_row =
sqlx::query("SELECT COUNT(*)::BIGINT AS cnt FROM message_embed WHERE message_id = $1")
.bind(message_id)
.fetch_one(self.pool())
.await?;
Ok((att_row.get("cnt"), emb_row.get("cnt")))
}
}
+94
View File
@@ -0,0 +1,94 @@
//! Reaction CRUD operations on `MessageRepo`.
//!
//! Each (message, user, content) tuple is unique via ON CONFLICT.
//! Toggle semantics: same request adds/removes the reaction.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_reaction::MessageReaction;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Add or toggle a reaction. Returns the reaction if added, None if already exists.
pub async fn add_reaction(
&self,
message_id: Uuid,
channel_id: Uuid,
user_id: Uuid,
content: &str,
) -> ImksResult<Option<MessageReaction>> {
let id = Uuid::now_v7();
let now = Utc::now();
let row = sqlx::query_as::<_, MessageReaction>(
r#"
INSERT INTO message_reaction (id, message_id, channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (message_id, user_id, content) DO NOTHING
RETURNING *
"#,
)
.bind(id)
.bind(message_id)
.bind(channel_id)
.bind(user_id)
.bind(content)
.bind(now)
.fetch_optional(self.pool())
.await?;
Ok(row)
}
/// Remove a user's reaction from a message.
pub async fn remove_reaction(
&self,
message_id: Uuid,
user_id: Uuid,
content: &str,
) -> ImksResult<bool> {
let result = sqlx::query(
r#"
DELETE FROM message_reaction
WHERE message_id = $1 AND user_id = $2 AND content = $3
"#,
)
.bind(message_id)
.bind(user_id)
.bind(content)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
/// Get all reactions on a message.
pub async fn get_reactions(&self, message_id: Uuid) -> ImksResult<Vec<MessageReaction>> {
sqlx::query_as::<_, MessageReaction>(
"SELECT * FROM message_reaction WHERE message_id = $1 ORDER BY created_at",
)
.bind(message_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
/// Get a specific user's reactions on a message.
pub async fn get_user_reactions(
&self,
message_id: Uuid,
user_id: Uuid,
) -> ImksResult<Vec<MessageReaction>> {
sqlx::query_as::<_, MessageReaction>(
"SELECT * FROM message_reaction WHERE message_id = $1 AND user_id = $2 ORDER BY created_at",
)
.bind(message_id)
.bind(user_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
}
+110
View File
@@ -0,0 +1,110 @@
//! Read state CRUD operations on `MessageRepo`.
//!
//! One row per (channel, user). Upserted on each read; ON CONFLICT DO UPDATE
//! advances the cursor and recalculates unread counts.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_read_state::MessageReadState;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Mark a channel as read up to a given message for a user.
/// Recalculates unread_count and unread_mentions from the database.
pub async fn mark_read(
&self,
channel_id: Uuid,
user_id: Uuid,
last_read_message_id: Uuid,
) -> ImksResult<MessageReadState> {
let now = Utc::now();
let unread_count: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)::BIGINT
FROM message
WHERE channel_id = $1
AND deleted_at IS NULL
AND id > $2
AND author_id != $3
"#,
)
.bind(channel_id)
.bind(last_read_message_id)
.bind(user_id)
.fetch_one(self.pool())
.await?;
let unread_mentions: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)::BIGINT
FROM message_mention mm
WHERE mm.channel_id = $1
AND mm.mentioned_user_id = $2
AND mm.message_id > $3
AND mm.read_at IS NULL
"#,
)
.bind(channel_id)
.bind(user_id)
.bind(last_read_message_id)
.fetch_one(self.pool())
.await?;
let id = Uuid::now_v7();
sqlx::query_as::<_, MessageReadState>(
r#"
INSERT INTO message_read_state (
id, channel_id, user_id, last_read_message_id, last_read_at,
unread_count, unread_mentions, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $8)
ON CONFLICT (channel_id, user_id) DO UPDATE SET
last_read_message_id = EXCLUDED.last_read_message_id,
last_read_at = EXCLUDED.last_read_at,
unread_count = EXCLUDED.unread_count,
unread_mentions = EXCLUDED.unread_mentions,
updated_at = EXCLUDED.updated_at
RETURNING *
"#,
)
.bind(id)
.bind(channel_id)
.bind(user_id)
.bind(last_read_message_id)
.bind(now)
.bind(unread_count)
.bind(unread_mentions)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get a user's read state for a channel.
pub async fn get_read_state(
&self,
channel_id: Uuid,
user_id: Uuid,
) -> ImksResult<Option<MessageReadState>> {
sqlx::query_as::<_, MessageReadState>(
"SELECT * FROM message_read_state WHERE channel_id = $1 AND user_id = $2",
)
.bind(channel_id)
.bind(user_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
/// Get read state summaries for all channels a user participates in.
pub async fn get_user_read_states(&self, user_id: Uuid) -> ImksResult<Vec<MessageReadState>> {
sqlx::query_as::<_, MessageReadState>("SELECT * FROM message_read_state WHERE user_id = $1")
.bind(user_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
}
+27
View File
@@ -0,0 +1,27 @@
//! Message repository — struct definition and pool accessor.
//!
//! CRUD operations are split across `message_create.rs` (writes)
//! and `message_query.rs` (reads) as separate `impl` blocks.
use sqlx::PgPool;
/// Repository for message CRUD operations.
///
/// All queries use parameterized statements via sqlx.
/// IDs are UUID v7 (time-ordered) for efficient cursor pagination.
#[derive(Clone)]
pub struct MessageRepo {
pool: PgPool,
}
impl MessageRepo {
/// Create a new repository backed by the given connection pool.
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
/// Access the inner `PgPool` for advanced queries.
pub fn pool(&self) -> &PgPool {
&self.pool
}
}
+159
View File
@@ -0,0 +1,159 @@
//! Scheduled message CRUD operations on `MessageRepo`.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_scheduled::MessageScheduled;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Schedule a message to be sent later.
#[allow(clippy::too_many_arguments)]
pub async fn schedule_message(
&self,
channel_id: Uuid,
author_id: Uuid,
thread_id: Option<Uuid>,
reply_to_message_id: Option<Uuid>,
body: &str,
metadata: Option<serde_json::Value>,
scheduled_at: chrono::DateTime<Utc>,
) -> ImksResult<MessageScheduled> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query_as::<_, MessageScheduled>(
r#"
INSERT INTO message_scheduled (
id, channel_id, author_id, thread_id, reply_to_message_id,
body, metadata, scheduled_at, status, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'pending', $9, $9)
RETURNING *
"#,
)
.bind(id)
.bind(channel_id)
.bind(author_id)
.bind(thread_id)
.bind(reply_to_message_id)
.bind(body)
.bind(metadata)
.bind(scheduled_at)
.bind(now)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Cancel a scheduled message (only if still pending).
pub async fn cancel_scheduled(&self, scheduled_id: Uuid) -> ImksResult<bool> {
let result = sqlx::query(
r#"
UPDATE message_scheduled
SET status = 'cancelled', updated_at = $1
WHERE id = $2 AND status = 'pending'
"#,
)
.bind(Utc::now())
.bind(scheduled_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected() > 0)
}
/// List a user's scheduled messages.
pub async fn list_scheduled(
&self,
channel_id: Uuid,
author_id: Uuid,
) -> ImksResult<Vec<MessageScheduled>> {
sqlx::query_as::<_, MessageScheduled>(
r#"
SELECT * FROM message_scheduled
WHERE channel_id = $1 AND author_id = $2 AND status = 'pending'
ORDER BY scheduled_at ASC
"#,
)
.bind(channel_id)
.bind(author_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
/// Atomically claim due scheduled messages for background dispatch.
pub async fn claim_due_scheduled(&self) -> ImksResult<Vec<MessageScheduled>> {
let mut tx = self.pool().begin().await?;
let now = Utc::now();
let rows = sqlx::query_as::<_, MessageScheduled>(
r#"
UPDATE message_scheduled
SET status = 'processing', updated_at = $1
WHERE id IN (
SELECT id
FROM message_scheduled
WHERE status = 'pending' AND scheduled_at <= $1
ORDER BY scheduled_at ASC
LIMIT 100
FOR UPDATE SKIP LOCKED
)
RETURNING *
"#,
)
.bind(now)
.fetch_all(&mut *tx)
.await?;
tx.commit().await?;
Ok(rows)
}
/// Get all pending scheduled messages whose time has come (for background dispatch).
pub async fn get_due_scheduled(&self) -> ImksResult<Vec<MessageScheduled>> {
self.claim_due_scheduled().await
}
/// Mark a scheduled message as sent.
pub async fn mark_scheduled_sent(
&self,
scheduled_id: Uuid,
sent_message_id: Uuid,
) -> ImksResult<()> {
sqlx::query(
r#"
UPDATE message_scheduled
SET status = 'sent', sent_message_id = $1, updated_at = $2
WHERE id = $3 AND status = 'processing'
"#,
)
.bind(sent_message_id)
.bind(Utc::now())
.bind(scheduled_id)
.execute(self.pool())
.await?;
Ok(())
}
/// Mark a scheduled message as failed.
pub async fn mark_scheduled_failed(&self, scheduled_id: Uuid, error: &str) -> ImksResult<()> {
sqlx::query(
r#"
UPDATE message_scheduled
SET status = 'failed', error = $1, updated_at = $2
WHERE id = $3 AND status = 'processing'
"#,
)
.bind(error)
.bind(Utc::now())
.bind(scheduled_id)
.execute(self.pool())
.await?;
Ok(())
}
}
+53
View File
@@ -0,0 +1,53 @@
//! Sticker CRUD operations on `MessageRepo`.
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_sticker::MessageSticker;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Record a sticker used in a message.
#[allow(clippy::too_many_arguments)]
pub async fn record_sticker(
&self,
message_id: Uuid,
sticker_id: Uuid,
name: &str,
image_url: &str,
format_type: &str,
pack_name: Option<&str>,
tags: Option<&str>,
) -> ImksResult<MessageSticker> {
sqlx::query_as::<_, MessageSticker>(
r#"
INSERT INTO message_sticker (id, message_id, sticker_id, name, image_url, format_type, pack_name, tags)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING *
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(sticker_id)
.bind(name)
.bind(image_url)
.bind(format_type)
.bind(pack_name)
.bind(tags)
.fetch_one(self.pool())
.await
.map_err(Into::into)
}
/// Get all stickers on a message.
pub async fn get_stickers(&self, message_id: Uuid) -> ImksResult<Vec<MessageSticker>> {
sqlx::query_as::<_, MessageSticker>(
"SELECT * FROM message_sticker WHERE message_id = $1 ORDER BY created_at",
)
.bind(message_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
}
+218
View File
@@ -0,0 +1,218 @@
//! Thread CRUD operations on `MessageRepo`.
use chrono::Utc;
use uuid::Uuid;
use crate::ImksResult;
use crate::models::message_thread::MessageThread;
use crate::models::message_thread_participant::MessageThreadParticipant;
use super::message_repo::MessageRepo;
impl MessageRepo {
/// Create a new thread anchored on a root message.
pub async fn create_thread(
&self,
root_message_id: Uuid,
channel_id: Uuid,
created_by: Uuid,
) -> ImksResult<MessageThread> {
let id = Uuid::now_v7();
let now = Utc::now();
sqlx::query_as::<_, MessageThread>(
r#"
INSERT INTO message_thread (
id, channel_id, root_message_id, created_by,
replies_count, participants_count, resolved,
created_at, updated_at
) VALUES ($1, $2, $3, $4, 0, 0, FALSE, $5, $5)
ON CONFLICT (root_message_id) DO NOTHING
RETURNING *
"#,
)
.bind(id)
.bind(channel_id)
.bind(root_message_id)
.bind(created_by)
.bind(now)
.fetch_optional(self.pool())
.await?
.ok_or_else(|| crate::ImksError::InvalidInput("Thread already exists".into()))
}
/// Get a thread by its ID.
pub async fn get_thread(&self, thread_id: Uuid) -> ImksResult<Option<MessageThread>> {
sqlx::query_as::<_, MessageThread>("SELECT * FROM message_thread WHERE id = $1")
.bind(thread_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
/// Get a thread by its root message ID.
pub async fn get_thread_by_root(
&self,
root_message_id: Uuid,
) -> ImksResult<Option<MessageThread>> {
sqlx::query_as::<_, MessageThread>(
"SELECT * FROM message_thread WHERE root_message_id = $1",
)
.bind(root_message_id)
.fetch_optional(self.pool())
.await
.map_err(Into::into)
}
/// List threads in a channel.
pub async fn list_threads(&self, channel_id: Uuid) -> ImksResult<Vec<MessageThread>> {
sqlx::query_as::<_, MessageThread>(
r#"
SELECT * FROM message_thread
WHERE channel_id = $1
ORDER BY last_reply_at DESC NULLS LAST
"#,
)
.bind(channel_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
/// Increment thread reply counter and update last reply info.
pub async fn bump_thread(&self, thread_id: Uuid, message_id: Uuid) -> ImksResult<()> {
let now = Utc::now();
sqlx::query(
r#"
UPDATE message_thread
SET replies_count = replies_count + 1,
last_reply_message_id = $1,
last_reply_at = $2,
updated_at = $2
WHERE id = $3
"#,
)
.bind(message_id)
.bind(now)
.bind(thread_id)
.execute(self.pool())
.await?;
Ok(())
}
/// Resolve or unresolve a thread.
pub async fn resolve_thread(
&self,
thread_id: Uuid,
resolved_by: Uuid,
resolved: bool,
) -> ImksResult<()> {
if resolved {
sqlx::query(
r#"
UPDATE message_thread
SET resolved = TRUE, resolved_by = $1, resolved_at = $2, updated_at = $2
WHERE id = $3
"#,
)
.bind(resolved_by)
.bind(Utc::now())
.bind(thread_id)
.execute(self.pool())
.await?;
} else {
sqlx::query(
r#"
UPDATE message_thread
SET resolved = FALSE, resolved_by = NULL, resolved_at = NULL, updated_at = $1
WHERE id = $2
"#,
)
.bind(Utc::now())
.bind(thread_id)
.execute(self.pool())
.await?;
}
Ok(())
}
}
impl MessageRepo {
/// Add a participant to a thread (or update their join reason).
pub async fn add_thread_participant(
&self,
thread_id: Uuid,
user_id: Uuid,
joined_reason: &str,
) -> ImksResult<MessageThreadParticipant> {
let id = Uuid::now_v7();
let now = Utc::now();
let participant = sqlx::query_as::<_, MessageThreadParticipant>(
r#"
INSERT INTO message_thread_participant (id, thread_id, user_id, joined_reason, joined_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (thread_id, user_id) DO UPDATE SET joined_reason = EXCLUDED.joined_reason
RETURNING *
"#,
)
.bind(id)
.bind(thread_id)
.bind(user_id)
.bind(joined_reason)
.bind(now)
.fetch_one(self.pool())
.await?;
sqlx::query(
"UPDATE message_thread SET participants_count = (SELECT COUNT(*) FROM message_thread_participant WHERE thread_id = $1) WHERE id = $1",
)
.bind(thread_id)
.execute(self.pool())
.await?;
Ok(participant)
}
/// Remove a participant from a thread.
pub async fn remove_thread_participant(
&self,
thread_id: Uuid,
user_id: Uuid,
) -> ImksResult<bool> {
let result = sqlx::query(
"DELETE FROM message_thread_participant WHERE thread_id = $1 AND user_id = $2",
)
.bind(thread_id)
.bind(user_id)
.execute(self.pool())
.await?;
if result.rows_affected() > 0 {
sqlx::query(
"UPDATE message_thread SET participants_count = GREATEST(participants_count - 1, 0) WHERE id = $1",
)
.bind(thread_id)
.execute(self.pool())
.await?;
}
Ok(result.rows_affected() > 0)
}
/// List all participants in a thread.
pub async fn list_thread_participants(
&self,
thread_id: Uuid,
) -> ImksResult<Vec<MessageThreadParticipant>> {
sqlx::query_as::<_, MessageThreadParticipant>(
"SELECT * FROM message_thread_participant WHERE thread_id = $1 ORDER BY joined_at",
)
.bind(thread_id)
.fetch_all(self.pool())
.await
.map_err(Into::into)
}
}
+25
View File
@@ -0,0 +1,25 @@
pub mod message_article;
pub mod message_attachment;
pub mod message_bookmark;
pub mod message_component;
pub mod message_create;
pub mod message_draft;
pub mod message_edit;
pub mod message_embed;
pub mod message_forward;
pub mod message_mention;
pub mod message_notification;
pub mod message_pin;
pub mod message_poll;
pub mod message_query;
pub mod message_reaction;
pub mod message_read_state;
pub mod message_repo;
pub mod message_scheduled;
pub mod message_sticker;
pub mod message_thread;
pub mod pagination;
pub use message_create::CreateMessageInput;
pub use message_repo::MessageRepo;
pub use pagination::CursorPage;
+114
View File
@@ -0,0 +1,114 @@
//! Cursor-based pagination helpers for repository queries.
//!
//! UUID v7 IDs are time-ordered, so `WHERE id < $cursor ORDER BY id DESC`
//! naturally yields reverse-chronological pages without OFFSET.
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Default number of items per page.
pub const DEFAULT_PAGE_SIZE: i64 = 50;
/// Hard upper bound on page size to prevent abuse.
pub const MAX_PAGE_SIZE: i64 = 100;
/// Generic cursor-based page response.
///
/// Returned by list operations in the repo layer. The `next_cursor`
/// is the last item's UUID — pass it as `before` in the next request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CursorPage<T> {
/// Items in this page (ordered by `id DESC`).
pub items: Vec<T>,
/// Opaque cursor for the next page. `None` when no more results exist.
pub next_cursor: Option<Uuid>,
/// Whether there are more results beyond this page.
pub has_more: bool,
}
impl<T> CursorPage<T> {
/// Build a page from a raw result set that may contain one extra row.
///
/// If `raw_items.len() > limit`, the extra row is dropped and
/// `has_more` is set to `true`.
pub fn from_raw(mut raw_items: Vec<T>, limit: i64, get_id: impl Fn(&T) -> Uuid) -> Self {
let has_more = raw_items.len() > limit as usize;
if has_more {
raw_items.truncate(limit as usize);
}
let next_cursor = if has_more {
raw_items.last().map(get_id)
} else {
None
};
Self {
items: raw_items,
next_cursor,
has_more,
}
}
/// Empty page (no results).
pub fn empty() -> Self {
Self {
items: Vec::new(),
next_cursor: None,
has_more: false,
}
}
}
/// Clamp a caller-requested limit to `[1, MAX_PAGE_SIZE]`, defaulting to `DEFAULT_PAGE_SIZE`.
pub fn clamp_limit(limit: Option<i64>) -> i64 {
match limit {
Some(n) if n < 1 => DEFAULT_PAGE_SIZE,
Some(n) if n > MAX_PAGE_SIZE => MAX_PAGE_SIZE,
Some(n) => n,
None => DEFAULT_PAGE_SIZE,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clamp_limit_none() {
assert_eq!(clamp_limit(None), DEFAULT_PAGE_SIZE);
}
#[test]
fn test_clamp_limit_zero() {
assert_eq!(clamp_limit(Some(0)), DEFAULT_PAGE_SIZE);
}
#[test]
fn test_clamp_limit_negative() {
assert_eq!(clamp_limit(Some(-5)), DEFAULT_PAGE_SIZE);
}
#[test]
fn test_clamp_limit_over_max() {
assert_eq!(clamp_limit(Some(200)), MAX_PAGE_SIZE);
}
#[test]
fn test_clamp_limit_valid() {
assert_eq!(clamp_limit(Some(25)), 25);
}
#[test]
fn test_cursor_page_empty() {
let page: CursorPage<String> = CursorPage::empty();
assert!(page.items.is_empty());
assert!(!page.has_more);
assert!(page.next_cursor.is_none());
}
#[test]
fn test_cursor_page_from_raw_no_overflow() {
let items = vec!["a".to_string(), "b".to_string()];
let page = CursorPage::from_raw(items, 5, |_| Uuid::nil());
assert_eq!(page.items.len(), 2);
assert!(!page.has_more);
}
}
+108
View File
@@ -0,0 +1,108 @@
//! Aggregate gRPC client holder for all appks core services.
//!
//! A single TCP `Channel` is shared across all four service clients
//! to avoid redundant connections.
use std::fs;
use std::time::Duration;
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity};
use crate::pb::core::token_service_client::TokenServiceClient;
use crate::pb::im::{
channel_service_client::ChannelServiceClient, member_service_client::MemberServiceClient,
permission_service_client::PermissionServiceClient,
};
use crate::{ImksError, ImksResult};
use super::config::RpcConfig;
/// Holds gRPC clients for all appks core services consumed by imks.
///
/// Cheaply cloneable — each inner client wraps a shared `Arc<Channel>`.
#[derive(Clone)]
pub struct AppksClients {
/// JWT token lifecycle: issue, refresh, revoke, verify, signing keys.
pub token: TokenServiceClient<Channel>,
/// Channel and category CRUD + statistics.
pub channel: ChannelServiceClient<Channel>,
/// Channel member invite / kick / join / leave.
pub member: MemberServiceClient<Channel>,
/// Permission checks and overwrite rules.
pub permission: PermissionServiceClient<Channel>,
}
impl AppksClients {
/// Connect to all appks services using a shared gRPC channel.
pub async fn connect(config: &RpcConfig) -> ImksResult<Self> {
let mut endpoint = Endpoint::from_shared(config.appks_addr.clone())
.map_err(|e| ImksError::Internal(format!("Invalid gRPC endpoint: {e}")))?
.connect_timeout(Duration::from_secs(config.connect_timeout_secs));
if config.appks_addr.starts_with("https://")
|| config.tls_ca_cert_path.is_some()
|| config.tls_client_cert_path.is_some()
|| config.tls_client_key_path.is_some()
{
endpoint = endpoint.tls_config(build_tls_config(config)?)?;
}
let channel = endpoint
.connect()
.await
.map_err(|e| ImksError::Internal(format!("gRPC connect failed: {e}")))?;
tracing::info!(addr = %config.appks_addr, "Connected to appks gRPC services");
Ok(Self {
token: TokenServiceClient::new(channel.clone()),
channel: ChannelServiceClient::new(channel.clone()),
member: MemberServiceClient::new(channel.clone()),
permission: PermissionServiceClient::new(channel),
})
}
/// Build from pre-connected clients (useful for tests with mock servers).
pub fn new(
token: TokenServiceClient<Channel>,
channel: ChannelServiceClient<Channel>,
member: MemberServiceClient<Channel>,
permission: PermissionServiceClient<Channel>,
) -> Self {
Self {
token,
channel,
member,
permission,
}
}
}
fn build_tls_config(config: &RpcConfig) -> ImksResult<ClientTlsConfig> {
let mut tls = ClientTlsConfig::new();
if let Some(domain) = &config.tls_domain_name {
tls = tls.domain_name(domain);
}
if let Some(path) = &config.tls_ca_cert_path {
let pem = fs::read(path)?;
tls = tls.ca_certificate(Certificate::from_pem(pem));
}
match (&config.tls_client_cert_path, &config.tls_client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert = fs::read(cert_path)?;
let key = fs::read(key_path)?;
tls = tls.identity(Identity::from_pem(cert, key));
}
(None, None) => {}
_ => {
return Err(ImksError::InvalidInput(
"Both APPKS_GRPC_TLS_CLIENT_CERT and APPKS_GRPC_TLS_CLIENT_KEY are required for mTLS".into(),
));
}
}
Ok(tls)
}
+65
View File
@@ -0,0 +1,65 @@
//! gRPC client configuration for connecting to appks core services.
//!
//! Reads the appks address and timeout from environment variables.
use std::env;
/// Configuration for appks gRPC connections.
#[derive(Debug, Clone)]
pub struct RpcConfig {
/// appks gRPC endpoint, e.g. `http://localhost:50051`.
pub appks_addr: String,
/// Connection establishment timeout (seconds).
pub connect_timeout_secs: u64,
/// Optional CA certificate PEM path for appks mTLS.
pub tls_ca_cert_path: Option<String>,
/// Optional client certificate PEM path for appks mTLS.
pub tls_client_cert_path: Option<String>,
/// Optional client private key PEM path for appks mTLS.
pub tls_client_key_path: Option<String>,
/// TLS domain name used for certificate verification.
pub tls_domain_name: Option<String>,
}
impl RpcConfig {
/// Build config from environment variables with defaults.
pub fn from_env() -> Self {
Self {
appks_addr: env::var("APPKS_GRPC_ADDR")
.unwrap_or_else(|_| "http://localhost:50051".to_string()),
connect_timeout_secs: env::var("APPKS_GRPC_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10),
tls_ca_cert_path: env::var("APPKS_GRPC_TLS_CA_CERT").ok(),
tls_client_cert_path: env::var("APPKS_GRPC_TLS_CLIENT_CERT").ok(),
tls_client_key_path: env::var("APPKS_GRPC_TLS_CLIENT_KEY").ok(),
tls_domain_name: env::var("APPKS_GRPC_TLS_DOMAIN").ok(),
}
}
}
impl Default for RpcConfig {
fn default() -> Self {
Self::from_env()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let cfg = RpcConfig {
appks_addr: "http://localhost:50051".to_string(),
connect_timeout_secs: 10,
tls_ca_cert_path: None,
tls_client_cert_path: None,
tls_client_key_path: None,
tls_domain_name: None,
};
assert_eq!(cfg.connect_timeout_secs, 10);
assert!(cfg.appks_addr.starts_with("http"));
}
}
+5
View File
@@ -0,0 +1,5 @@
pub mod clients;
pub mod config;
pub use clients::AppksClients;
pub use config::RpcConfig;
+39 -13
View File
@@ -5,7 +5,7 @@ use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use uuid::Uuid; use uuid::Uuid;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, SocketInfo}; use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, LocalSendFn, SocketInfo};
use crate::socket::packet::Packet; use crate::socket::packet::Packet;
pub struct LocalAdapter { pub struct LocalAdapter {
@@ -16,7 +16,7 @@ pub struct LocalAdapter {
pub socket_sids: Arc<DashMap<String, String>>, pub socket_sids: Arc<DashMap<String, String>>,
/// socket_sid → namespace path /// socket_sid → namespace path
socket_namespace: Arc<DashMap<String, String>>, socket_namespace: Arc<DashMap<String, String>>,
send_fn: Arc<dyn Fn(&str, &Packet) -> Result<(), String> + Send + Sync>, send_fn: LocalSendFn,
} }
impl LocalAdapter { impl LocalAdapter {
@@ -68,7 +68,11 @@ impl LocalAdapter {
#[async_trait] #[async_trait]
impl Adapter for LocalAdapter { impl Adapter for LocalAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { async fn broadcast(
&self,
packet: &Packet,
opts: &BroadcastOptions,
) -> Result<(), AdapterError> {
let namespace = &packet.namespace; let namespace = &packet.namespace;
let sids = self.collect_matching_sids(opts, namespace); let sids = self.collect_matching_sids(opts, namespace);
for sid in &sids { for sid in &sids {
@@ -87,9 +91,16 @@ impl Adapter for LocalAdapter {
Ok(()) Ok(())
} }
async fn register(&self, socket_sid: &str, engine_sid: &str, ns: &str) -> Result<(), AdapterError> { async fn register(
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string()); &self,
self.socket_namespace.insert(socket_sid.to_string(), ns.to_string()); socket_sid: &str,
engine_sid: &str,
ns: &str,
) -> Result<(), AdapterError> {
self.socket_sids
.insert(socket_sid.to_string(), engine_sid.to_string());
self.socket_namespace
.insert(socket_sid.to_string(), ns.to_string());
Ok(()) Ok(())
} }
@@ -99,8 +110,16 @@ impl Adapter for LocalAdapter {
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> { async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError> {
let key = Self::room_key(ns, room); let key = Self::room_key(ns, room);
self.rooms.entry(key).or_insert_with(HashSet::new).value_mut().insert(sid.to_string()); self.rooms
self.socket_rooms.entry(sid.to_string()).or_insert_with(HashSet::new).value_mut().insert(room.to_string()); .entry(key)
.or_default()
.value_mut()
.insert(sid.to_string());
self.socket_rooms
.entry(sid.to_string())
.or_default()
.value_mut()
.insert(room.to_string());
Ok(()) Ok(())
} }
@@ -137,10 +156,14 @@ impl Adapter for LocalAdapter {
} }
} }
self.socket_sids.remove(sid); self.socket_sids.remove(sid);
self.socket_namespace.remove(sid);
Ok(()) Ok(())
} }
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> { async fn fetch_sockets(
&self,
opts: &BroadcastOptions,
) -> Result<Vec<SocketInfo>, AdapterError> {
// fetch_sockets needs namespace context; use an empty namespace to match all // fetch_sockets needs namespace context; use an empty namespace to match all
// (this method is typically called for inspection, not delivery) // (this method is typically called for inspection, not delivery)
let sids: Vec<String> = if opts.rooms.is_empty() { let sids: Vec<String> = if opts.rooms.is_empty() {
@@ -164,11 +187,13 @@ impl Adapter for LocalAdapter {
continue; continue;
} }
if self.socket_sids.contains_key(sid) { if self.socket_sids.contains_key(sid) {
let namespace = self.socket_namespace let namespace = self
.socket_namespace
.get(sid) .get(sid)
.map(|r| r.value().clone()) .map(|r| r.value().clone())
.unwrap_or_default(); .unwrap_or_default();
let rooms = self.socket_rooms let rooms = self
.socket_rooms
.get(sid) .get(sid)
.map(|r| r.value().clone()) .map(|r| r.value().clone())
.unwrap_or_default(); .unwrap_or_default();
@@ -183,7 +208,8 @@ impl Adapter for LocalAdapter {
} }
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> { async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms Ok(self
.socket_rooms
.get(sid) .get(sid)
.map(|r| r.value().clone()) .map(|r| r.value().clone())
.unwrap_or_default()) .unwrap_or_default())
@@ -196,4 +222,4 @@ impl Adapter for LocalAdapter {
async fn close(&self) -> Result<(), AdapterError> { async fn close(&self) -> Result<(), AdapterError> {
Ok(()) Ok(())
} }
} }
+18 -5
View File
@@ -1,14 +1,20 @@
pub mod local; pub mod local;
pub mod redis;
pub mod nats; pub mod nats;
pub mod redis;
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use thiserror::Error; use thiserror::Error;
use crate::socket::packet::Packet; use crate::socket::packet::Packet;
/// Alias for cross-node broadcast callback functions.
pub type LocalBroadcastFn = Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>;
/// Alias for local send-to-socket callback functions.
pub type LocalSendFn = Arc<dyn Fn(&str, &Packet) -> Result<(), String> + Send + Sync>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum AdapterError { pub enum AdapterError {
#[error("Redis error: {0}")] #[error("Redis error: {0}")]
@@ -72,11 +78,13 @@ pub enum BusMessage {
#[async_trait] #[async_trait]
pub trait Adapter: Send + Sync + 'static { pub trait Adapter: Send + Sync + 'static {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError>; async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions)
-> Result<(), AdapterError>;
async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>; async fn add(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>; async fn del(&self, sid: &str, room: &str, ns: &str) -> Result<(), AdapterError>;
async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>; async fn del_all(&self, sid: &str, ns: &str) -> Result<(), AdapterError>;
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError>; async fn fetch_sockets(&self, opts: &BroadcastOptions)
-> Result<Vec<SocketInfo>, AdapterError>;
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError>; async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError>;
fn server_id(&self) -> &str; fn server_id(&self) -> &str;
async fn close(&self) -> Result<(), AdapterError>; async fn close(&self) -> Result<(), AdapterError>;
@@ -84,7 +92,12 @@ pub trait Adapter: Send + Sync + 'static {
/// Register a socket SID → engine SID mapping in the adapter. /// Register a socket SID → engine SID mapping in the adapter.
/// Must be called when a socket first connects, before any room operations. /// Must be called when a socket first connects, before any room operations.
/// The `ns` parameter is the namespace path this socket belongs to. /// The `ns` parameter is the namespace path this socket belongs to.
async fn register(&self, _socket_sid: &str, _engine_sid: &str, _ns: &str) -> Result<(), AdapterError> { async fn register(
&self,
_socket_sid: &str,
_engine_sid: &str,
_ns: &str,
) -> Result<(), AdapterError> {
Ok(()) Ok(())
} }
@@ -95,5 +108,5 @@ pub trait Adapter: Send + Sync + 'static {
} }
pub use local::LocalAdapter; pub use local::LocalAdapter;
pub use nats::NatsAdapter;
pub use redis::RedisAdapter; pub use redis::RedisAdapter;
pub use nats::NatsAdapter;
+90 -32
View File
@@ -5,7 +5,9 @@ use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo}; use crate::socket::adapter::{
Adapter, AdapterError, BroadcastOptions, BusMessage, LocalBroadcastFn, SocketInfo,
};
use crate::socket::message_bus::MessageBus; use crate::socket::message_bus::MessageBus;
use crate::socket::packet::Packet; use crate::socket::packet::Packet;
use crate::socket::parser; use crate::socket::parser;
@@ -15,11 +17,16 @@ use crate::socket::socket::Socket;
/// Only performs local dispatch — no remote state writes needed. /// Only performs local dispatch — no remote state writes needed.
async fn handle_bus_message( async fn handle_bus_message(
msg: BusMessage, msg: BusMessage,
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>, on_local_broadcast: &LocalBroadcastFn,
server_id: &str, server_id: &str,
) { ) {
match msg { match msg {
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => { BusMessage::Broadcast {
namespace: _,
packet,
opts,
server_id: sender_id,
} => {
if sender_id == server_id { if sender_id == server_id {
return; return;
} }
@@ -29,13 +36,18 @@ async fn handle_bus_message(
} }
// NATS adapter manages room state locally; cross-server join/leave/disconnect // NATS adapter manages room state locally; cross-server join/leave/disconnect
// are informational only and don't require duplicate state writes. // are informational only and don't require duplicate state writes.
BusMessage::SocketJoin { server_id: sender_id, .. } BusMessage::SocketJoin {
| BusMessage::SocketLeave { server_id: sender_id, .. } server_id: sender_id,
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => { ..
if sender_id == server_id {
return;
}
} }
| BusMessage::SocketLeave {
server_id: sender_id,
..
}
| BusMessage::SocketDisconnect {
server_id: sender_id,
..
} => if sender_id == server_id {},
} }
} }
@@ -51,7 +63,7 @@ pub struct NatsAdapter {
sockets: DashMap<String, Arc<Socket>>, sockets: DashMap<String, Arc<Socket>>,
server_id: String, server_id: String,
namespace: String, namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>, on_local_broadcast: LocalBroadcastFn,
} }
impl NatsAdapter { impl NatsAdapter {
@@ -59,7 +71,7 @@ impl NatsAdapter {
message_bus: Arc<dyn MessageBus>, message_bus: Arc<dyn MessageBus>,
server_id: String, server_id: String,
namespace: String, namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>, on_local_broadcast: LocalBroadcastFn,
) -> Self { ) -> Self {
Self { Self {
message_bus, message_bus,
@@ -133,7 +145,11 @@ impl NatsAdapter {
#[async_trait] #[async_trait]
impl Adapter for NatsAdapter { impl Adapter for NatsAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { async fn broadcast(
&self,
packet: &Packet,
opts: &BroadcastOptions,
) -> Result<(), AdapterError> {
if opts.flags.local_only { if opts.flags.local_only {
(self.on_local_broadcast)(packet, opts); (self.on_local_broadcast)(packet, opts);
return Ok(()); return Ok(());
@@ -146,8 +162,8 @@ impl Adapter for NatsAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:broadcast", self.namespace), &payload) .publish(&format!("socket.io:{}:broadcast", self.namespace), &payload)
@@ -158,20 +174,30 @@ impl Adapter for NatsAdapter {
Ok(()) Ok(())
} }
async fn register(&self, socket_sid: &str, engine_sid: &str, _ns: &str) -> Result<(), AdapterError> { async fn register(
self.socket_sids.insert(socket_sid.to_string(), engine_sid.to_string()); &self,
socket_sid: &str,
engine_sid: &str,
_ns: &str,
) -> Result<(), AdapterError> {
self.socket_sids
.insert(socket_sid.to_string(), engine_sid.to_string());
Ok(()) Ok(())
} }
async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> { async fn add(&self, sid: &str, room: &str, _ns: &str) -> Result<(), AdapterError> {
self.socket_rooms self.socket_rooms
.entry(sid.to_string()) .entry(sid.to_string())
.and_modify(|set| { set.insert(room.to_string()); }) .and_modify(|set| {
set.insert(room.to_string());
})
.or_insert_with(|| HashSet::from([room.to_string()])); .or_insert_with(|| HashSet::from([room.to_string()]));
self.rooms self.rooms
.entry(room.to_string()) .entry(room.to_string())
.and_modify(|set| { set.insert(sid.to_string()); }) .and_modify(|set| {
set.insert(sid.to_string());
})
.or_insert_with(|| HashSet::from([sid.to_string()])); .or_insert_with(|| HashSet::from([sid.to_string()]));
let msg = BusMessage::SocketJoin { let msg = BusMessage::SocketJoin {
@@ -181,8 +207,8 @@ impl Adapter for NatsAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:join", self.namespace), &payload) .publish(&format!("socket.io:{}:join", self.namespace), &payload)
@@ -196,14 +222,24 @@ impl Adapter for NatsAdapter {
if let Some(mut entry) = self.socket_rooms.get_mut(sid) { if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
entry.value_mut().remove(room); entry.value_mut().remove(room);
} }
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) { if self
.socket_rooms
.get(sid)
.map(|e| e.value().is_empty())
.unwrap_or(true)
{
self.socket_rooms.remove(sid); self.socket_rooms.remove(sid);
} }
if let Some(mut entry) = self.rooms.get_mut(room) { if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid); entry.value_mut().remove(sid);
} }
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { if self
.rooms
.get(room)
.map(|e| e.value().is_empty())
.unwrap_or(true)
{
self.rooms.remove(room); self.rooms.remove(room);
} }
@@ -214,8 +250,8 @@ impl Adapter for NatsAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:leave", self.namespace), &payload) .publish(&format!("socket.io:{}:leave", self.namespace), &payload)
@@ -231,7 +267,12 @@ impl Adapter for NatsAdapter {
if let Some(mut entry) = self.rooms.get_mut(room) { if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid); entry.value_mut().remove(sid);
} }
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { if self
.rooms
.get(room)
.map(|e| e.value().is_empty())
.unwrap_or(true)
{
self.rooms.remove(room); self.rooms.remove(room);
} }
} }
@@ -246,18 +287,24 @@ impl Adapter for NatsAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:disconnect", self.namespace), &payload) .publish(
&format!("socket.io:{}:disconnect", self.namespace),
&payload,
)
.await .await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?; .map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(()) Ok(())
} }
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> { async fn fetch_sockets(
&self,
opts: &BroadcastOptions,
) -> Result<Vec<SocketInfo>, AdapterError> {
let mut result = Vec::new(); let mut result = Vec::new();
let target_sids: HashSet<String> = if opts.rooms.is_empty() { let target_sids: HashSet<String> = if opts.rooms.is_empty() {
@@ -276,7 +323,11 @@ impl Adapter for NatsAdapter {
if opts.except.contains(&sid) { if opts.except.contains(&sid) {
continue; continue;
} }
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default(); let rooms = self
.socket_rooms
.get(&sid)
.map(|e| e.value().clone())
.unwrap_or_default();
result.push(SocketInfo { result.push(SocketInfo {
sid: sid.clone(), sid: sid.clone(),
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
@@ -288,7 +339,11 @@ impl Adapter for NatsAdapter {
} }
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> { async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default()) Ok(self
.socket_rooms
.get(sid)
.map(|e| e.value().clone())
.unwrap_or_default())
} }
fn server_id(&self) -> &str { fn server_id(&self) -> &str {
@@ -296,7 +351,10 @@ impl Adapter for NatsAdapter {
} }
async fn close(&self) -> Result<(), AdapterError> { async fn close(&self) -> Result<(), AdapterError> {
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?; self.message_bus
.close()
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(()) Ok(())
} }
} }
+107 -34
View File
@@ -7,7 +7,9 @@ use fred::clients::Client;
use fred::interfaces::{KeysInterface, SetsInterface}; use fred::interfaces::{KeysInterface, SetsInterface};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::socket::adapter::{Adapter, AdapterError, BroadcastOptions, BusMessage, SocketInfo}; use crate::socket::adapter::{
Adapter, AdapterError, BroadcastOptions, BusMessage, LocalBroadcastFn, SocketInfo,
};
use crate::socket::message_bus::MessageBus; use crate::socket::message_bus::MessageBus;
use crate::socket::packet::Packet; use crate::socket::packet::Packet;
use crate::socket::parser; use crate::socket::parser;
@@ -28,11 +30,16 @@ fn socket_rooms_key(ns: &str, sid: &str) -> String {
/// Only performs local state updates — the remote server already wrote to Redis. /// Only performs local state updates — the remote server already wrote to Redis.
async fn handle_bus_message( async fn handle_bus_message(
msg: BusMessage, msg: BusMessage,
on_local_broadcast: &Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>, on_local_broadcast: &LocalBroadcastFn,
server_id: &str, server_id: &str,
) { ) {
match msg { match msg {
BusMessage::Broadcast { namespace: _, packet, opts, server_id: sender_id } => { BusMessage::Broadcast {
namespace: _,
packet,
opts,
server_id: sender_id,
} => {
if sender_id == server_id { if sender_id == server_id {
return; return;
} }
@@ -40,13 +47,20 @@ async fn handle_bus_message(
on_local_broadcast(&decoded_packet, &opts); on_local_broadcast(&decoded_packet, &opts);
} }
} }
BusMessage::SocketJoin { server_id: sender_id, .. } BusMessage::SocketJoin {
| BusMessage::SocketLeave { server_id: sender_id, .. } server_id: sender_id,
| BusMessage::SocketDisconnect { server_id: sender_id, .. } => { ..
}
| BusMessage::SocketLeave {
server_id: sender_id,
..
}
| BusMessage::SocketDisconnect {
server_id: sender_id,
..
} => {
// Skip messages from this server; remote server already updated Redis // Skip messages from this server; remote server already updated Redis
if sender_id == server_id { if sender_id == server_id {}
return;
}
// No duplicate Redis writes — the sender already persisted the state change // No duplicate Redis writes — the sender already persisted the state change
} }
} }
@@ -58,10 +72,12 @@ pub struct RedisAdapter {
room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>, room_subscribers: DashMap<String, mpsc::Receiver<Vec<u8>>>,
socket_rooms: DashMap<String, HashSet<String>>, socket_rooms: DashMap<String, HashSet<String>>,
rooms: DashMap<String, HashSet<String>>, rooms: DashMap<String, HashSet<String>>,
/// socket_sid → engine_sid mapping for local inspection.
socket_sids: DashMap<String, String>,
sockets: DashMap<String, Arc<Socket>>, sockets: DashMap<String, Arc<Socket>>,
server_id: String, server_id: String,
namespace: String, namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>, on_local_broadcast: LocalBroadcastFn,
} }
impl RedisAdapter { impl RedisAdapter {
@@ -70,7 +86,7 @@ impl RedisAdapter {
redis_client: Client, redis_client: Client,
server_id: String, server_id: String,
namespace: String, namespace: String,
on_local_broadcast: Arc<dyn Fn(&Packet, &BroadcastOptions) + Send + Sync + 'static>, on_local_broadcast: LocalBroadcastFn,
) -> Self { ) -> Self {
Self { Self {
message_bus, message_bus,
@@ -81,6 +97,7 @@ impl RedisAdapter {
room_subscribers: DashMap::new(), room_subscribers: DashMap::new(),
socket_rooms: DashMap::new(), socket_rooms: DashMap::new(),
rooms: DashMap::new(), rooms: DashMap::new(),
socket_sids: DashMap::new(),
sockets: DashMap::new(), sockets: DashMap::new(),
} }
} }
@@ -144,7 +161,11 @@ impl RedisAdapter {
#[async_trait] #[async_trait]
impl Adapter for RedisAdapter { impl Adapter for RedisAdapter {
async fn broadcast(&self, packet: &Packet, opts: &BroadcastOptions) -> Result<(), AdapterError> { async fn broadcast(
&self,
packet: &Packet,
opts: &BroadcastOptions,
) -> Result<(), AdapterError> {
if opts.flags.local_only { if opts.flags.local_only {
(self.on_local_broadcast)(packet, opts); (self.on_local_broadcast)(packet, opts);
return Ok(()); return Ok(());
@@ -157,11 +178,11 @@ impl Adapter for RedisAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:broadcast", packet.namespace), &payload) .publish(&format!("socket.io:{}:broadcast", self.namespace), &payload)
.await .await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?; .map_err(|e| AdapterError::MessageBus(e.to_string()))?;
@@ -185,12 +206,16 @@ impl Adapter for RedisAdapter {
self.socket_rooms self.socket_rooms
.entry(sid.to_string()) .entry(sid.to_string())
.and_modify(|set| { set.insert(room.to_string()); }) .and_modify(|set| {
set.insert(room.to_string());
})
.or_insert_with(|| HashSet::from([room.to_string()])); .or_insert_with(|| HashSet::from([room.to_string()]));
self.rooms self.rooms
.entry(room.to_string()) .entry(room.to_string())
.and_modify(|set| { set.insert(sid.to_string()); }) .and_modify(|set| {
set.insert(sid.to_string());
})
.or_insert_with(|| HashSet::from([sid.to_string()])); .or_insert_with(|| HashSet::from([sid.to_string()]));
let msg = BusMessage::SocketJoin { let msg = BusMessage::SocketJoin {
@@ -200,11 +225,11 @@ impl Adapter for RedisAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:join", ns), &payload) .publish(&format!("socket.io:{}:join", self.namespace), &payload)
.await .await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?; .map_err(|e| AdapterError::MessageBus(e.to_string()))?;
@@ -228,14 +253,24 @@ impl Adapter for RedisAdapter {
if let Some(mut entry) = self.socket_rooms.get_mut(sid) { if let Some(mut entry) = self.socket_rooms.get_mut(sid) {
entry.value_mut().remove(room); entry.value_mut().remove(room);
} }
if self.socket_rooms.get(sid).map(|e| e.value().is_empty()).unwrap_or(true) { if self
.socket_rooms
.get(sid)
.map(|e| e.value().is_empty())
.unwrap_or(true)
{
self.socket_rooms.remove(sid); self.socket_rooms.remove(sid);
} }
if let Some(mut entry) = self.rooms.get_mut(room) { if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid); entry.value_mut().remove(sid);
} }
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { if self
.rooms
.get(room)
.map(|e| e.value().is_empty())
.unwrap_or(true)
{
self.rooms.remove(room); self.rooms.remove(room);
} }
@@ -246,11 +281,11 @@ impl Adapter for RedisAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:leave", ns), &payload) .publish(&format!("socket.io:{}:leave", self.namespace), &payload)
.await .await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?; .map_err(|e| AdapterError::MessageBus(e.to_string()))?;
@@ -263,7 +298,12 @@ impl Adapter for RedisAdapter {
if let Some(mut entry) = self.rooms.get_mut(room) { if let Some(mut entry) = self.rooms.get_mut(room) {
entry.value_mut().remove(sid); entry.value_mut().remove(sid);
} }
if self.rooms.get(room).map(|e| e.value().is_empty()).unwrap_or(true) { if self
.rooms
.get(room)
.map(|e| e.value().is_empty())
.unwrap_or(true)
{
self.rooms.remove(room); self.rooms.remove(room);
} }
@@ -280,6 +320,7 @@ impl Adapter for RedisAdapter {
.await .await
.map_err(|e| AdapterError::Redis(e.to_string()))?; .map_err(|e| AdapterError::Redis(e.to_string()))?;
self.socket_sids.remove(sid);
self.sockets.remove(sid); self.sockets.remove(sid);
let msg = BusMessage::SocketDisconnect { let msg = BusMessage::SocketDisconnect {
@@ -288,22 +329,43 @@ impl Adapter for RedisAdapter {
server_id: self.server_id.clone(), server_id: self.server_id.clone(),
}; };
let payload = serde_json::to_vec(&msg) let payload =
.map_err(|e| AdapterError::Serialization(e.to_string()))?; serde_json::to_vec(&msg).map_err(|e| AdapterError::Serialization(e.to_string()))?;
self.message_bus self.message_bus
.publish(&format!("socket.io:{}:disconnect", ns), &payload) .publish(
&format!("socket.io:{}:disconnect", self.namespace),
&payload,
)
.await .await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?; .map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(()) Ok(())
} }
async fn fetch_sockets(&self, opts: &BroadcastOptions) -> Result<Vec<SocketInfo>, AdapterError> { async fn register(
&self,
socket_sid: &str,
engine_sid: &str,
_ns: &str,
) -> Result<(), AdapterError> {
self.socket_sids
.insert(socket_sid.to_string(), engine_sid.to_string());
Ok(())
}
async fn unregister(&self, socket_sid: &str, ns: &str) -> Result<(), AdapterError> {
self.del_all(socket_sid, ns).await
}
async fn fetch_sockets(
&self,
opts: &BroadcastOptions,
) -> Result<Vec<SocketInfo>, AdapterError> {
let mut result = Vec::new(); let mut result = Vec::new();
let target_sids: HashSet<String> = if opts.rooms.is_empty() { let target_sids: HashSet<String> = if opts.rooms.is_empty() {
self.sockets.iter().map(|e| e.key().clone()).collect() self.socket_sids.iter().map(|e| e.key().clone()).collect()
} else { } else {
let mut sids = HashSet::new(); let mut sids = HashSet::new();
for room in &opts.rooms { for room in &opts.rooms {
@@ -318,7 +380,11 @@ impl Adapter for RedisAdapter {
if opts.except.contains(&sid) { if opts.except.contains(&sid) {
continue; continue;
} }
let rooms = self.socket_rooms.get(&sid).map(|e| e.value().clone()).unwrap_or_default(); let rooms = self
.socket_rooms
.get(&sid)
.map(|e| e.value().clone())
.unwrap_or_default();
result.push(SocketInfo { result.push(SocketInfo {
sid: sid.clone(), sid: sid.clone(),
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
@@ -330,7 +396,11 @@ impl Adapter for RedisAdapter {
} }
async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> { async fn socket_rooms(&self, sid: &str) -> Result<HashSet<String>, AdapterError> {
Ok(self.socket_rooms.get(sid).map(|e| e.value().clone()).unwrap_or_default()) Ok(self
.socket_rooms
.get(sid)
.map(|e| e.value().clone())
.unwrap_or_default())
} }
fn server_id(&self) -> &str { fn server_id(&self) -> &str {
@@ -338,7 +408,10 @@ impl Adapter for RedisAdapter {
} }
async fn close(&self) -> Result<(), AdapterError> { async fn close(&self) -> Result<(), AdapterError> {
self.message_bus.close().await.map_err(|e| AdapterError::MessageBus(e.to_string()))?; self.message_bus
.close()
.await
.map_err(|e| AdapterError::MessageBus(e.to_string()))?;
Ok(()) Ok(())
} }
} }
+2 -2
View File
@@ -1,5 +1,5 @@
pub mod redis;
pub mod nats; pub mod nats;
pub mod redis;
use async_trait::async_trait; use async_trait::async_trait;
use thiserror::Error; use thiserror::Error;
@@ -27,5 +27,5 @@ pub trait MessageBus: Send + Sync + 'static {
async fn close(&self) -> Result<(), MessageBusError>; async fn close(&self) -> Result<(), MessageBusError>;
} }
pub use nats::NatsMessageBus;
pub use redis::RedisMessageBus; pub use redis::RedisMessageBus;
pub use nats::NatsMessageBus;
+3 -2
View File
@@ -34,7 +34,8 @@ impl MessageBus for NatsMessageBus {
async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> { async fn subscribe(&self, channel: &str) -> Result<mpsc::Receiver<Vec<u8>>, MessageBusError> {
let (tx, rx) = mpsc::channel::<Vec<u8>>(256); let (tx, rx) = mpsc::channel::<Vec<u8>>(256);
let mut subscriber = self.client let mut subscriber = self
.client
.subscribe(channel.to_string()) .subscribe(channel.to_string())
.await .await
.map_err(|e| MessageBusError::Nats(e.to_string()))?; .map_err(|e| MessageBusError::Nats(e.to_string()))?;
@@ -85,4 +86,4 @@ impl MessageBus for NatsMessageBus {
self.shutdowns.clear(); self.shutdowns.clear();
Ok(()) Ok(())
} }
} }
+5 -6
View File
@@ -13,8 +13,8 @@ pub struct RedisMessageBus {
impl RedisMessageBus { impl RedisMessageBus {
pub async fn new(redis_url: &str) -> Result<Self, MessageBusError> { pub async fn new(redis_url: &str) -> Result<Self, MessageBusError> {
let config = Config::from_url(redis_url) let config =
.map_err(|e| MessageBusError::Redis(e.to_string()))?; Config::from_url(redis_url).map_err(|e| MessageBusError::Redis(e.to_string()))?;
let client = Client::new(config.clone(), None, None, None); let client = Client::new(config.clone(), None, None, None);
let subscriber = SubscriberClient::new(config, None, None, None); let subscriber = SubscriberClient::new(config, None, None, None);
@@ -64,9 +64,8 @@ impl MessageBus for RedisMessageBus {
tokio::spawn(async move { tokio::spawn(async move {
while let Ok(message) = message_rx.recv().await { while let Ok(message) = message_rx.recv().await {
if &message.channel == &channel_owned { if message.channel == channel_owned {
let data: Vec<u8> = FromValue::from_value(message.value) let data: Vec<u8> = FromValue::from_value(message.value).unwrap_or_default();
.unwrap_or_default();
if tx.send(data).await.is_err() { if tx.send(data).await.is_err() {
break; break;
} }
@@ -96,4 +95,4 @@ impl MessageBus for RedisMessageBus {
.map_err(|e| MessageBusError::Redis(e.to_string()))?; .map_err(|e| MessageBusError::Redis(e.to_string()))?;
Ok(()) Ok(())
} }
} }
+11 -5
View File
@@ -5,12 +5,18 @@ pub mod packet;
pub mod parser; pub mod parser;
pub mod server; pub mod server;
pub mod session_store; pub mod session_store;
#[allow(clippy::module_inception)]
pub mod socket; pub mod socket;
pub use adapter::{Adapter, AdapterError, BroadcastOptions, BroadcastFlags, BusMessage, LocalAdapter, RedisAdapter, NatsAdapter, SocketInfo}; pub use adapter::{
pub use message_bus::{MessageBus, MessageBusError, RedisMessageBus, NatsMessageBus}; Adapter, AdapterError, BroadcastFlags, BroadcastOptions, BusMessage, LocalAdapter, NatsAdapter,
pub use namespace::{is_valid_namespace, Namespace, NamespaceManager}; RedisAdapter, SocketInfo,
};
pub use message_bus::{MessageBus, MessageBusError, NatsMessageBus, RedisMessageBus};
pub use namespace::{Namespace, NamespaceManager, is_valid_namespace};
pub use packet::{Packet, PacketType}; pub use packet::{Packet, PacketType};
pub use server::{SocketServer, SocketServerBuilder}; pub use server::{SocketServer, SocketServerBuilder};
pub use session_store::{InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait}; pub use session_store::{
pub use socket::Socket; InMemorySessionStore, RedisSessionStore, SessionError, SessionInfo, SessionStoreTrait,
};
pub use socket::Socket;
+148 -21
View File
@@ -4,12 +4,13 @@ use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::socket::adapter::{Adapter, BroadcastOptions, BroadcastFlags}; use crate::socket::adapter::{Adapter, BroadcastFlags, BroadcastOptions};
use crate::socket::packet::Packet; use crate::socket::packet::Packet;
use crate::socket::socket::Socket; use crate::socket::socket::Socket;
pub type EventHandler = Arc<dyn Fn(&Socket, &serde_json::Value) + Send + Sync>; pub type EventHandler = Arc<dyn Fn(Arc<Socket>, &serde_json::Value) + Send + Sync>;
type ConnectHandler = Arc<dyn Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync>; type ConnectHandler =
Arc<dyn Fn(&Socket, Option<&serde_json::Value>) -> Result<(), String> + Send + Sync>;
pub struct Namespace { pub struct Namespace {
pub path: String, pub path: String,
@@ -19,6 +20,8 @@ pub struct Namespace {
engine_to_socket: DashMap<String, String>, engine_to_socket: DashMap<String, String>,
handlers: RwLock<HashMap<String, Vec<EventHandler>>>, handlers: RwLock<HashMap<String, Vec<EventHandler>>>,
connect_handler: RwLock<Option<ConnectHandler>>, connect_handler: RwLock<Option<ConnectHandler>>,
rooms: DashMap<String, HashSet<String>>,
socket_rooms: DashMap<String, HashSet<String>>,
pub(crate) adapter: RwLock<Option<Arc<dyn Adapter>>>, pub(crate) adapter: RwLock<Option<Arc<dyn Adapter>>>,
} }
@@ -30,6 +33,8 @@ impl Namespace {
engine_to_socket: DashMap::new(), engine_to_socket: DashMap::new(),
handlers: RwLock::new(HashMap::new()), handlers: RwLock::new(HashMap::new()),
connect_handler: RwLock::new(None), connect_handler: RwLock::new(None),
rooms: DashMap::new(),
socket_rooms: DashMap::new(),
adapter: RwLock::new(None), adapter: RwLock::new(None),
} }
} }
@@ -40,11 +45,15 @@ impl Namespace {
} }
/// Add a socket to this namespace. Returns Err if the connect handler rejects. /// Add a socket to this namespace. Returns Err if the connect handler rejects.
pub async fn add_socket(&self, socket: Arc<Socket>) -> Result<(), String> { pub async fn add_socket(
&self,
socket: Arc<Socket>,
auth_data: Option<&serde_json::Value>,
) -> Result<(), String> {
// Run connect handler before adding to storage // Run connect handler before adding to storage
let handler = self.connect_handler.read().await; let handler = self.connect_handler.read().await;
if let Some(ref h) = *handler { if let Some(ref h) = *handler {
h(&socket, None)?; h(&socket, auth_data)?;
} }
drop(handler); drop(handler);
@@ -53,10 +62,10 @@ impl Namespace {
// Register with adapter (socket_sid → engine_sid mapping) // Register with adapter (socket_sid → engine_sid mapping)
let adapter = self.adapter.read().await; let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter { if let Some(ref adapter) = *adapter
if let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await { && let Err(e) = adapter.register(&socket_sid, &engine_sid, &self.path).await
tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e); {
} tracing::warn!("Adapter register error for socket {}: {}", socket_sid, e);
} }
// Store socket by socket_sid, plus reverse index // Store socket by socket_sid, plus reverse index
@@ -69,12 +78,13 @@ impl Namespace {
pub async fn remove_socket_by_sid(&self, socket_sid: &str) { pub async fn remove_socket_by_sid(&self, socket_sid: &str) {
if let Some((_, socket)) = self.sockets.remove(socket_sid) { if let Some((_, socket)) = self.sockets.remove(socket_sid) {
self.engine_to_socket.remove(&socket.engine_sid); self.engine_to_socket.remove(&socket.engine_sid);
self.remove_socket_from_local_rooms(socket_sid);
let adapter = self.adapter.read().await; let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter { if let Some(ref adapter) = *adapter
if let Err(e) = adapter.del_all(socket_sid, &self.path).await { && let Err(e) = adapter.del_all(socket_sid, &self.path).await
tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e); {
} tracing::warn!("Adapter del_all error for socket {}: {}", socket_sid, e);
} }
} }
} }
@@ -130,7 +140,12 @@ impl Namespace {
} }
} }
pub async fn emit_to_room(&self, room: &str, event: impl Into<String>, data: serde_json::Value) { pub async fn emit_to_room(
&self,
room: &str,
event: impl Into<String>,
data: serde_json::Value,
) {
let event_name = event.into(); let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None); let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
@@ -145,20 +160,64 @@ impl Namespace {
tracing::warn!("Adapter broadcast to room error: {}", e); tracing::warn!("Adapter broadcast to room error: {}", e);
} }
} else { } else {
self.emit_local(&packet); self.emit_local_to_room(&packet, room, &HashSet::new());
} }
} }
pub fn emit_local(&self, packet: &Packet) { pub fn emit_local(&self, packet: &Packet) {
for entry in self.sockets.iter() { for entry in self.sockets.iter() {
let socket = entry.value(); self.send_local_packet(entry.value(), packet);
if socket.send_packet(packet).is_err() { }
tracing::warn!("Failed to send event to socket {}", socket.sid); }
pub fn emit_local_filtered(&self, packet: &Packet, opts: &BroadcastOptions) {
if opts.rooms.is_empty() {
for entry in self.sockets.iter() {
if !opts.except.contains(entry.key()) {
self.send_local_packet(entry.value(), packet);
}
}
return;
}
let mut target_sids = HashSet::new();
for room in &opts.rooms {
if let Some(room_sids) = self.rooms.get(room) {
target_sids.extend(room_sids.value().iter().cloned());
}
}
for sid in target_sids {
if opts.except.contains(&sid) {
continue;
}
if let Some(socket) = self.get_socket(&sid) {
self.send_local_packet(&socket, packet);
} }
} }
} }
pub async fn emit_to(&self, socket_sid: &str, event: impl Into<String>, data: serde_json::Value) { fn emit_local_to_room(&self, packet: &Packet, room: &str, except: &HashSet<String>) {
let opts = BroadcastOptions {
rooms: HashSet::from([room.to_string()]),
except: except.clone(),
flags: BroadcastFlags::default(),
};
self.emit_local_filtered(packet, &opts);
}
fn send_local_packet(&self, socket: &Socket, packet: &Packet) {
if socket.send_packet(packet).is_err() {
tracing::warn!("Failed to send event to socket {}", socket.sid);
}
}
pub async fn emit_to(
&self,
socket_sid: &str,
event: impl Into<String>,
data: serde_json::Value,
) {
if let Some(socket) = self.get_socket(socket_sid) { if let Some(socket) = self.get_socket(socket_sid) {
let event_name = event.into(); let event_name = event.into();
let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None); let packet = Packet::event(&self.path, serde_json::json!([event_name, data]), None);
@@ -168,11 +227,79 @@ impl Namespace {
} }
} }
pub async fn handle_event(&self, socket: &Socket, event: &str, data: &serde_json::Value) { pub async fn handle_event(&self, socket: Arc<Socket>, event: &str, data: &serde_json::Value) {
let handlers = self.handlers.read().await; let handlers = self.handlers.read().await;
if let Some(event_handlers) = handlers.get(event) { if let Some(event_handlers) = handlers.get(event) {
for handler in event_handlers { for handler in event_handlers {
handler(socket, data); handler(Arc::clone(&socket), data);
}
}
}
pub async fn join_room(&self, socket_sid: &str, room: &str) -> crate::ImksResult<()> {
if !self.sockets.contains_key(socket_sid) {
return Err(crate::ImksError::SocketNotFound(socket_sid.to_string()));
}
self.rooms
.entry(room.to_string())
.or_default()
.value_mut()
.insert(socket_sid.to_string());
self.socket_rooms
.entry(socket_sid.to_string())
.or_default()
.value_mut()
.insert(room.to_string());
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter
&& let Err(e) = adapter.add(socket_sid, room, &self.path).await
{
self.remove_local_room(socket_sid, room);
return Err(e.into());
}
Ok(())
}
pub async fn leave_room(&self, socket_sid: &str, room: &str) -> crate::ImksResult<()> {
let adapter = self.adapter.read().await;
if let Some(ref adapter) = *adapter {
adapter.del(socket_sid, room, &self.path).await?;
}
self.remove_local_room(socket_sid, room);
Ok(())
}
fn remove_local_room(&self, socket_sid: &str, room: &str) {
if let Some(mut sids) = self.rooms.get_mut(room) {
sids.value_mut().remove(socket_sid);
if sids.value().is_empty() {
drop(sids);
self.rooms.remove(room);
}
}
if let Some(mut rooms) = self.socket_rooms.get_mut(socket_sid) {
rooms.value_mut().remove(room);
if rooms.value().is_empty() {
drop(rooms);
self.socket_rooms.remove(socket_sid);
}
}
}
fn remove_socket_from_local_rooms(&self, socket_sid: &str) {
if let Some((_, rooms)) = self.socket_rooms.remove(socket_sid) {
for room in rooms {
if let Some(mut sids) = self.rooms.get_mut(&room) {
sids.value_mut().remove(socket_sid);
if sids.value().is_empty() {
drop(sids);
self.rooms.remove(&room);
}
}
} }
} }
} }
+31 -29
View File
@@ -24,19 +24,18 @@ pub fn encode(packet: &Packet) -> String {
if let Some(ref data) = packet.data { if let Some(ref data) = packet.data {
if packet.has_binary() { if packet.has_binary() {
let data_with_placeholders = replace_binary_with_placeholders(data, packet.attachment_count()); let data_with_placeholders =
let encoded_data = serde_json::to_string(&data_with_placeholders) replace_binary_with_placeholders(data, packet.attachment_count());
.unwrap_or_else(|e| { let encoded_data = serde_json::to_string(&data_with_placeholders).unwrap_or_else(|e| {
tracing::error!("Failed to serialize socket packet data: {}", e); tracing::error!("Failed to serialize socket packet data: {}", e);
"null".to_string() "null".to_string()
}); });
result.push_str(&encoded_data); result.push_str(&encoded_data);
} else { } else {
let encoded_data = serde_json::to_string(data) let encoded_data = serde_json::to_string(data).unwrap_or_else(|e| {
.unwrap_or_else(|e| { tracing::error!("Failed to serialize socket packet data: {}", e);
tracing::error!("Failed to serialize socket packet data: {}", e); "null".to_string()
"null".to_string() });
});
result.push_str(&encoded_data); result.push_str(&encoded_data);
} }
} }
@@ -67,7 +66,8 @@ pub fn decode(input: &str) -> Result<Packet, PacketError> {
let type_char = chars.next().ok_or(PacketError::Empty)?; let type_char = chars.next().ok_or(PacketError::Empty)?;
let packet_type = PacketType::try_from(type_char)?; let packet_type = PacketType::try_from(type_char)?;
let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck) { let attachment_count = if matches!(packet_type, PacketType::BinaryEvent | PacketType::BinaryAck)
{
let mut count_str = String::new(); let mut count_str = String::new();
while let Some(&c) = chars.peek() { while let Some(&c) = chars.peek() {
if c == '-' { if c == '-' {
@@ -126,7 +126,11 @@ pub fn decode(input: &str) -> Result<Packet, PacketError> {
id, id,
attachments: Vec::new(), attachments: Vec::new(),
// Store attachment_count for binary packets; actual attachments come via decode_with_attachments // Store attachment_count for binary packets; actual attachments come via decode_with_attachments
expected_attachments: if attachment_count > 0 { Some(attachment_count) } else { None }, expected_attachments: if attachment_count > 0 {
Some(attachment_count)
} else {
None
},
}) })
} }
@@ -144,10 +148,10 @@ pub fn decode_with_attachments(
packet.attachments = attachments; packet.attachments = attachments;
packet.expected_attachments = None; packet.expected_attachments = None;
if packet.has_binary() { if packet.has_binary()
if let Some(ref data) = packet.data { && let Some(ref data) = packet.data
packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments)); {
} packet.data = Some(replace_placeholders_with_binary(data, &packet.attachments));
} }
Ok(packet) Ok(packet)
@@ -204,12 +208,12 @@ fn replace_binary_with_placeholders(value: &Value, total_attachments: usize) ->
} }
} }
fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut usize) -> Value { fn replace_binary_with_placeholders_inner(value: &Value, _placeholder_idx: &mut usize) -> Value {
match value { match value {
Value::Array(arr) => { Value::Array(arr) => {
let new_arr: Vec<Value> = arr let new_arr: Vec<Value> = arr
.iter() .iter()
.map(|v| replace_binary_with_placeholders_inner(v, placeholder_idx)) .map(|v| replace_binary_with_placeholders_inner(v, _placeholder_idx))
.collect(); .collect();
Value::Array(new_arr) Value::Array(new_arr)
} }
@@ -218,7 +222,7 @@ fn replace_binary_with_placeholders_inner(value: &Value, placeholder_idx: &mut u
for (k, v) in map { for (k, v) in map {
new_map.insert( new_map.insert(
k.clone(), k.clone(),
replace_binary_with_placeholders_inner(v, placeholder_idx), replace_binary_with_placeholders_inner(v, _placeholder_idx),
); );
} }
Value::Object(new_map) Value::Object(new_map)
@@ -236,15 +240,13 @@ fn replace_placeholders_with_binary(value: &Value, attachments: &[Vec<u8>]) -> V
// Check if this is a placeholder object: { "_placeholder": true, "num": N } // Check if this is a placeholder object: { "_placeholder": true, "num": N }
if let (Some(Value::Bool(true)), Some(Value::Number(num))) = if let (Some(Value::Bool(true)), Some(Value::Number(num))) =
(map.get("_placeholder"), map.get("num")) (map.get("_placeholder"), map.get("num"))
&& let Some(idx) = num.as_u64()
&& let Some(attachment) = attachments.get(idx as usize)
{ {
if let Some(idx) = num.as_u64() { return Value::String(base64::Engine::encode(
if let Some(attachment) = attachments.get(idx as usize) { &base64::engine::general_purpose::STANDARD,
return Value::String(base64::Engine::encode( attachment,
&base64::engine::general_purpose::STANDARD, ));
attachment,
));
}
}
} }
let mut new_map = serde_json::Map::new(); let mut new_map = serde_json::Map::new();
@@ -389,4 +391,4 @@ mod tests {
assert_eq!(packet.expected_attachments, Some(1)); assert_eq!(packet.expected_attachments, Some(1));
assert_eq!(packet.namespace, "/"); assert_eq!(packet.namespace, "/");
} }
} }
+55 -41
View File
@@ -103,8 +103,14 @@ impl SocketServerBuilder {
let adapter = adapter_clone.clone(); let adapter = adapter_clone.clone();
tokio::spawn(async move { tokio::spawn(async move {
handle_engine_message( handle_engine_message(
sid, engine_packet, &namespaces, &socket_txs, &engine_store, &adapter, sid,
).await; engine_packet,
&namespaces,
&socket_txs,
&engine_store,
&adapter,
)
.await;
}); });
}, },
)); ));
@@ -136,10 +142,18 @@ async fn handle_engine_message(
adapter: &Arc<dyn Adapter>, adapter: &Arc<dyn Adapter>,
) { ) {
if let EnginePacketData::Text(ref text) = engine_packet.data { if let EnginePacketData::Text(ref text) = engine_packet.data {
if let Ok(socket_packet) = parser::decode(text) { match parser::decode(text) {
match socket_packet.packet_type { Ok(socket_packet) => match socket_packet.packet_type {
PacketType::Connect => { PacketType::Connect => {
handle_connect(&engine_sid, &socket_packet, namespaces, socket_txs, engine_store, adapter).await; handle_connect(
&engine_sid,
&socket_packet,
namespaces,
socket_txs,
engine_store,
adapter,
)
.await;
} }
PacketType::Disconnect => { PacketType::Disconnect => {
handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs); handle_disconnect(&engine_sid, &socket_packet, namespaces, socket_txs);
@@ -151,6 +165,9 @@ async fn handle_engine_message(
handle_ack(&engine_sid, &socket_packet); handle_ack(&engine_sid, &socket_packet);
} }
_ => {} _ => {}
},
Err(e) => {
tracing::warn!(engine_sid = %engine_sid, error = %e, "Invalid Socket.IO packet");
} }
} }
} }
@@ -166,22 +183,21 @@ async fn handle_connect(
) { ) {
// Validate namespace path to prevent DoS via arbitrary namespace creation // Validate namespace path to prevent DoS via arbitrary namespace creation
if !crate::socket::namespace::is_valid_namespace(&packet.namespace) { if !crate::socket::namespace::is_valid_namespace(&packet.namespace) {
tracing::warn!("Rejected connect with invalid namespace: {}", packet.namespace); tracing::warn!(
"Rejected connect with invalid namespace: {}",
packet.namespace
);
return; return;
} }
let namespace = namespaces.get_or_create_namespace(&packet.namespace); let namespace = namespaces.get_or_create_namespace(&packet.namespace);
// Ensure newly created namespaces get the shared adapter // Ensure newly created namespaces get the shared adapter before registration.
{ {
let ns_adapter = namespace.adapter.read().await; let ns_adapter = namespace.adapter.read().await;
if ns_adapter.is_none() { if ns_adapter.is_none() {
drop(ns_adapter); drop(ns_adapter);
let adapter_ref = adapter.clone(); namespace.set_adapter(adapter.clone()).await;
let ns_clone = namespace.clone();
tokio::spawn(async move {
ns_clone.set_adapter(adapter_ref).await;
});
} }
} }
@@ -198,7 +214,10 @@ async fn handle_connect(
// Run connect handler and add to namespace. // Run connect handler and add to namespace.
// If the handler rejects, clean up and do NOT send a Connect response. // If the handler rejects, clean up and do NOT send a Connect response.
if let Err(msg) = namespace.add_socket(socket.clone()).await { if let Err(msg) = namespace
.add_socket(socket.clone(), packet.data.as_ref())
.await
{
tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg); tracing::warn!("Socket {} connection rejected: {}", socket_sid, msg);
socket_txs.remove(&socket_sid); socket_txs.remove(&socket_sid);
return; return;
@@ -227,7 +246,9 @@ async fn handle_connect(
} }
// Forwarding task ended — ensure socket is cleaned up from namespace // Forwarding task ended — ensure socket is cleaned up from namespace
socket_txs_clone.remove(&socket_sid_clone); socket_txs_clone.remove(&socket_sid_clone);
namespace_clone.remove_socket_by_sid(&socket_sid_clone).await; namespace_clone
.remove_socket_by_sid(&socket_sid_clone)
.await;
}); });
// Send Connect response (only after handler passed) // Send Connect response (only after handler passed)
@@ -260,34 +281,27 @@ fn handle_disconnect(
} }
} }
fn handle_event( fn handle_event(engine_sid: &str, packet: &Packet, namespaces: &Arc<NamespaceManager>) {
engine_sid: &str, if let Some(namespace) = namespaces.get_namespace(&packet.namespace)
packet: &Packet, && let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid)
namespaces: &Arc<NamespaceManager>, && let Some(ref data) = packet.data
) { && let Some(arr) = data.as_array()
if let Some(namespace) = namespaces.get_namespace(&packet.namespace) { && let Some(event) = arr.first().and_then(|v| v.as_str())
if let Some(socket) = namespace.get_socket_by_engine_sid(engine_sid) { {
if let Some(ref data) = packet.data { let event_data = if arr.len() > 1 {
if let Some(arr) = data.as_array() { serde_json::Value::Array(arr[1..].to_vec())
if let Some(event) = arr.first().and_then(|v| v.as_str()) { } else {
let event_data = if arr.len() > 1 { serde_json::Value::Null
serde_json::Value::Array(arr[1..].to_vec()) };
} else {
serde_json::Value::Null
};
let namespace_clone = namespace.clone(); let namespace_clone = namespace.clone();
let event = event.to_string(); let event = event.to_string();
let socket_clone = socket.clone(); let socket_clone = socket.clone();
tokio::spawn(async move { tokio::spawn(async move {
namespace_clone namespace_clone
.handle_event(&socket_clone, &event, &event_data) .handle_event(socket_clone, &event, &event_data)
.await; .await;
}); });
}
}
}
}
} }
} }
+7 -2
View File
@@ -33,7 +33,12 @@ fn now_millis() -> u64 {
#[async_trait] #[async_trait]
impl SessionStoreTrait for InMemorySessionStore { impl SessionStoreTrait for InMemorySessionStore {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> { async fn create(
&self,
sid: &str,
transport: &str,
server_id: &str,
) -> Result<(), SessionError> {
let info = SessionInfo { let info = SessionInfo {
sid: sid.to_string(), sid: sid.to_string(),
transport: transport.to_string(), transport: transport.to_string(),
@@ -85,4 +90,4 @@ impl SessionStoreTrait for InMemorySessionStore {
async fn exists(&self, sid: &str) -> Result<bool, SessionError> { async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
Ok(self.sessions.contains_key(sid)) Ok(self.sessions.contains_key(sid))
} }
} }
+3 -2
View File
@@ -28,7 +28,8 @@ pub struct SessionInfo {
#[async_trait] #[async_trait]
pub trait SessionStoreTrait: Send + Sync + 'static { pub trait SessionStoreTrait: Send + Sync + 'static {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError>; async fn create(&self, sid: &str, transport: &str, server_id: &str)
-> Result<(), SessionError>;
async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError>; async fn get(&self, sid: &str) -> Result<Option<SessionInfo>, SessionError>;
async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>; async fn set_state(&self, sid: &str, state: &str) -> Result<(), SessionError>;
async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>; async fn set_transport(&self, sid: &str, transport: &str) -> Result<(), SessionError>;
@@ -38,4 +39,4 @@ pub trait SessionStoreTrait: Send + Sync + 'static {
} }
pub use memory::InMemorySessionStore; pub use memory::InMemorySessionStore;
pub use redis::RedisSessionStore; pub use redis::RedisSessionStore;
+19 -6
View File
@@ -36,7 +36,12 @@ impl RedisSessionStore {
#[async_trait] #[async_trait]
impl SessionStoreTrait for RedisSessionStore { impl SessionStoreTrait for RedisSessionStore {
async fn create(&self, sid: &str, transport: &str, server_id: &str) -> Result<(), SessionError> { async fn create(
&self,
sid: &str,
transport: &str,
server_id: &str,
) -> Result<(), SessionError> {
let key = self.key(sid); let key = self.key(sid);
let now = now_millis(); let now = now_millis();
@@ -67,7 +72,8 @@ impl SessionStoreTrait for RedisSessionStore {
// Use hgetall directly — if the key doesn't exist Redis returns an empty map. // Use hgetall directly — if the key doesn't exist Redis returns an empty map.
// This avoids the TOCTOU race between EXISTS and HGETALL. // This avoids the TOCTOU race between EXISTS and HGETALL.
let values: std::collections::HashMap<String, String> = self.client let values: std::collections::HashMap<String, String> = self
.client
.hgetall::<std::collections::HashMap<String, String>, _>(&key) .hgetall::<std::collections::HashMap<String, String>, _>(&key)
.await .await
.map_err(|e| SessionError::Redis(e.to_string()))?; .map_err(|e| SessionError::Redis(e.to_string()))?;
@@ -81,8 +87,14 @@ impl SessionStoreTrait for RedisSessionStore {
transport: values.get("transport").cloned().unwrap_or_default(), transport: values.get("transport").cloned().unwrap_or_default(),
state: values.get("state").cloned().unwrap_or_default(), state: values.get("state").cloned().unwrap_or_default(),
server_id: values.get("server_id").cloned().unwrap_or_default(), server_id: values.get("server_id").cloned().unwrap_or_default(),
created_at: values.get("created_at").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0), created_at: values
last_ping: values.get("last_ping").and_then(|v| v.parse::<u64>().ok()).unwrap_or(0), .get("created_at")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0),
last_ping: values
.get("last_ping")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0),
}; };
Ok(Some(info)) Ok(Some(info))
@@ -154,11 +166,12 @@ impl SessionStoreTrait for RedisSessionStore {
async fn exists(&self, sid: &str) -> Result<bool, SessionError> { async fn exists(&self, sid: &str) -> Result<bool, SessionError> {
let key = self.key(sid); let key = self.key(sid);
let exists: bool = self.client let exists: bool = self
.client
.exists::<bool, _>(&key) .exists::<bool, _>(&key)
.await .await
.map_err(|e| SessionError::Redis(e.to_string()))?; .map_err(|e| SessionError::Redis(e.to_string()))?;
Ok(exists) Ok(exists)
} }
} }
+27 -2
View File
@@ -1,6 +1,8 @@
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use uuid::Uuid;
use crate::socket::packet::Packet; use crate::socket::packet::Packet;
@@ -8,10 +10,13 @@ pub struct Socket {
pub sid: String, pub sid: String,
pub namespace: String, pub namespace: String,
pub engine_sid: String, pub engine_sid: String,
/// Authenticated user ID, set once during `on_connect`.
user_id: OnceLock<Uuid>,
ack_id: AtomicU64, ack_id: AtomicU64,
tx: mpsc::Sender<Packet>, tx: mpsc::Sender<Packet>,
} }
#[allow(clippy::result_large_err)]
impl Socket { impl Socket {
pub fn new( pub fn new(
sid: String, sid: String,
@@ -24,10 +29,22 @@ impl Socket {
namespace, namespace,
engine_sid, engine_sid,
ack_id: AtomicU64::new(0), ack_id: AtomicU64::new(0),
user_id: OnceLock::new(),
tx, tx,
} }
} }
/// Set the authenticated user ID after JWT verification.
/// Safe to call once; subsequent calls are ignored.
pub fn set_user_id(&self, id: Uuid) {
let _ = self.user_id.set(id);
}
/// Get the authenticated user ID, if set.
pub fn user_id(&self) -> Option<Uuid> {
self.user_id.get().copied()
}
pub fn next_ack_id(&self) -> u64 { pub fn next_ack_id(&self) -> u64 {
self.ack_id.fetch_add(1, Ordering::SeqCst) self.ack_id.fetch_add(1, Ordering::SeqCst)
} }
@@ -36,7 +53,11 @@ impl Socket {
self.tx.try_send(packet.clone()) self.tx.try_send(packet.clone())
} }
pub fn emit(&self, event: impl Into<String>, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> { pub fn emit(
&self,
event: impl Into<String>,
data: serde_json::Value,
) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::event( let packet = Packet::event(
&self.namespace, &self.namespace,
serde_json::json!([event.into(), data]), serde_json::json!([event.into(), data]),
@@ -65,7 +86,11 @@ impl Socket {
self.send_packet(&packet) self.send_packet(&packet)
} }
pub fn send_ack(&self, id: u64, data: serde_json::Value) -> Result<(), mpsc::error::TrySendError<Packet>> { pub fn send_ack(
&self,
id: u64,
data: serde_json::Value,
) -> Result<(), mpsc::error::TrySendError<Packet>> {
let packet = Packet::ack(&self.namespace, data, id); let packet = Packet::ack(&self.namespace, data, id);
self.send_packet(&packet) self.send_packet(&packet)
} }
+221
View File
@@ -0,0 +1,221 @@
//! Forum article event handlers on `MessageService`.
//!
//! Articles are long-form posts in forum channels. Creating an article
//! creates both a `message` (with `message_type = "article"`) and a
//! `message_article` row linked to it.
use std::sync::Arc;
use uuid::Uuid;
use crate::ImksError;
use crate::models::message::MessageType;
use crate::models::message_article::ArticleSort;
use crate::repo::CreateMessageInput;
use crate::socket::socket::Socket;
use super::message::MessageService;
impl MessageService {
/// Handle `article:create` — create a forum article (message + article metadata).
pub async fn create_article(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
let title: String = Self::parse_field(arr, "title")?;
let body: String = Self::parse_field(arr, "body")?;
let summary: Option<String> = Self::parse_optional(arr, "summary")?;
let cover_url: Option<String> = Self::parse_optional(arr, "cover_url")?;
let tags: Option<serde_json::Value> = Self::parse_optional(arr, "tags")?;
self.validate_channel_write(&channel_id.to_string(), &user_id.to_string())
.await?;
// Create the message first (with article type)
let input = CreateMessageInput {
channel_id,
author_id: user_id,
thread_id: None,
reply_to_message_id: None,
message_type: MessageType::Article.as_str().into(),
body,
metadata: None,
system: false,
};
let message = self.repo.create(&input).await?;
// Create the article record
let article = self
.repo
.create_article(
message.id,
&title,
summary.as_deref(),
cover_url.as_deref(),
None,
None,
None,
tags.as_ref(),
)
.await?;
if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) {
ns.emit_to_room(
&channel_id.to_string(),
"article:created",
serde_json::json!({
"message": message,
"article": article,
}),
)
.await;
}
tracing::info!(article_id = %article.id, %channel_id, %user_id, "Article created");
Ok(())
}
/// Handle `article:update` — update article title, summary, cover, tags.
pub async fn update_article(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let message_id: Uuid = Self::parse_field(arr, "message_id")?;
let title: Option<String> = Self::parse_optional(arr, "title")?;
let summary: Option<String> = Self::parse_optional(arr, "summary")?;
let cover_url: Option<String> = Self::parse_optional(arr, "cover_url")?;
let cover_color: Option<String> = Self::parse_optional(arr, "cover_color")?;
let tags: Option<serde_json::Value> = Self::parse_optional(arr, "tags")?;
let existing = self
.repo
.get(message_id)
.await?
.ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?;
self.ensure_author_or_mod(
existing.author_id,
&existing.channel_id.to_string(),
user_id,
)
.await?;
// Update article body if provided
if let Ok(new_body) = Self::parse_field::<String>(arr, "body")
&& !new_body.is_empty()
{
let old_body = existing.body.clone();
self.repo.update_body(message_id, &new_body).await?;
self.repo
.record_edit(message_id, user_id, &old_body, &new_body)
.await?;
}
if let Some(updated) = self
.repo
.update_article(
message_id,
title.as_deref(),
summary.as_deref(),
cover_url.as_deref(),
cover_color.as_deref(),
tags.as_ref(),
)
.await?
&& let Some(ns) = self.namespaces.get_namespace(&socket.namespace)
{
ns.emit_to_room(
&existing.channel_id.to_string(),
"article:updated",
serde_json::to_value(&updated).unwrap_or_default(),
)
.await;
}
Ok(())
}
/// Handle `article:list` — list articles in a forum channel.
pub async fn list_articles(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
self.ensure_readable(&channel_id.to_string(), &user_id.to_string())
.await?;
self.ensure_member(&channel_id.to_string(), &user_id.to_string())
.await?;
let before: Option<(i64, Uuid)> = None;
let limit: Option<i64> = Self::parse_optional(arr, "limit")?;
let page = self
.repo
.list_articles(channel_id, ArticleSort::LatestActivity, before, limit)
.await?;
let _ = socket.emit(
"article:loaded",
serde_json::to_value(&page).unwrap_or_default(),
);
Ok(())
}
/// Handle `article:delete` — soft-delete an article.
pub async fn delete_article(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let message_id: Uuid = Self::parse_field(arr, "message_id")?;
let _channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
let existing = self
.repo
.get(message_id)
.await?
.ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?;
self.ensure_author_or_mod(
existing.author_id,
&existing.channel_id.to_string(),
user_id,
)
.await?;
self.repo.soft_delete(message_id).await?;
if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) {
ns.emit_to_room(&existing.channel_id.to_string(), "article:deleted",
serde_json::json!({"id": message_id.to_string(), "channel_id": existing.channel_id.to_string()}),
).await;
}
Ok(())
}
}
+75
View File
@@ -0,0 +1,75 @@
//! Bookmark event handlers on `MessageService`.
use std::sync::Arc;
use uuid::Uuid;
use crate::ImksError;
use crate::socket::socket::Socket;
use super::message::MessageService;
impl MessageService {
/// Handle `bookmark:add` — toggle (add/update) a bookmark.
pub async fn add_bookmark(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let message_id: Uuid = Self::parse_field(arr, "message_id")?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
let note: Option<String> = Self::parse_optional(arr, "note")?;
self.repo
.add_bookmark(message_id, channel_id, user_id, note.as_deref())
.await?;
Ok(())
}
/// Handle `bookmark:remove` — remove a bookmark.
pub async fn remove_bookmark(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let message_id: Uuid = Self::parse_field(arr, "message_id")?;
self.repo.remove_bookmark(message_id, user_id).await?;
Ok(())
}
/// Handle `bookmark:list` — list a user's bookmarks.
pub async fn list_bookmarks(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let before: Option<Uuid> = Self::parse_optional(arr, "before")?;
let limit: Option<i64> = Self::parse_optional(arr, "limit")?;
let page = self.repo.list_bookmarks(user_id, before, limit).await?;
let _ = socket.emit(
"bookmark:loaded",
serde_json::to_value(&page).unwrap_or_default(),
);
Ok(())
}
}
+98
View File
@@ -0,0 +1,98 @@
//! Interactive component event handlers on `MessageService`.
//!
//! Handles button clicks and select menu interactions on message components.
//! When a user clicks a button, the server updates the component state and
//! broadcasts the interaction to the channel.
use std::sync::Arc;
use uuid::Uuid;
use crate::ImksError;
use crate::socket::socket::Socket;
use super::message::MessageService;
impl MessageService {
/// Handle `component:interact` — a user clicked a button or selected from a menu.
pub async fn interact_component(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let component_id: Uuid = Self::parse_field(arr, "component_id")?;
let custom_id: String = Self::parse_field(arr, "custom_id")?;
let message_id: Uuid = Self::parse_field(arr, "message_id")?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
// Get current components to verify the interaction is valid
let components = self.repo.get_components(message_id).await?;
let component = components.iter().find(|c| c.id == component_id);
if component.is_none() {
return Err(ImksError::NotFound(format!("component {component_id}")));
}
// Broadcast the interaction event to all clients in the channel.
// The actual action (e.g., approve/deny) is handled by the bot/webhook
// that listens for this event. The server just relays and disables the
// component to prevent double-clicks.
if let Some(ns) = self.namespaces.get_namespace(&socket.namespace) {
ns.emit_to_room(
&channel_id.to_string(),
"component:interaction",
serde_json::json!({
"component_id": component_id.to_string(),
"custom_id": custom_id,
"message_id": message_id.to_string(),
"user_id": user_id.to_string(),
"channel_id": channel_id.to_string(),
}),
)
.await;
}
tracing::info!(%component_id, %user_id, %custom_id, "Component interaction");
Ok(())
}
/// Handle `component:update` — update a component's state (e.g., disable after interaction).
#[allow(dead_code)]
pub async fn update_component(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let component_id: Uuid = Self::parse_field(arr, "component_id")?;
let label: Option<String> = Self::parse_optional(arr, "label")?;
let disabled: bool = Self::parse_optional(arr, "disabled")?.unwrap_or(true);
if let Some(updated) = self
.repo
.update_component(component_id, label.as_deref(), disabled)
.await?
&& let Some(ns) = self.namespaces.get_namespace(&socket.namespace)
{
ns.emit_to_room(
&updated.message_id.to_string(),
"component:updated",
serde_json::to_value(&updated).unwrap_or_default(),
)
.await;
}
Ok(())
}
}
+62
View File
@@ -0,0 +1,62 @@
//! Server deployment configuration.
//!
//! Reads from environment variables to select adapter (local/redis/nats)
//! and WebTransport settings.
use std::env;
/// Adapter + message bus configuration for multi-node scale-out.
#[derive(Debug, Clone)]
pub struct DeployConfig {
/// "local" | "redis" | "nats"
pub adapter_mode: String,
/// Redis connection URL (used when adapter_mode = "redis").
pub redis_url: String,
/// NATS connection URL (used when adapter_mode = "nats").
pub nats_url: String,
/// Unique server ID for this node.
pub server_id: String,
/// Enable WebTransport server.
pub webtransport_enabled: bool,
/// WebTransport listen port.
pub webtransport_port: u16,
/// TLS certificate path (required for WebTransport).
pub cert_path: String,
/// TLS key path (required for WebTransport).
pub key_path: String,
}
impl DeployConfig {
pub fn from_env() -> Self {
let server_id = env::var("IMKS_SERVER_ID").unwrap_or_else(|_| hostname());
Self {
adapter_mode: env::var("IMKS_ADAPTER").unwrap_or_else(|_| "local".into()),
redis_url: env::var("IMKS_REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".into()),
nats_url: env::var("IMKS_NATS_URL").unwrap_or_else(|_| "nats://localhost:4222".into()),
server_id,
webtransport_enabled: env::var("IMKS_WT_ENABLED")
.map(|v| v == "true" || v == "1")
.unwrap_or(false),
webtransport_port: env::var("IMKS_WT_PORT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(3001),
cert_path: env::var("IMKS_WT_CERT_PATH").unwrap_or_default(),
key_path: env::var("IMKS_WT_KEY_PATH").unwrap_or_default(),
}
}
}
impl Default for DeployConfig {
fn default() -> Self {
Self::from_env()
}
}
fn hostname() -> String {
env::var("HOSTNAME")
.or_else(|_| env::var("HOST"))
.unwrap_or_else(|_| "imks-node-1".into())
}
+105
View File
@@ -0,0 +1,105 @@
//! Draft event handlers on `MessageService`.
//!
//! Drafts are per-user private data — no permission checks beyond auth needed.
use std::sync::Arc;
use uuid::Uuid;
use crate::ImksError;
use crate::socket::socket::Socket;
use super::message::MessageService;
impl MessageService {
/// Handle `draft:save` — upsert a draft.
pub async fn save_draft(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
let body: String = Self::parse_field(arr, "body")?;
let thread_id: Option<Uuid> = Self::parse_optional(arr, "thread_id")?;
let reply_to_message_id: Option<Uuid> = Self::parse_optional(arr, "reply_to_message_id")?;
let metadata: Option<serde_json::Value> = Self::parse_optional(arr, "metadata")?;
let channel_id_str = channel_id.to_string();
let user_id_str = user_id.to_string();
self.validate_channel_write(&channel_id_str, &user_id_str)
.await?;
self.repo
.upsert_draft(
channel_id,
user_id,
thread_id,
&body,
reply_to_message_id,
metadata,
)
.await?;
Ok(())
}
/// Handle `draft:get` — retrieve a draft and send it back to the requesting socket.
pub async fn get_draft(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
let thread_id: Option<Uuid> = Self::parse_optional(arr, "thread_id")?;
let channel_id_str = channel_id.to_string();
let user_id_str = user_id.to_string();
self.ensure_readable(&channel_id_str, &user_id_str).await?;
self.ensure_member(&channel_id_str, &user_id_str).await?;
let draft = self.repo.get_draft(channel_id, user_id, thread_id).await?;
if let Some(d) = draft {
let _ = socket.emit("draft:loaded", serde_json::to_value(&d).unwrap_or_default());
} else {
let _ = socket.emit("draft:loaded", serde_json::json!(null));
}
Ok(())
}
/// Handle `draft:delete` — delete a draft after sending.
pub async fn delete_draft(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> crate::ImksResult<()> {
let user_id = self.user_id(&socket)?;
let arr = data
.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))?;
let channel_id: Uuid = Self::parse_field(arr, "channel_id")?;
let thread_id: Option<Uuid> = Self::parse_optional(arr, "thread_id")?;
let channel_id_str = channel_id.to_string();
let user_id_str = user_id.to_string();
self.ensure_readable(&channel_id_str, &user_id_str).await?;
self.ensure_member(&channel_id_str, &user_id_str).await?;
self.repo
.delete_draft(channel_id, user_id, thread_id)
.await?;
Ok(())
}
}
+970
View File
@@ -0,0 +1,970 @@
//! Message service — the business logic layer connecting auth, permissions,
//! persistence, and real-time broadcast.
//!
//! Validates every operation through the gRPC permission chain before
//! touching the database or broadcasting to the room.
use std::sync::Arc;
use std::time::{Duration, Instant};
use chrono::Utc;
use dashmap::DashMap;
use tracing;
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::models::message::Message;
use crate::pb::im::{
CheckPermissionRequest, EnsureReadableRequest, ImPermission, IsMemberRequest,
ResolveChannelRequest,
};
use crate::repo::message_repo::MessageRepo;
use crate::rpc::AppksClients;
use crate::socket::namespace::NamespaceManager;
use crate::socket::socket::Socket;
use crate::{ImksError, ImksResult};
/// Central business-logic service for message operations.
///
/// Every mutating operation performs the following checks in order:
/// 1. JWT authentication (via `Authenticator`)
/// 2. Nonce deduplication (prevent duplicate sends)
/// 3. Rate limiting (per-user, per-channel sliding window)
/// 4. Message body size validation (max 100 KB)
/// 5. Channel readability (`PermissionService.EnsureReadable`)
/// 6. Channel status (`ResolveChannel` → read_only / archived)
/// 7. Channel membership (`MemberService.IsMember`)
/// 8. Operation-specific permission (`PermissionService.CheckPermission`)
/// 9. Ownership validation (for edit/delete of others' messages)
///
/// Only after all gates pass does the operation reach the database
/// and the broadcast adapter.
#[derive(Clone)]
pub struct MessageService {
pub(crate) repo: MessageRepo,
pub(crate) auth: Arc<Authenticator>,
pub(crate) clients: AppksClients,
pub(crate) namespaces: Arc<NamespaceManager>,
/// Rate limiter: stores timestamps of recent sends per (user, channel).
rate_limits: Arc<DashMap<(Uuid, Uuid), Vec<Instant>>>,
/// Nonce dedup cache: nonce → first-seen timestamp. Uses TTL eviction.
nonces: Arc<DashMap<String, Instant>>,
/// Max message body length in bytes.
max_body_size: usize,
/// Per-user, per-channel messages allowed in the rate window.
rate_limit_count: usize,
/// Rate-limiting window duration.
rate_window: Duration,
/// Nonce TTL before reuse is allowed again.
nonce_ttl: Duration,
}
impl MessageService {
/// Create a new message service.
///
/// Internally initializes the JWT `Authenticator` from the token client
/// so that `authenticate_socket()` can verify JWT tokens during `on_connect`.
pub async fn new(
repo: MessageRepo,
clients: AppksClients,
namespaces: Arc<NamespaceManager>,
) -> ImksResult<Self> {
let auth = Arc::new(Authenticator::new(clients.token.clone()).await?);
Ok(Self {
repo,
auth,
clients,
namespaces,
rate_limits: Arc::new(DashMap::new()),
nonces: Arc::new(DashMap::new()),
max_body_size: 100_000,
rate_limit_count: 10,
rate_window: Duration::from_secs(10),
nonce_ttl: Duration::from_secs(300),
})
}
pub fn namespaces(&self) -> &NamespaceManager {
&self.namespaces
}
/// Verify a JWT token and attach the authenticated `user_id` to the socket.
/// Call this from `on_connect` before registering any event handlers.
pub fn authenticate_socket(
&self,
socket: &Socket,
auth_data: Option<&serde_json::Value>,
) -> ImksResult<()> {
let token = auth_data
.and_then(|v| v.get("token"))
.and_then(|v| v.as_str())
.ok_or_else(|| ImksError::Auth("Missing auth token".into()))?;
let claims = self.auth.verify_local(token)?;
if !claims.has_scope("im:read") && !claims.has_scope("im:write") {
return Err(ImksError::Auth("Token lacks im scope".into()));
}
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| ImksError::Auth(format!("Invalid user ID in token: {}", claims.sub)))?;
socket.set_user_id(user_id);
tracing::info!(user_id = %user_id, socket_sid = %socket.sid, "Socket authenticated");
Ok(())
}
/// Handle `channel:join` event by adding the socket to the channel room.
pub async fn join_channel(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> ImksResult<()> {
let user_id = self.user_id(&socket)?;
let payload = Self::first_payload(data)?;
let channel_id: Uuid = Self::parse_field(payload, "channel_id")?;
let channel_id_str = channel_id.to_string();
let user_id_str = user_id.to_string();
self.ensure_readable(&channel_id_str, &user_id_str).await?;
self.ensure_member(&channel_id_str, &user_id_str).await?;
let namespace = self
.namespaces
.get_namespace(&socket.namespace)
.ok_or_else(|| {
ImksError::Namespace(format!("namespace {} not found", socket.namespace))
})?;
namespace.join_room(&socket.sid, &channel_id_str).await?;
tracing::info!(socket_sid = %socket.sid, %channel_id, %user_id, "Socket joined channel room");
Ok(())
}
/// Handle `channel:leave` event by removing the socket from the channel room.
pub async fn leave_channel(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> ImksResult<()> {
let _user_id = self.user_id(&socket)?;
let payload = Self::first_payload(data)?;
let channel_id: Uuid = Self::parse_field(payload, "channel_id")?;
let channel_id_str = channel_id.to_string();
if let Some(namespace) = self.namespaces.get_namespace(&socket.namespace) {
namespace.leave_room(&socket.sid, &channel_id_str).await?;
}
tracing::info!(socket_sid = %socket.sid, %channel_id, "Socket left channel room");
Ok(())
}
/// Handle `message:send` event from a connected socket.
///
/// Expected `data` shape (Socket.IO event array tail):
/// ```json
/// [{
/// "channel_id": "...",
/// "body": "hello world",
/// "thread_id": null,
/// "reply_to_message_id": null,
/// "nonce": "optional-client-idempotency-key"
/// }]
/// ```
pub async fn send_message(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> ImksResult<Message> {
let user_id = socket
.user_id()
.ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))?;
let payload = self.parse_send_payload(data)?;
let channel_id = payload.channel_id;
let channel_id_str = channel_id.to_string();
let user_id_str = user_id.to_string();
self.validate_send_payload(&payload.body, &payload.nonce, user_id, channel_id)?;
self.validate_channel_write(&channel_id_str, &user_id_str)
.await?;
let nonce_key = match payload.nonce.as_deref() {
Some(nonce) => Some(self.reserve_nonce(nonce, user_id, channel_id)?),
None => None,
};
let result = self
.create_and_dispatch(payload, user_id, &socket.namespace)
.await;
if result.is_err() {
self.release_nonce(nonce_key);
}
result
}
/// Handle `message:edit` event.
///
/// The author can edit their own message. Users with `MANAGE_MESSAGES`
/// permission can edit any message in the channel.
pub async fn edit_message(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> ImksResult<Message> {
let user_id = socket
.user_id()
.ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))?;
let payload = Self::first_payload(data)?;
let message_id: Uuid = Self::parse_field(payload, "message_id")?;
let new_body: String = Self::parse_field(payload, "body")?;
let existing = self
.repo
.get(message_id)
.await?
.ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?;
self.ensure_author_or_mod(
existing.author_id,
&existing.channel_id.to_string(),
user_id,
)
.await?;
let old_body = existing.body.clone();
let updated = self.repo.update_body(message_id, &new_body).await?;
self.repo
.record_edit(message_id, user_id, &old_body, &new_body)
.await?;
let namespace = self.namespaces.get_namespace(&socket.namespace);
if let Some(ns) = namespace {
ns.emit_to_room(
&existing.channel_id.to_string(),
"message:updated",
serde_json::to_value(&updated).unwrap_or_default(),
)
.await;
}
tracing::info!(message_id = %message_id, user_id = %user_id, "Message edited");
Ok(updated)
}
/// Handle `message:delete` event (soft-delete).
///
/// Same ownership / moderator rules as edit.
pub async fn delete_message(
&self,
socket: Arc<Socket>,
data: &serde_json::Value,
) -> ImksResult<()> {
let user_id = socket
.user_id()
.ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))?;
let payload = Self::first_payload(data)?;
let message_id: Uuid = Self::parse_field(payload, "message_id")?;
let existing = self
.repo
.get(message_id)
.await?
.ok_or_else(|| ImksError::NotFound(format!("message {message_id}")))?;
self.ensure_author_or_mod(
existing.author_id,
&existing.channel_id.to_string(),
user_id,
)
.await?;
self.repo.soft_delete(message_id).await?;
let namespace = self.namespaces.get_namespace(&socket.namespace);
if let Some(ns) = namespace {
ns.emit_to_room(
&existing.channel_id.to_string(),
"message:deleted",
serde_json::json!({ "id": message_id.to_string(), "channel_id": existing.channel_id.to_string() }),
)
.await;
}
tracing::info!(message_id = %message_id, user_id = %user_id, "Message deleted");
Ok(())
}
// Permission validation helpers
/// Full write-access gate: resolve channel + readability + membership + SEND_MESSAGE.
pub(crate) async fn validate_channel_write(
&self,
channel_id: &str,
user_id: &str,
) -> ImksResult<()> {
// Check read-only / archived first (fast gRPC, returns early)
let channel_info = self.resolve_channel(channel_id).await?;
if channel_info.read_only {
return Err(ImksError::Auth(format!(
"Channel {channel_id} is read-only"
)));
}
if channel_info.archived {
return Err(ImksError::Auth(format!("Channel {channel_id} is archived")));
}
self.ensure_readable(channel_id, user_id).await?;
self.ensure_member(channel_id, user_id).await?;
self.ensure_permission(channel_id, user_id, ImPermission::SendMessage)
.await?;
Ok(())
}
/// Verify the user can read this channel at all.
pub(crate) async fn ensure_readable(&self, channel_id: &str, user_id: &str) -> ImksResult<()> {
let mut client = self.clients.permission.clone();
let resp = client
.ensure_readable(EnsureReadableRequest {
channel_id: channel_id.to_string(),
user_id: user_id.to_string(),
})
.await?;
let inner = resp.into_inner();
if !inner.allowed {
return Err(ImksError::Auth(format!(
"User {user_id} cannot read channel {channel_id}"
)));
}
Ok(())
}
pub(crate) fn user_id(&self, socket: &Socket) -> crate::ImksResult<Uuid> {
socket
.user_id()
.ok_or_else(|| ImksError::Auth("Socket not authenticated".into()))
}
pub(crate) async fn ensure_member(
&self,
channel_id: &str,
user_id: &str,
) -> crate::ImksResult<()> {
let mut client = self.clients.member.clone();
let resp = client
.is_member(IsMemberRequest {
channel_id: channel_id.to_string(),
user_id: user_id.to_string(),
})
.await?;
let inner = resp.into_inner();
if !inner.is_member {
return Err(ImksError::Auth(format!(
"User {user_id} is not a member of channel {channel_id}"
)));
}
Ok(())
}
/// Verify a specific permission.
async fn ensure_permission(
&self,
channel_id: &str,
user_id: &str,
permission: ImPermission,
) -> ImksResult<()> {
let allowed = self
.check_permission(channel_id, user_id, permission)
.await?;
if !allowed {
return Err(ImksError::Auth(format!(
"User {user_id} lacks permission {permission:?} in channel {channel_id}"
)));
}
Ok(())
}
/// Verify the user is the message author or has ManageMessages permission.
pub(crate) async fn ensure_author_or_mod(
&self,
message_author_id: Uuid,
channel_id: &str,
user_id: Uuid,
) -> ImksResult<()> {
if message_author_id == user_id {
return Ok(());
}
let allowed = self
.check_permission(
channel_id,
&user_id.to_string(),
ImPermission::ManageMessages,
)
.await?;
if !allowed {
return Err(ImksError::Auth(
"Only the author or a moderator can modify this message".into(),
));
}
Ok(())
}
/// Low-level permission check returning a boolean.
pub(crate) async fn check_permission(
&self,
channel_id: &str,
user_id: &str,
permission: ImPermission,
) -> ImksResult<bool> {
let mut client = self.clients.permission.clone();
let resp = client
.check_permission(CheckPermissionRequest {
channel_id: channel_id.to_string(),
user_id: user_id.to_string(),
permission: permission as i32,
})
.await?;
Ok(resp.into_inner().allowed)
}
/// Resolve channel metadata via gRPC to check read_only / archived status.
async fn resolve_channel(
&self,
channel_id: &str,
) -> ImksResult<crate::pb::im::ResolveChannelResponse> {
let mut client = self.clients.permission.clone();
let resp = client
.resolve_channel(ResolveChannelRequest {
channel_id: channel_id.to_string(),
})
.await?;
Ok(resp.into_inner())
}
// Rate limiting & dedup
/// Reserve a nonce atomically. Nonces expire after `nonce_ttl`.
fn reserve_nonce(&self, nonce: &str, user_id: Uuid, channel_id: Uuid) -> ImksResult<String> {
let now = Instant::now();
let key = format!("{user_id}:{channel_id}:{nonce}");
// Cleanup expired nonces periodically (probabilistic, ~1/64 chance per check)
if rand::random::<u8>().is_multiple_of(64) {
self.nonces
.retain(|_, t| now.duration_since(*t) < self.nonce_ttl);
}
match self.nonces.entry(key.clone()) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
if now.duration_since(*entry.get()) < self.nonce_ttl {
return Err(ImksError::InvalidInput(
"Duplicate message: this nonce was already used".into(),
));
}
entry.insert(now);
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(now);
}
}
Ok(key)
}
fn release_nonce(&self, nonce_key: Option<String>) {
if let Some(key) = nonce_key {
self.nonces.remove(&key);
}
}
fn validate_body_size(&self, body: &str) -> ImksResult<()> {
if body.len() > self.max_body_size {
return Err(ImksError::InvalidInput(format!(
"Message body exceeds max size of {} bytes (got {})",
self.max_body_size,
body.len()
)));
}
Ok(())
}
/// Check per-user, per-channel rate limit using a sliding window.
fn check_rate_limit(&self, user_id: Uuid, channel_id: Uuid) -> ImksResult<()> {
let key = (user_id, channel_id);
let now = Instant::now();
let mut entry = self.rate_limits.entry(key).or_default();
// Evict timestamps outside the window
entry.retain(|t| now.duration_since(*t) < self.rate_window);
if entry.len() >= self.rate_limit_count {
return Err(ImksError::InvalidInput(format!(
"Rate limit exceeded: max {} messages per {}s",
self.rate_limit_count,
self.rate_window.as_secs()
)));
}
entry.push(now);
Ok(())
}
/// Combined safety checks for message sending: body size and rate limit.
fn validate_send_payload(
&self,
body: &str,
_nonce: &Option<String>,
user_id: Uuid,
channel_id: Uuid,
) -> ImksResult<()> {
self.validate_body_size(body)?;
self.check_rate_limit(user_id, channel_id)?;
Ok(())
}
// Payload parsers
/// Parse attachment inputs from a JSON payload.
fn parse_attachments(value: &serde_json::Value) -> Vec<AttachmentInput> {
value
.get("attachments")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|a| {
Some(AttachmentInput {
filename: a.get("filename")?.as_str()?.to_string(),
url: a.get("url")?.as_str()?.to_string(),
size: a.get("size")?.as_i64()?,
content_type: a
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
})
})
.collect()
})
.unwrap_or_default()
}
/// Parse embed inputs from a JSON payload.
fn parse_embeds(value: &serde_json::Value) -> Vec<EmbedInput> {
value
.get("embeds")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|e| {
let fields: Vec<(String, String, bool)> = e
.get("fields")
.and_then(|v| v.as_array())
.map(|fa| {
fa.iter()
.filter_map(|f| {
Some((
f.get("name")?.as_str()?.to_string(),
f.get("value")?.as_str()?.to_string(),
f.get("inline")
.and_then(|v| v.as_bool())
.unwrap_or(false),
))
})
.collect()
})
.unwrap_or_default();
Some(EmbedInput {
embed_type: e.get("embed_type")?.as_str()?.to_string(),
title: e.get("title").and_then(|v| v.as_str()).map(String::from),
description: e
.get("description")
.and_then(|v| v.as_str())
.map(String::from),
url: e.get("url").and_then(|v| v.as_str()).map(String::from),
image_url: e
.get("image_url")
.and_then(|v| v.as_str())
.map(String::from),
fields,
})
})
.collect()
})
.unwrap_or_default()
}
/// Parse a sticker input from a JSON payload (single object or null).
fn parse_sticker(value: &serde_json::Value) -> Option<StickerInput> {
value.get("sticker").and_then(|s| {
Some(StickerInput {
sticker_id: MessageService::parse_field(s, "sticker_id").ok()?,
name: MessageService::parse_field(s, "name").ok()?,
image_url: MessageService::parse_field(s, "image_url").ok()?,
format_type: s
.get("format_type")
.and_then(|v| v.as_str())
.unwrap_or("png")
.to_string(),
})
})
}
/// Parse a forward input from a JSON payload (single object or null).
fn parse_forward(value: &serde_json::Value) -> Option<ForwardInput> {
value.get("forward").and_then(|f| {
Some(ForwardInput {
source_message_id: MessageService::parse_field(f, "source_message_id").ok()?,
source_channel_id: MessageService::parse_field(f, "source_channel_id").ok()?,
})
})
}
/// Parse the `message:send` event payload including optional rich content.
fn parse_send_payload(&self, data: &serde_json::Value) -> ImksResult<SendPayload> {
let arr = data
.as_array()
.ok_or_else(|| ImksError::InvalidInput("Event data must be a JSON array".into()))?;
if arr.is_empty() {
return Err(ImksError::InvalidInput("Empty event data".into()));
}
let payload = &arr[0];
let channel_id: Uuid = Self::parse_field(payload, "channel_id")?;
let body: String = Self::parse_field(payload, "body")?;
let thread_id = Self::parse_optional(payload, "thread_id")?;
let reply_to_message_id = Self::parse_optional(payload, "reply_to_message_id")?;
let nonce: Option<String> = Self::parse_optional(payload, "nonce")?;
let mentioned_user_ids: Vec<Uuid> =
Self::parse_optional(payload, "mentioned_user_ids")?.unwrap_or_default();
let attachments = Self::parse_attachments(payload);
let embeds = Self::parse_embeds(payload);
let sticker = Self::parse_sticker(payload);
let forward = Self::parse_forward(payload);
Ok(SendPayload {
channel_id,
body,
thread_id,
reply_to_message_id,
nonce,
mentioned_user_ids,
attachments,
embeds,
sticker,
forward,
})
}
pub(crate) fn first_payload(data: &serde_json::Value) -> ImksResult<&serde_json::Value> {
data.as_array()
.and_then(|a| a.first())
.ok_or_else(|| ImksError::InvalidInput("Expected [payload] array".into()))
}
pub(crate) fn parse_field<T: serde::de::DeserializeOwned>(
value: &serde_json::Value,
field: &str,
) -> crate::ImksResult<T> {
let field_value = value
.get(field)
.ok_or_else(|| ImksError::InvalidInput(format!("Missing required field: {field}")))?;
serde_json::from_value(field_value.clone())
.map_err(|e| ImksError::InvalidInput(format!("Invalid field {field}: {e}")))
}
pub(crate) fn parse_optional<T: serde::de::DeserializeOwned>(
value: &serde_json::Value,
field: &str,
) -> ImksResult<Option<T>> {
match value.get(field) {
Some(v) if v.is_null() => Ok(None),
Some(v) => serde_json::from_value(v.clone())
.map(Some)
.map_err(|e| ImksError::InvalidInput(format!("Invalid field {field}: {e}"))),
None => Ok(None),
}
}
/// Create the message and all rich content in one database transaction,
/// then broadcast to the channel room after commit.
async fn create_and_dispatch(
&self,
payload: SendPayload,
user_id: Uuid,
namespace_path: &str,
) -> ImksResult<Message> {
let mut tx = self.repo.pool().begin().await?;
let message_id = Uuid::now_v7();
let now = Utc::now();
let message = sqlx::query_as::<_, Message>(
r#"
INSERT INTO message (
id, channel_id, author_id, thread_id, reply_to_message_id,
message_type, body, metadata, pinned, system,
edited_at, deleted_at, created_at, updated_at
) VALUES (
$1, $2, $3, $4, $5,
'text', $6, NULL, FALSE, FALSE,
NULL, NULL, $7, $7
)
RETURNING *
"#,
)
.bind(message_id)
.bind(payload.channel_id)
.bind(user_id)
.bind(payload.thread_id)
.bind(payload.reply_to_message_id)
.bind(&payload.body)
.bind(now)
.fetch_one(&mut *tx)
.await?;
if let Some(thread_id) = payload.thread_id {
sqlx::query(
r#"
UPDATE message_thread
SET replies_count = replies_count + 1,
last_reply_message_id = $1,
last_reply_at = $2,
updated_at = $2
WHERE id = $3
"#,
)
.bind(message_id)
.bind(now)
.bind(thread_id)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
INSERT INTO message_thread_participant (id, thread_id, user_id, joined_reason, joined_at)
VALUES ($1, $2, $3, 'reply', $4)
ON CONFLICT (thread_id, user_id) DO UPDATE SET joined_reason = EXCLUDED.joined_reason
"#,
)
.bind(Uuid::now_v7())
.bind(thread_id)
.bind(user_id)
.bind(now)
.execute(&mut *tx)
.await?;
sqlx::query(
"UPDATE message_thread SET participants_count = (SELECT COUNT(*) FROM message_thread_participant WHERE thread_id = $1) WHERE id = $1",
)
.bind(thread_id)
.execute(&mut *tx)
.await?;
}
for att in &payload.attachments {
sqlx::query(
r#"
INSERT INTO message_attachment (
id, message_id, filename, content_type, size, url,
storage_key, width, height, spoiler
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, FALSE)
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(&att.filename)
.bind(att.content_type.as_deref())
.bind(att.size)
.bind(&att.url)
.execute(&mut *tx)
.await?;
}
for emb in &payload.embeds {
let embed_id = Uuid::now_v7();
sqlx::query(
r#"
INSERT INTO message_embed (
id, message_id, embed_type, title, description, url, color,
image_url, author_name, author_url, footer_text, provider_name
) VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, NULL, NULL, NULL, NULL)
"#,
)
.bind(embed_id)
.bind(message_id)
.bind(&emb.embed_type)
.bind(emb.title.as_deref())
.bind(emb.description.as_deref())
.bind(emb.url.as_deref())
.bind(emb.image_url.as_deref())
.execute(&mut *tx)
.await?;
for (position, (name, value, inline)) in emb.fields.iter().enumerate() {
sqlx::query(
r#"
INSERT INTO message_embed_field (id, embed_id, name, value, inline, position)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
)
.bind(Uuid::now_v7())
.bind(embed_id)
.bind(name)
.bind(value)
.bind(*inline)
.bind(position as i32)
.execute(&mut *tx)
.await?;
}
}
if let Some(sticker) = &payload.sticker {
sqlx::query(
r#"
INSERT INTO message_sticker (id, message_id, sticker_id, name, image_url, format_type, pack_name, tags)
VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL)
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(sticker.sticker_id)
.bind(&sticker.name)
.bind(&sticker.image_url)
.bind(&sticker.format_type)
.execute(&mut *tx)
.await?;
}
if let Some(forward) = &payload.forward {
sqlx::query(
r#"
INSERT INTO message_forward (id, message_id, source_message_id, source_channel_id, forwarded_by)
VALUES ($1, $2, $3, $4, $5)
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(forward.source_message_id)
.bind(forward.source_channel_id)
.bind(user_id)
.execute(&mut *tx)
.await?;
}
for mentioned_id in &payload.mentioned_user_ids {
sqlx::query(
r#"
INSERT INTO message_mention (id, message_id, channel_id, mentioned_user_id, mentioned_by, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(payload.channel_id)
.bind(*mentioned_id)
.bind(user_id)
.bind(now)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
INSERT INTO message_notification (
id, message_id, channel_id, user_id, reason, status, delivery_channel, created_at
) VALUES ($1, $2, $3, $4, 'mention', 'pending', NULL, $5)
"#,
)
.bind(Uuid::now_v7())
.bind(message_id)
.bind(payload.channel_id)
.bind(*mentioned_id)
.bind(now)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
self.broadcast_new_message(&message, payload.channel_id, user_id, namespace_path)
.await;
Ok(message)
}
/// Broadcast a newly created message to the channel room and log.
async fn broadcast_new_message(
&self,
message: &Message,
channel_id: Uuid,
user_id: Uuid,
namespace_path: &str,
) {
let namespace = self.namespaces.get_namespace(namespace_path);
if let Some(ns) = namespace {
ns.emit_to_room(
&channel_id.to_string(),
"message:new",
serde_json::to_value(message).unwrap_or_default(),
)
.await;
}
tracing::info!(
message_id = %message.id,
channel_id = %channel_id,
user_id = %user_id,
"Message sent"
);
}
}
// Rich content input types for parsing
pub(crate) struct SendPayload {
channel_id: Uuid,
body: String,
thread_id: Option<Uuid>,
reply_to_message_id: Option<Uuid>,
nonce: Option<String>,
mentioned_user_ids: Vec<Uuid>,
attachments: Vec<AttachmentInput>,
embeds: Vec<EmbedInput>,
sticker: Option<StickerInput>,
forward: Option<ForwardInput>,
}
pub(crate) struct AttachmentInput {
filename: String,
url: String,
size: i64,
content_type: Option<String>,
}
pub(crate) struct EmbedInput {
embed_type: String,
title: Option<String>,
description: Option<String>,
url: Option<String>,
image_url: Option<String>,
fields: Vec<(String, String, bool)>,
}
pub(crate) struct StickerInput {
sticker_id: Uuid,
name: String,
image_url: String,
format_type: String,
}
pub(crate) struct ForwardInput {
source_message_id: Uuid,
source_channel_id: Uuid,
}
+19
View File
@@ -0,0 +1,19 @@
pub mod article;
pub mod bookmark;
pub mod component;
pub mod deploy;
pub mod draft;
pub mod message;
pub mod pin;
pub mod poll;
pub mod reaction;
pub mod read_state;
pub mod scheduled;
pub mod thread;
pub mod typing;
#[cfg(test)]
mod tests;
pub use deploy::DeployConfig;
pub use message::MessageService;

Some files were not shown because too many files have changed in this diff Show More