821537186e
- Reorganized import statements in adapter tests for better readability - Replaced or_insert_with(Vec::new) with or_default() in test closures - Updated Cargo.lock with new dependency versions and checksums - Added TLS features to tonic dependency configuration - Included sqlx, chrono, and uuid dependencies with specific features - Added jsonwebtoken and arc-swap as project dependencies - Reformatted assertion statements to comply with line length limits - Adjusted base64 import order in engine codec module - Updated protobuf include statement formatting
397 lines
12 KiB
Rust
397 lines
12 KiB
Rust
//! Poll CRUD operations on `MessageRepo`.
|
|
//!
|
|
//! Handles poll creation (with options), voting (with denormalized counts),
|
|
//! and result retrieval.
|
|
|
|
use chrono::Utc;
|
|
use sqlx::Row;
|
|
use uuid::Uuid;
|
|
|
|
use crate::ImksResult;
|
|
use crate::models::message_poll::{MessagePoll, MessagePollOption, MessagePollVote, PollResult};
|
|
|
|
use super::message_repo::MessageRepo;
|
|
|
|
/// Canonical poll target resolved from the database.
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct PollTarget {
|
|
pub poll_id: Uuid,
|
|
pub message_id: Uuid,
|
|
pub channel_id: Uuid,
|
|
}
|
|
|
|
impl MessageRepo {
|
|
/// Resolve and validate a poll option's canonical message/channel target.
|
|
pub async fn get_poll_target(&self, poll_id: Uuid, option_id: Uuid) -> ImksResult<PollTarget> {
|
|
let row = sqlx::query(
|
|
r#"
|
|
SELECT p.id AS poll_id, p.message_id, m.channel_id, o.id AS option_id
|
|
FROM message_poll p
|
|
JOIN message m ON m.id = p.message_id
|
|
JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2
|
|
WHERE p.id = $1 AND m.deleted_at IS NULL
|
|
"#,
|
|
)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.fetch_optional(self.pool())
|
|
.await?
|
|
.ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?;
|
|
|
|
Ok(PollTarget {
|
|
poll_id: row.get("poll_id"),
|
|
message_id: row.get("message_id"),
|
|
channel_id: row.get("channel_id"),
|
|
})
|
|
}
|
|
|
|
/// Create a poll with its options. Returns the poll (options fetched separately).
|
|
pub async fn create_poll(
|
|
&self,
|
|
message_id: Uuid,
|
|
question: &str,
|
|
allow_multiselect: bool,
|
|
max_selections: Option<i32>,
|
|
expires_at: Option<chrono::DateTime<Utc>>,
|
|
options: &[(String, Option<String>)],
|
|
) -> ImksResult<MessagePoll> {
|
|
let poll_id = Uuid::now_v7();
|
|
let now = Utc::now();
|
|
|
|
let poll = sqlx::query_as::<_, MessagePoll>(
|
|
r#"
|
|
INSERT INTO message_poll (
|
|
id, message_id, question, allow_multiselect,
|
|
max_selections, expires_at, total_votes, created_at, updated_at
|
|
) VALUES ($1, $2, $3, $4, $5, $6, 0, $7, $7)
|
|
RETURNING *
|
|
"#,
|
|
)
|
|
.bind(poll_id)
|
|
.bind(message_id)
|
|
.bind(question)
|
|
.bind(allow_multiselect)
|
|
.bind(max_selections)
|
|
.bind(expires_at)
|
|
.bind(now)
|
|
.fetch_one(self.pool())
|
|
.await?;
|
|
|
|
for (i, (text, emoji)) in options.iter().enumerate() {
|
|
let opt_id = Uuid::now_v7();
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO message_poll_option (id, poll_id, text, emoji, vote_count, position)
|
|
VALUES ($1, $2, $3, $4, 0, $5)
|
|
"#,
|
|
)
|
|
.bind(opt_id)
|
|
.bind(poll_id)
|
|
.bind(text)
|
|
.bind(emoji.as_deref())
|
|
.bind(i as i32)
|
|
.execute(self.pool())
|
|
.await?;
|
|
}
|
|
|
|
Ok(poll)
|
|
}
|
|
|
|
/// Cast a validated vote and return the canonical message/channel target.
|
|
pub async fn cast_vote_checked(
|
|
&self,
|
|
poll_id: Uuid,
|
|
option_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> ImksResult<PollTarget> {
|
|
let mut tx = self.pool().begin().await?;
|
|
let now = Utc::now();
|
|
|
|
let row = sqlx::query(
|
|
r#"
|
|
SELECT p.id AS poll_id, p.message_id, p.allow_multiselect, p.max_selections,
|
|
p.expires_at, m.channel_id, o.id AS option_id
|
|
FROM message_poll p
|
|
JOIN message m ON m.id = p.message_id
|
|
JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2
|
|
WHERE p.id = $1 AND m.deleted_at IS NULL
|
|
FOR UPDATE OF p
|
|
"#,
|
|
)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.fetch_optional(&mut *tx)
|
|
.await?
|
|
.ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?;
|
|
|
|
let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
|
|
if expires_at.is_some_and(|exp| now >= exp) {
|
|
return Err(crate::ImksError::InvalidInput("Poll has expired".into()));
|
|
}
|
|
|
|
let allow_multiselect: bool = row.get("allow_multiselect");
|
|
let max_selections: Option<i32> = row.get("max_selections");
|
|
let current_votes: Vec<Uuid> = sqlx::query_scalar(
|
|
"SELECT option_id FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2",
|
|
)
|
|
.bind(poll_id)
|
|
.bind(user_id)
|
|
.fetch_all(&mut *tx)
|
|
.await?;
|
|
|
|
if current_votes.contains(&option_id) {
|
|
return Err(crate::ImksError::InvalidInput(
|
|
"Already voted for this option".into(),
|
|
));
|
|
}
|
|
if !allow_multiselect && !current_votes.is_empty() {
|
|
return Err(crate::ImksError::InvalidInput(
|
|
"Poll allows only one selection".into(),
|
|
));
|
|
}
|
|
if let Some(max) = max_selections
|
|
&& current_votes.len() >= max.max(1) as usize
|
|
{
|
|
return Err(crate::ImksError::InvalidInput(
|
|
"Poll selection limit exceeded".into(),
|
|
));
|
|
}
|
|
|
|
let vote_id = Uuid::now_v7();
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, created_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
"#,
|
|
)
|
|
.bind(vote_id)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.bind(user_id)
|
|
.bind(now)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
sqlx::query("UPDATE message_poll_option SET vote_count = vote_count + 1 WHERE id = $1")
|
|
.bind(option_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll SET total_votes = total_votes + 1, updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(now)
|
|
.bind(poll_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
Ok(PollTarget {
|
|
poll_id,
|
|
message_id: row.get("message_id"),
|
|
channel_id: row.get("channel_id"),
|
|
})
|
|
}
|
|
|
|
/// Remove a validated vote and return the canonical message/channel target.
|
|
pub async fn remove_vote_checked(
|
|
&self,
|
|
poll_id: Uuid,
|
|
option_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> ImksResult<Option<PollTarget>> {
|
|
let mut tx = self.pool().begin().await?;
|
|
let now = Utc::now();
|
|
|
|
let row = sqlx::query(
|
|
r#"
|
|
SELECT p.id AS poll_id, p.message_id, m.channel_id, o.id AS option_id
|
|
FROM message_poll p
|
|
JOIN message m ON m.id = p.message_id
|
|
JOIN message_poll_option o ON o.poll_id = p.id AND o.id = $2
|
|
WHERE p.id = $1 AND m.deleted_at IS NULL
|
|
FOR UPDATE OF p
|
|
"#,
|
|
)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.fetch_optional(&mut *tx)
|
|
.await?
|
|
.ok_or_else(|| crate::ImksError::NotFound(format!("poll {poll_id} option {option_id}")))?;
|
|
|
|
let result = sqlx::query(
|
|
r#"
|
|
DELETE FROM message_poll_vote
|
|
WHERE poll_id = $1 AND option_id = $2 AND user_id = $3
|
|
"#,
|
|
)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.bind(user_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
tx.commit().await?;
|
|
return Ok(None);
|
|
}
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1",
|
|
)
|
|
.bind(option_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll SET total_votes = GREATEST(total_votes - 1, 0), updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(now)
|
|
.bind(poll_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
Ok(Some(PollTarget {
|
|
poll_id,
|
|
message_id: row.get("message_id"),
|
|
channel_id: row.get("channel_id"),
|
|
}))
|
|
}
|
|
|
|
/// Cast a vote. Increments denormalized counts atomically.
|
|
pub async fn vote(
|
|
&self,
|
|
poll_id: Uuid,
|
|
option_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> ImksResult<MessagePollVote> {
|
|
let id = Uuid::now_v7();
|
|
let now = Utc::now();
|
|
|
|
let vote = sqlx::query_as::<_, MessagePollVote>(
|
|
r#"
|
|
INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, created_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (poll_id, user_id, option_id) DO NOTHING
|
|
RETURNING *
|
|
"#,
|
|
)
|
|
.bind(id)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.bind(user_id)
|
|
.bind(now)
|
|
.fetch_optional(self.pool())
|
|
.await?
|
|
.ok_or_else(|| crate::ImksError::InvalidInput("Already voted for this option".into()))?;
|
|
|
|
sqlx::query("UPDATE message_poll_option SET vote_count = vote_count + 1 WHERE id = $1")
|
|
.bind(option_id)
|
|
.execute(self.pool())
|
|
.await?;
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll SET total_votes = total_votes + 1, updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(now)
|
|
.bind(poll_id)
|
|
.execute(self.pool())
|
|
.await?;
|
|
|
|
Ok(vote)
|
|
}
|
|
|
|
/// Remove a vote. Decrements denormalized counts.
|
|
pub async fn remove_vote(
|
|
&self,
|
|
poll_id: Uuid,
|
|
option_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> ImksResult<bool> {
|
|
let result = sqlx::query(
|
|
r#"
|
|
DELETE FROM message_poll_vote
|
|
WHERE poll_id = $1 AND option_id = $2 AND user_id = $3
|
|
"#,
|
|
)
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.bind(user_id)
|
|
.execute(self.pool())
|
|
.await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Ok(false);
|
|
}
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1",
|
|
)
|
|
.bind(option_id)
|
|
.execute(self.pool())
|
|
.await?;
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll SET total_votes = GREATEST(total_votes - 1, 0), updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(Utc::now())
|
|
.bind(poll_id)
|
|
.execute(self.pool())
|
|
.await?;
|
|
|
|
Ok(true)
|
|
}
|
|
|
|
/// Get full poll results including options, vote counts, and the given user's votes.
|
|
pub async fn get_poll_result(
|
|
&self,
|
|
message_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> ImksResult<Option<PollResult>> {
|
|
let poll =
|
|
sqlx::query_as::<_, MessagePoll>("SELECT * FROM message_poll WHERE message_id = $1")
|
|
.bind(message_id)
|
|
.fetch_optional(self.pool())
|
|
.await?;
|
|
|
|
let Some(poll) = poll else {
|
|
return Ok(None);
|
|
};
|
|
|
|
let options: Vec<MessagePollOption> = sqlx::query_as(
|
|
"SELECT * FROM message_poll_option WHERE poll_id = $1 ORDER BY position",
|
|
)
|
|
.bind(poll.id)
|
|
.fetch_all(self.pool())
|
|
.await?;
|
|
|
|
let my_votes: Vec<Uuid> = sqlx::query_scalar(
|
|
"SELECT option_id FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2",
|
|
)
|
|
.bind(poll.id)
|
|
.bind(user_id)
|
|
.fetch_all(self.pool())
|
|
.await?;
|
|
|
|
Ok(Some(PollResult::from_poll(poll, options, my_votes)))
|
|
}
|
|
|
|
/// Close a poll by setting its expiration to now.
|
|
pub async fn close_poll(&self, message_id: Uuid) -> ImksResult<()> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE message_poll
|
|
SET expires_at = $1, updated_at = $1
|
|
WHERE message_id = $2 AND (expires_at IS NULL OR expires_at > $1)
|
|
"#,
|
|
)
|
|
.bind(now)
|
|
.bind(message_id)
|
|
.execute(self.pool())
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
}
|