use std::pin::Pin; use tokio::time::{self, Duration, Instant}; use tokio::sync::mpsc; use tokio_stream::{Stream, wrappers::ReceiverStream}; use tonic::{Request, Response, Status}; use tracing::warn; const STREAM_STATUS_POLL_INTERVAL: Duration = Duration::from_millis(300); /// Maximum lifetime of a single streaming batch-status RPC. /// Protects against leaked streams when jobs never reach a terminal state. const STREAM_STATUS_TIMEOUT: Duration = Duration::from_secs(10 * 60); use crate::{ error::QueueError, pb::email::v1::{ BatchSendEmailRequest, BatchSendEmailResponse, GetEmailStatusRequest, GetEmailStatusResponse, SendEmailRequest, SendEmailResponse, SendStatus, email_service_server::EmailService, }, queue::EmailQueue, status::JobStatusStore, }; #[derive(Clone)] pub struct EmailServiceImpl { queue: EmailQueue, store: JobStatusStore, } impl EmailServiceImpl { pub fn new(queue: EmailQueue, store: JobStatusStore) -> Self { Self { queue, store } } } fn map_queue_err(err: QueueError) -> Status { match err { QueueError::Closed => Status::unavailable("queue closed"), QueueError::Full => Status::resource_exhausted("queue full, try later"), QueueError::IdExhausted => Status::resource_exhausted("queue id space exhausted"), } } fn build_response(id: u64, status: SendStatus) -> SendEmailResponse { SendEmailResponse { message_id: id.to_string(), status: status.into(), provider: String::new(), sent_at: None, } } fn build_failed_response(id: Option, detail: String) -> SendEmailResponse { SendEmailResponse { message_id: id.map(|v| v.to_string()).unwrap_or_default(), status: SendStatus::Failed.into(), provider: detail, sent_at: None, } } #[tonic::async_trait] impl EmailService for EmailServiceImpl { async fn send_email( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let id = self.queue.enqueue(req).map_err(map_queue_err)?; Ok(Response::new(build_response(id, SendStatus::Queued))) } async fn batch_send_email( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let total = req.emails.len() as i32; let mut success = 0i32; let mut failures = 0i32; let mut results = Vec::with_capacity(total as usize); for (i, email) in req.emails.into_iter().enumerate() { match self.queue.enqueue(email) { Ok(id) => { success += 1; results.push(build_response(id, SendStatus::Queued)); } Err(e) => { failures += 1; warn!(%e, "batch enqueue failed for one email"); if req.fail_fast { // Count remaining unprocessed emails as failures too. failures += total - (i as i32) - 1; warn!( successful = success, failed = failures, "fail_fast triggered, returning partial results" ); break; } } } } Ok(Response::new(BatchSendEmailResponse { results, success_count: success, failure_count: failures, })) } async fn get_email_status( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let id: u64 = req .message_id .parse() .map_err(|_| Status::invalid_argument("message_id must be a valid u64"))?; let entry = self .store .get(id) .ok_or_else(|| Status::not_found(format!("message_id {id} not found")))?; Ok(Response::new(GetEmailStatusResponse { message_id: id.to_string(), status: entry.status.into(), error_detail: entry.error.unwrap_or_default(), updated_at: None, })) } type StreamBatchStatusStream = Pin> + Send>>; async fn stream_batch_status( &self, request: Request, ) -> Result, Status> { let req = request.into_inner(); let mut ids = Vec::with_capacity(req.emails.len()); let mut immediate_results = Vec::new(); for email in req.emails { match self.queue.enqueue(email) { Ok(id) => ids.push(id), Err(err) => { immediate_results.push(Ok(build_failed_response(None, err.to_string()))) } } } let id_set: std::collections::HashSet = ids.iter().copied().collect(); let store = self.store.clone(); let mut missing_streak: std::collections::HashMap = std::collections::HashMap::new(); let (tx, rx) = mpsc::channel(ids.len().saturating_add(immediate_results.len()).max(1)); tokio::spawn(async move { for result in immediate_results { if tx.send(result).await.is_err() { return; } } let mut interval = time::interval(STREAM_STATUS_POLL_INTERVAL); let deadline = Instant::now() + STREAM_STATUS_TIMEOUT; let mut reported = std::collections::HashSet::new(); loop { tokio::select! { _ = tx.closed() => return, _ = time::sleep_until(deadline) => { for id in id_set.difference(&reported) { let response = build_failed_response( Some(*id), "status stream timed out before terminal state".to_owned(), ); if tx.send(Ok(response)).await.is_err() { return; } } break; } _ = interval.tick() => { for id in &id_set { if reported.contains(id) { continue; } if let Some(entry) = store.get(*id) { missing_streak.remove(id); match entry.status { SendStatus::Sent => { if tx .send(Ok(build_response(*id, SendStatus::Sent))) .await .is_err() { return; } reported.insert(*id); } SendStatus::Failed => { let response = build_failed_response( Some(*id), entry.error.unwrap_or_else(|| "unknown".into()), ); if tx.send(Ok(response)).await.is_err() { return; } reported.insert(*id); } _ => {} } } else { // Status entry may have been evicted under memory pressure. // Report as failed after a few consecutive misses. let streak = missing_streak.entry(*id).and_modify(|c| *c += 1).or_insert(1); if *streak >= 5 { if tx.send(Ok(build_failed_response( Some(*id), "status entry evicted before terminal state".into(), ))).await.is_err() { return; } reported.insert(*id); missing_streak.remove(id); } } } if reported.len() == id_set.len() { break; } } } } }); let stream: Self::StreamBatchStatusStream = Box::pin(ReceiverStream::new(rx)); Ok(Response::new(stream)) } }