use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::error::AppError; use crate::immediate::{PollAction, PollEvent}; use crate::models::channels::{MessagePoll, MessagePollOption, MessagePollVote}; use crate::models::common::PollLayout; use crate::service::ImService; use crate::service::im::events::ImEvent; use super::session::ImSession; use super::util::*; #[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] pub struct CreatePollParams { pub question: String, pub description: Option, pub options: Vec, pub layout: Option, pub allow_multiselect: Option, pub duration_hours: Option, } #[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] pub struct CreatePollOptionParams { pub text: String, pub emoji_id: Option, pub emoji_name: Option, } #[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)] pub struct VoteParams { pub option_ids: Vec, } impl ImService { pub async fn poll_create( &self, ctx: &ImSession, _wk_name: &str, channel_id: Uuid, message_id: Uuid, params: CreatePollParams, ) -> Result { let user_uid = ctx.user; let channel = self.resolve_channel(channel_id).await?; self.ensure_channel_readable(user_uid, &channel).await?; self.resolve_message(message_id, channel_id).await?; let question = required_text(params.question, "question")?; if params.options.is_empty() || params.options.len() > MAX_POLL_OPTIONS { return Err(AppError::BadRequest(format!( "poll must have between 1 and {MAX_POLL_OPTIONS} options" ))); } let layout = parse_enum( params.layout, PollLayout::Default, PollLayout::Unknown, "layout", )?; let now = chrono::Utc::now(); let poll_id = Uuid::now_v7(); let ends_at = params .duration_hours .map(|h| now + chrono::Duration::hours(h as i64)); let validated_options: Vec<(String, Option, Option)> = params .options .iter() .map(|opt| { let text = required_text(opt.text.clone(), "option text")?; if text.len() > MAX_POLL_OPTION_TEXT { return Err(AppError::BadRequest("poll option text too long".into())); } Ok((text, opt.emoji_id.clone(), opt.emoji_name.clone())) }) .collect::, AppError>>()?; let mut txn = self .ctx .db .writer() .begin() .await .map_err(|_| AppError::TxnError)?; sqlx::query("SET LOCAL app.current_user_id = $1") .bind(user_uid) .execute(&mut *txn) .await .map_err(AppError::Database)?; let poll = sqlx::query_as::<_, MessagePoll>( "INSERT INTO message_poll \ (id, message_id, channel_id, question, description, layout, \ allow_multiselect, duration_hours, ends_at, total_votes, metadata, \ created_at, updated_at) \ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 0, NULL, $10, $10) \ RETURNING id, message_id, channel_id, question, description, layout, \ allow_multiselect, duration_hours, ends_at, total_votes, metadata, \ created_at, updated_at", ) .bind(poll_id) .bind(message_id) .bind(channel_id) .bind(&question) .bind(params.description.as_deref()) .bind(layout) .bind(params.allow_multiselect.unwrap_or(false)) .bind(params.duration_hours) .bind(ends_at) .bind(now) .fetch_one(&mut *txn) .await .map_err(AppError::Database)?; for (i, (text, emoji_id, emoji_name)) in validated_options.iter().enumerate() { sqlx::query( "INSERT INTO message_poll_option \ (id, poll_id, position, text, emoji_id, emoji_name, vote_count, created_at) \ VALUES ($1, $2, $3, $4, $5, $6, 0, $7)", ) .bind(Uuid::now_v7()) .bind(poll_id) .bind(i as i32) .bind(text) .bind(emoji_id.as_deref()) .bind(emoji_name.as_deref()) .bind(now) .execute(&mut *txn) .await .map_err(AppError::Database)?; } txn.commit().await.map_err(|_| AppError::TxnError)?; tracing::info!(poll_id = %poll_id, "Poll created"); let request_id = Uuid::nil(); let event = PollEvent { channel_id, poll_id, action: PollAction::Created, }; self.publish(&format!("im.poll.{}", channel_id), request_id, &event) .await; self.emit_event(ImEvent::Poll { request_id, data: event, }); Ok(poll) } pub async fn poll_get( &self, ctx: &ImSession, _wk_name: &str, channel_id: Uuid, poll_id: Uuid, ) -> Result<(MessagePoll, Vec), AppError> { let user_uid = ctx.user; let channel = self.resolve_channel(channel_id).await?; self.ensure_channel_readable(user_uid, &channel).await?; let poll = sqlx::query_as::<_, MessagePoll>( "SELECT id, message_id, channel_id, question, description, layout, \ allow_multiselect, duration_hours, ends_at, total_votes, metadata, \ created_at, updated_at \ FROM message_poll WHERE id = $1 AND channel_id = $2", ) .bind(poll_id) .bind(channel_id) .fetch_optional(self.ctx.db.reader()) .await .map_err(AppError::Database)? .ok_or(AppError::NotFound("poll not found".into()))?; let options = sqlx::query_as::<_, MessagePollOption>( "SELECT id, poll_id, position, text, emoji_id, emoji_name, vote_count, created_at \ FROM message_poll_option WHERE poll_id = $1 ORDER BY position ASC", ) .bind(poll_id) .fetch_all(self.ctx.db.reader()) .await .map_err(AppError::Database)?; Ok((poll, options)) } pub async fn poll_vote( &self, ctx: &ImSession, _wk_name: &str, channel_id: Uuid, poll_id: Uuid, params: VoteParams, ) -> Result<(), AppError> { let user_uid = ctx.user; let channel = self.resolve_channel(channel_id).await?; self.ensure_channel_readable(user_uid, &channel).await?; let poll = sqlx::query_as::<_, MessagePoll>( "SELECT id, message_id, channel_id, question, description, layout, \ allow_multiselect, duration_hours, ends_at, total_votes, metadata, \ created_at, updated_at \ FROM message_poll WHERE id = $1 AND channel_id = $2", ) .bind(poll_id) .bind(channel_id) .fetch_optional(self.ctx.db.reader()) .await .map_err(AppError::Database)? .ok_or(AppError::NotFound("poll not found".into()))?; if let Some(ends) = poll.ends_at && chrono::Utc::now() > ends { return Err(AppError::BadRequest("poll has ended".into())); } if !poll.allow_multiselect && params.option_ids.len() > 1 { return Err(AppError::BadRequest("multiselect not allowed".into())); } let now = chrono::Utc::now(); let mut txn = self .ctx .db .writer() .begin() .await .map_err(|_| AppError::TxnError)?; sqlx::query("SET LOCAL app.current_user_id = $1") .bind(user_uid) .execute(&mut *txn) .await .map_err(AppError::Database)?; // Collect old option_ids before deleting let old_option_ids: Vec = sqlx::query_scalar( "DELETE FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2 RETURNING option_id", ) .bind(poll_id) .bind(user_uid) .fetch_all(&mut *txn) .await .map_err(AppError::Database)?; let removed = old_option_ids.len() as i32; // Decrement old vote counts for opt_id in &old_option_ids { sqlx::query( "UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1", ) .bind(opt_id) .execute(&mut *txn) .await .map_err(AppError::Database)?; } // Insert new votes let mut new_count = 0i32; for option_id in ¶ms.option_ids { sqlx::query( "INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, voted_at) \ VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING", ) .bind(Uuid::now_v7()) .bind(poll_id) .bind(option_id) .bind(user_uid) .bind(now) .execute(&mut *txn) .await .map_err(AppError::Database)?; sqlx::query( "UPDATE message_poll_option SET vote_count = vote_count + 1 \ WHERE id = $1 AND poll_id = $2", ) .bind(option_id) .bind(poll_id) .execute(&mut *txn) .await .map_err(AppError::Database)?; new_count += 1; } let delta = new_count - removed; sqlx::query( "UPDATE message_poll SET total_votes = total_votes + $1, updated_at = $2 WHERE id = $3", ) .bind(delta) .bind(now) .bind(poll_id) .execute(&mut *txn) .await .map_err(AppError::Database)?; txn.commit().await.map_err(|_| AppError::TxnError)?; let request_id = Uuid::nil(); let event = PollEvent { channel_id, poll_id, action: PollAction::Voted, }; self.publish(&format!("im.poll.{}", channel_id), request_id, &event) .await; self.emit_event(ImEvent::Poll { request_id, data: event, }); Ok(()) } pub async fn poll_results( &self, ctx: &ImSession, _wk_name: &str, channel_id: Uuid, poll_id: Uuid, ) -> Result, AppError> { let user_uid = ctx.user; let channel = self.resolve_channel(channel_id).await?; self.ensure_channel_readable(user_uid, &channel).await?; sqlx::query_as::<_, MessagePollVote>( "SELECT id, poll_id, option_id, user_id, voted_at \ FROM message_poll_vote WHERE poll_id = $1 ORDER BY voted_at ASC", ) .bind(poll_id) .fetch_all(self.ctx.db.reader()) .await .map_err(AppError::Database) } pub async fn poll_delete( &self, ctx: &ImSession, _wk_name: &str, channel_id: Uuid, poll_id: Uuid, ) -> Result<(), AppError> { let user_uid = ctx.user; let channel = self.resolve_channel(channel_id).await?; self.ensure_channel_editable(user_uid, &channel).await?; let result = sqlx::query("DELETE FROM message_poll WHERE id = $1 AND channel_id = $2") .bind(poll_id) .bind(channel_id) .execute(self.ctx.db.writer()) .await .map_err(AppError::Database)?; ensure_affected(result.rows_affected(), "poll not found")?; let request_id = Uuid::nil(); let event = PollEvent { channel_id, poll_id, action: PollAction::Deleted, }; self.publish(&format!("im.poll.{}", channel_id), request_id, &event) .await; self.emit_event(ImEvent::Poll { request_id, data: event, }); Ok(()) } }