Files
gitks/service/im/polls.rs
T
2026-06-07 11:30:56 +08:00

373 lines
12 KiB
Rust

use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::AppError;
use crate::immediate::{PollAction, PollEvent};
use crate::models::channels::{MessagePoll, MessagePollOption, MessagePollVote};
use crate::models::common::PollLayout;
use crate::service::ImService;
use crate::service::im::events::ImEvent;
use super::session::ImSession;
use super::util::*;
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct CreatePollParams {
pub question: String,
pub description: Option<String>,
pub options: Vec<CreatePollOptionParams>,
pub layout: Option<String>,
pub allow_multiselect: Option<bool>,
pub duration_hours: Option<i32>,
}
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct CreatePollOptionParams {
pub text: String,
pub emoji_id: Option<String>,
pub emoji_name: Option<String>,
}
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
pub struct VoteParams {
pub option_ids: Vec<Uuid>,
}
impl ImService {
pub async fn poll_create(
&self,
ctx: &ImSession,
_wk_name: &str,
channel_id: Uuid,
message_id: Uuid,
params: CreatePollParams,
) -> Result<MessagePoll, AppError> {
let user_uid = ctx.user;
let channel = self.resolve_channel(channel_id).await?;
self.ensure_channel_readable(user_uid, &channel).await?;
self.resolve_message(message_id, channel_id).await?;
let question = required_text(params.question, "question")?;
if params.options.is_empty() || params.options.len() > MAX_POLL_OPTIONS {
return Err(AppError::BadRequest(format!(
"poll must have between 1 and {MAX_POLL_OPTIONS} options"
)));
}
let layout = parse_enum(
params.layout,
PollLayout::Default,
PollLayout::Unknown,
"layout",
)?;
let now = chrono::Utc::now();
let poll_id = Uuid::now_v7();
let ends_at = params
.duration_hours
.map(|h| now + chrono::Duration::hours(h as i64));
let validated_options: Vec<(String, Option<String>, Option<String>)> = params
.options
.iter()
.map(|opt| {
let text = required_text(opt.text.clone(), "option text")?;
if text.len() > MAX_POLL_OPTION_TEXT {
return Err(AppError::BadRequest("poll option text too long".into()));
}
Ok((text, opt.emoji_id.clone(), opt.emoji_name.clone()))
})
.collect::<Result<Vec<_>, AppError>>()?;
let mut txn = self
.ctx
.db
.writer()
.begin()
.await
.map_err(|_| AppError::TxnError)?;
sqlx::query("SET LOCAL app.current_user_id = $1")
.bind(user_uid)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
let poll = sqlx::query_as::<_, MessagePoll>(
"INSERT INTO message_poll \
(id, message_id, channel_id, question, description, layout, \
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 0, NULL, $10, $10) \
RETURNING id, message_id, channel_id, question, description, layout, \
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
created_at, updated_at",
)
.bind(poll_id)
.bind(message_id)
.bind(channel_id)
.bind(&question)
.bind(params.description.as_deref())
.bind(layout)
.bind(params.allow_multiselect.unwrap_or(false))
.bind(params.duration_hours)
.bind(ends_at)
.bind(now)
.fetch_one(&mut *txn)
.await
.map_err(AppError::Database)?;
for (i, (text, emoji_id, emoji_name)) in validated_options.iter().enumerate() {
sqlx::query(
"INSERT INTO message_poll_option \
(id, poll_id, position, text, emoji_id, emoji_name, vote_count, created_at) \
VALUES ($1, $2, $3, $4, $5, $6, 0, $7)",
)
.bind(Uuid::now_v7())
.bind(poll_id)
.bind(i as i32)
.bind(text)
.bind(emoji_id.as_deref())
.bind(emoji_name.as_deref())
.bind(now)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
}
txn.commit().await.map_err(|_| AppError::TxnError)?;
tracing::info!(poll_id = %poll_id, "Poll created");
let request_id = Uuid::nil();
let event = PollEvent {
channel_id,
poll_id,
action: PollAction::Created,
};
self.publish(&format!("im.poll.{}", channel_id), request_id, &event)
.await;
self.emit_event(ImEvent::Poll {
request_id,
data: event,
});
Ok(poll)
}
pub async fn poll_get(
&self,
ctx: &ImSession,
_wk_name: &str,
channel_id: Uuid,
poll_id: Uuid,
) -> Result<(MessagePoll, Vec<MessagePollOption>), AppError> {
let user_uid = ctx.user;
let channel = self.resolve_channel(channel_id).await?;
self.ensure_channel_readable(user_uid, &channel).await?;
let poll = sqlx::query_as::<_, MessagePoll>(
"SELECT id, message_id, channel_id, question, description, layout, \
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
created_at, updated_at \
FROM message_poll WHERE id = $1 AND channel_id = $2",
)
.bind(poll_id)
.bind(channel_id)
.fetch_optional(self.ctx.db.reader())
.await
.map_err(AppError::Database)?
.ok_or(AppError::NotFound("poll not found".into()))?;
let options = sqlx::query_as::<_, MessagePollOption>(
"SELECT id, poll_id, position, text, emoji_id, emoji_name, vote_count, created_at \
FROM message_poll_option WHERE poll_id = $1 ORDER BY position ASC",
)
.bind(poll_id)
.fetch_all(self.ctx.db.reader())
.await
.map_err(AppError::Database)?;
Ok((poll, options))
}
pub async fn poll_vote(
&self,
ctx: &ImSession,
_wk_name: &str,
channel_id: Uuid,
poll_id: Uuid,
params: VoteParams,
) -> Result<(), AppError> {
let user_uid = ctx.user;
let channel = self.resolve_channel(channel_id).await?;
self.ensure_channel_readable(user_uid, &channel).await?;
let poll = sqlx::query_as::<_, MessagePoll>(
"SELECT id, message_id, channel_id, question, description, layout, \
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
created_at, updated_at \
FROM message_poll WHERE id = $1 AND channel_id = $2",
)
.bind(poll_id)
.bind(channel_id)
.fetch_optional(self.ctx.db.reader())
.await
.map_err(AppError::Database)?
.ok_or(AppError::NotFound("poll not found".into()))?;
if let Some(ends) = poll.ends_at
&& chrono::Utc::now() > ends
{
return Err(AppError::BadRequest("poll has ended".into()));
}
if !poll.allow_multiselect && params.option_ids.len() > 1 {
return Err(AppError::BadRequest("multiselect not allowed".into()));
}
let now = chrono::Utc::now();
let mut txn = self
.ctx
.db
.writer()
.begin()
.await
.map_err(|_| AppError::TxnError)?;
sqlx::query("SET LOCAL app.current_user_id = $1")
.bind(user_uid)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
// Collect old option_ids before deleting
let old_option_ids: Vec<Uuid> = sqlx::query_scalar(
"DELETE FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2 RETURNING option_id",
)
.bind(poll_id)
.bind(user_uid)
.fetch_all(&mut *txn)
.await
.map_err(AppError::Database)?;
let removed = old_option_ids.len() as i32;
// Decrement old vote counts
for opt_id in &old_option_ids {
sqlx::query(
"UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1",
)
.bind(opt_id)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
}
// Insert new votes
let mut new_count = 0i32;
for option_id in &params.option_ids {
sqlx::query(
"INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, voted_at) \
VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING",
)
.bind(Uuid::now_v7())
.bind(poll_id)
.bind(option_id)
.bind(user_uid)
.bind(now)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
sqlx::query(
"UPDATE message_poll_option SET vote_count = vote_count + 1 \
WHERE id = $1 AND poll_id = $2",
)
.bind(option_id)
.bind(poll_id)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
new_count += 1;
}
let delta = new_count - removed;
sqlx::query(
"UPDATE message_poll SET total_votes = total_votes + $1, updated_at = $2 WHERE id = $3",
)
.bind(delta)
.bind(now)
.bind(poll_id)
.execute(&mut *txn)
.await
.map_err(AppError::Database)?;
txn.commit().await.map_err(|_| AppError::TxnError)?;
let request_id = Uuid::nil();
let event = PollEvent {
channel_id,
poll_id,
action: PollAction::Voted,
};
self.publish(&format!("im.poll.{}", channel_id), request_id, &event)
.await;
self.emit_event(ImEvent::Poll {
request_id,
data: event,
});
Ok(())
}
pub async fn poll_results(
&self,
ctx: &ImSession,
_wk_name: &str,
channel_id: Uuid,
poll_id: Uuid,
) -> Result<Vec<MessagePollVote>, AppError> {
let user_uid = ctx.user;
let channel = self.resolve_channel(channel_id).await?;
self.ensure_channel_readable(user_uid, &channel).await?;
sqlx::query_as::<_, MessagePollVote>(
"SELECT id, poll_id, option_id, user_id, voted_at \
FROM message_poll_vote WHERE poll_id = $1 ORDER BY voted_at ASC",
)
.bind(poll_id)
.fetch_all(self.ctx.db.reader())
.await
.map_err(AppError::Database)
}
pub async fn poll_delete(
&self,
ctx: &ImSession,
_wk_name: &str,
channel_id: Uuid,
poll_id: Uuid,
) -> Result<(), AppError> {
let user_uid = ctx.user;
let channel = self.resolve_channel(channel_id).await?;
self.ensure_channel_editable(user_uid, &channel).await?;
let result = sqlx::query("DELETE FROM message_poll WHERE id = $1 AND channel_id = $2")
.bind(poll_id)
.bind(channel_id)
.execute(self.ctx.db.writer())
.await
.map_err(AppError::Database)?;
ensure_affected(result.rows_affected(), "poll not found")?;
let request_id = Uuid::nil();
let event = PollEvent {
channel_id,
poll_id,
action: PollAction::Deleted,
};
self.publish(&format!("im.poll.{}", channel_id), request_id, &event)
.await;
self.emit_event(ImEvent::Poll {
request_id,
data: event,
});
Ok(())
}
}