373 lines
12 KiB
Rust
373 lines
12 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::AppError;
|
|
use crate::immediate::{PollAction, PollEvent};
|
|
use crate::models::channels::{MessagePoll, MessagePollOption, MessagePollVote};
|
|
use crate::models::common::PollLayout;
|
|
use crate::service::ImService;
|
|
use crate::service::im::events::ImEvent;
|
|
|
|
use super::session::ImSession;
|
|
use super::util::*;
|
|
|
|
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
|
|
pub struct CreatePollParams {
|
|
pub question: String,
|
|
pub description: Option<String>,
|
|
pub options: Vec<CreatePollOptionParams>,
|
|
pub layout: Option<String>,
|
|
pub allow_multiselect: Option<bool>,
|
|
pub duration_hours: Option<i32>,
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
|
|
pub struct CreatePollOptionParams {
|
|
pub text: String,
|
|
pub emoji_id: Option<String>,
|
|
pub emoji_name: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize, Clone, Debug, utoipa::ToSchema)]
|
|
pub struct VoteParams {
|
|
pub option_ids: Vec<Uuid>,
|
|
}
|
|
|
|
impl ImService {
|
|
pub async fn poll_create(
|
|
&self,
|
|
ctx: &ImSession,
|
|
_wk_name: &str,
|
|
channel_id: Uuid,
|
|
message_id: Uuid,
|
|
params: CreatePollParams,
|
|
) -> Result<MessagePoll, AppError> {
|
|
let user_uid = ctx.user;
|
|
let channel = self.resolve_channel(channel_id).await?;
|
|
self.ensure_channel_readable(user_uid, &channel).await?;
|
|
self.resolve_message(message_id, channel_id).await?;
|
|
|
|
let question = required_text(params.question, "question")?;
|
|
if params.options.is_empty() || params.options.len() > MAX_POLL_OPTIONS {
|
|
return Err(AppError::BadRequest(format!(
|
|
"poll must have between 1 and {MAX_POLL_OPTIONS} options"
|
|
)));
|
|
}
|
|
|
|
let layout = parse_enum(
|
|
params.layout,
|
|
PollLayout::Default,
|
|
PollLayout::Unknown,
|
|
"layout",
|
|
)?;
|
|
|
|
let now = chrono::Utc::now();
|
|
let poll_id = Uuid::now_v7();
|
|
let ends_at = params
|
|
.duration_hours
|
|
.map(|h| now + chrono::Duration::hours(h as i64));
|
|
|
|
let validated_options: Vec<(String, Option<String>, Option<String>)> = params
|
|
.options
|
|
.iter()
|
|
.map(|opt| {
|
|
let text = required_text(opt.text.clone(), "option text")?;
|
|
if text.len() > MAX_POLL_OPTION_TEXT {
|
|
return Err(AppError::BadRequest("poll option text too long".into()));
|
|
}
|
|
Ok((text, opt.emoji_id.clone(), opt.emoji_name.clone()))
|
|
})
|
|
.collect::<Result<Vec<_>, AppError>>()?;
|
|
|
|
let mut txn = self
|
|
.ctx
|
|
.db
|
|
.writer()
|
|
.begin()
|
|
.await
|
|
.map_err(|_| AppError::TxnError)?;
|
|
sqlx::query("SET LOCAL app.current_user_id = $1")
|
|
.bind(user_uid)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
let poll = sqlx::query_as::<_, MessagePoll>(
|
|
"INSERT INTO message_poll \
|
|
(id, message_id, channel_id, question, description, layout, \
|
|
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
|
|
created_at, updated_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, 0, NULL, $10, $10) \
|
|
RETURNING id, message_id, channel_id, question, description, layout, \
|
|
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
|
|
created_at, updated_at",
|
|
)
|
|
.bind(poll_id)
|
|
.bind(message_id)
|
|
.bind(channel_id)
|
|
.bind(&question)
|
|
.bind(params.description.as_deref())
|
|
.bind(layout)
|
|
.bind(params.allow_multiselect.unwrap_or(false))
|
|
.bind(params.duration_hours)
|
|
.bind(ends_at)
|
|
.bind(now)
|
|
.fetch_one(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
for (i, (text, emoji_id, emoji_name)) in validated_options.iter().enumerate() {
|
|
sqlx::query(
|
|
"INSERT INTO message_poll_option \
|
|
(id, poll_id, position, text, emoji_id, emoji_name, vote_count, created_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, 0, $7)",
|
|
)
|
|
.bind(Uuid::now_v7())
|
|
.bind(poll_id)
|
|
.bind(i as i32)
|
|
.bind(text)
|
|
.bind(emoji_id.as_deref())
|
|
.bind(emoji_name.as_deref())
|
|
.bind(now)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
}
|
|
|
|
txn.commit().await.map_err(|_| AppError::TxnError)?;
|
|
tracing::info!(poll_id = %poll_id, "Poll created");
|
|
let request_id = Uuid::nil();
|
|
let event = PollEvent {
|
|
channel_id,
|
|
poll_id,
|
|
action: PollAction::Created,
|
|
};
|
|
self.publish(&format!("im.poll.{}", channel_id), request_id, &event)
|
|
.await;
|
|
self.emit_event(ImEvent::Poll {
|
|
request_id,
|
|
data: event,
|
|
});
|
|
Ok(poll)
|
|
}
|
|
|
|
pub async fn poll_get(
|
|
&self,
|
|
ctx: &ImSession,
|
|
_wk_name: &str,
|
|
channel_id: Uuid,
|
|
poll_id: Uuid,
|
|
) -> Result<(MessagePoll, Vec<MessagePollOption>), AppError> {
|
|
let user_uid = ctx.user;
|
|
let channel = self.resolve_channel(channel_id).await?;
|
|
self.ensure_channel_readable(user_uid, &channel).await?;
|
|
|
|
let poll = sqlx::query_as::<_, MessagePoll>(
|
|
"SELECT id, message_id, channel_id, question, description, layout, \
|
|
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
|
|
created_at, updated_at \
|
|
FROM message_poll WHERE id = $1 AND channel_id = $2",
|
|
)
|
|
.bind(poll_id)
|
|
.bind(channel_id)
|
|
.fetch_optional(self.ctx.db.reader())
|
|
.await
|
|
.map_err(AppError::Database)?
|
|
.ok_or(AppError::NotFound("poll not found".into()))?;
|
|
|
|
let options = sqlx::query_as::<_, MessagePollOption>(
|
|
"SELECT id, poll_id, position, text, emoji_id, emoji_name, vote_count, created_at \
|
|
FROM message_poll_option WHERE poll_id = $1 ORDER BY position ASC",
|
|
)
|
|
.bind(poll_id)
|
|
.fetch_all(self.ctx.db.reader())
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
Ok((poll, options))
|
|
}
|
|
|
|
pub async fn poll_vote(
|
|
&self,
|
|
ctx: &ImSession,
|
|
_wk_name: &str,
|
|
channel_id: Uuid,
|
|
poll_id: Uuid,
|
|
params: VoteParams,
|
|
) -> Result<(), AppError> {
|
|
let user_uid = ctx.user;
|
|
let channel = self.resolve_channel(channel_id).await?;
|
|
self.ensure_channel_readable(user_uid, &channel).await?;
|
|
|
|
let poll = sqlx::query_as::<_, MessagePoll>(
|
|
"SELECT id, message_id, channel_id, question, description, layout, \
|
|
allow_multiselect, duration_hours, ends_at, total_votes, metadata, \
|
|
created_at, updated_at \
|
|
FROM message_poll WHERE id = $1 AND channel_id = $2",
|
|
)
|
|
.bind(poll_id)
|
|
.bind(channel_id)
|
|
.fetch_optional(self.ctx.db.reader())
|
|
.await
|
|
.map_err(AppError::Database)?
|
|
.ok_or(AppError::NotFound("poll not found".into()))?;
|
|
|
|
if let Some(ends) = poll.ends_at
|
|
&& chrono::Utc::now() > ends
|
|
{
|
|
return Err(AppError::BadRequest("poll has ended".into()));
|
|
}
|
|
|
|
if !poll.allow_multiselect && params.option_ids.len() > 1 {
|
|
return Err(AppError::BadRequest("multiselect not allowed".into()));
|
|
}
|
|
|
|
let now = chrono::Utc::now();
|
|
let mut txn = self
|
|
.ctx
|
|
.db
|
|
.writer()
|
|
.begin()
|
|
.await
|
|
.map_err(|_| AppError::TxnError)?;
|
|
sqlx::query("SET LOCAL app.current_user_id = $1")
|
|
.bind(user_uid)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
// Collect old option_ids before deleting
|
|
let old_option_ids: Vec<Uuid> = sqlx::query_scalar(
|
|
"DELETE FROM message_poll_vote WHERE poll_id = $1 AND user_id = $2 RETURNING option_id",
|
|
)
|
|
.bind(poll_id)
|
|
.bind(user_uid)
|
|
.fetch_all(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
let removed = old_option_ids.len() as i32;
|
|
|
|
// Decrement old vote counts
|
|
for opt_id in &old_option_ids {
|
|
sqlx::query(
|
|
"UPDATE message_poll_option SET vote_count = GREATEST(vote_count - 1, 0) WHERE id = $1",
|
|
)
|
|
.bind(opt_id)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
}
|
|
|
|
// Insert new votes
|
|
let mut new_count = 0i32;
|
|
for option_id in ¶ms.option_ids {
|
|
sqlx::query(
|
|
"INSERT INTO message_poll_vote (id, poll_id, option_id, user_id, voted_at) \
|
|
VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING",
|
|
)
|
|
.bind(Uuid::now_v7())
|
|
.bind(poll_id)
|
|
.bind(option_id)
|
|
.bind(user_uid)
|
|
.bind(now)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
sqlx::query(
|
|
"UPDATE message_poll_option SET vote_count = vote_count + 1 \
|
|
WHERE id = $1 AND poll_id = $2",
|
|
)
|
|
.bind(option_id)
|
|
.bind(poll_id)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
new_count += 1;
|
|
}
|
|
|
|
let delta = new_count - removed;
|
|
sqlx::query(
|
|
"UPDATE message_poll SET total_votes = total_votes + $1, updated_at = $2 WHERE id = $3",
|
|
)
|
|
.bind(delta)
|
|
.bind(now)
|
|
.bind(poll_id)
|
|
.execute(&mut *txn)
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
txn.commit().await.map_err(|_| AppError::TxnError)?;
|
|
let request_id = Uuid::nil();
|
|
let event = PollEvent {
|
|
channel_id,
|
|
poll_id,
|
|
action: PollAction::Voted,
|
|
};
|
|
self.publish(&format!("im.poll.{}", channel_id), request_id, &event)
|
|
.await;
|
|
self.emit_event(ImEvent::Poll {
|
|
request_id,
|
|
data: event,
|
|
});
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn poll_results(
|
|
&self,
|
|
ctx: &ImSession,
|
|
_wk_name: &str,
|
|
channel_id: Uuid,
|
|
poll_id: Uuid,
|
|
) -> Result<Vec<MessagePollVote>, AppError> {
|
|
let user_uid = ctx.user;
|
|
let channel = self.resolve_channel(channel_id).await?;
|
|
self.ensure_channel_readable(user_uid, &channel).await?;
|
|
|
|
sqlx::query_as::<_, MessagePollVote>(
|
|
"SELECT id, poll_id, option_id, user_id, voted_at \
|
|
FROM message_poll_vote WHERE poll_id = $1 ORDER BY voted_at ASC",
|
|
)
|
|
.bind(poll_id)
|
|
.fetch_all(self.ctx.db.reader())
|
|
.await
|
|
.map_err(AppError::Database)
|
|
}
|
|
|
|
pub async fn poll_delete(
|
|
&self,
|
|
ctx: &ImSession,
|
|
_wk_name: &str,
|
|
channel_id: Uuid,
|
|
poll_id: Uuid,
|
|
) -> Result<(), AppError> {
|
|
let user_uid = ctx.user;
|
|
let channel = self.resolve_channel(channel_id).await?;
|
|
self.ensure_channel_editable(user_uid, &channel).await?;
|
|
|
|
let result = sqlx::query("DELETE FROM message_poll WHERE id = $1 AND channel_id = $2")
|
|
.bind(poll_id)
|
|
.bind(channel_id)
|
|
.execute(self.ctx.db.writer())
|
|
.await
|
|
.map_err(AppError::Database)?;
|
|
|
|
ensure_affected(result.rows_affected(), "poll not found")?;
|
|
let request_id = Uuid::nil();
|
|
let event = PollEvent {
|
|
channel_id,
|
|
poll_id,
|
|
action: PollAction::Deleted,
|
|
};
|
|
self.publish(&format!("im.poll.{}", channel_id), request_id, &event)
|
|
.await;
|
|
self.emit_event(ImEvent::Poll {
|
|
request_id,
|
|
data: event,
|
|
});
|
|
Ok(())
|
|
}
|
|
}
|