use std::sync::Arc; use arc_swap::ArcSwap; use base64::{Engine as _, engine::general_purpose::STANDARD as B64}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::cache::redis::AppRedis; use crate::error::{AppError, AppResult}; /// 3-hour key validity window. const KEY_WINDOW_SECS: i64 = 3 * 3600; /// Redis key for the currently active signing key. const ACTIVE_KEY: &str = "core:token:active_key"; /// Redis prefix for all signing keys (by kid). const KEY_PREFIX: &str = "core:token:key:"; /// Redis prefix for refresh tokens. const REFRESH_PREFIX: &str = "core:token:refresh:"; /// Redis prefix for revoked token IDs (jti). const REVOKED_PREFIX: &str = "core:token:revoked:"; // ── Types ──────────────────────────────────────────────────────────────── /// A signing key used for JWT issue/verify. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SigningKeyInfo { pub kid: String, pub algorithm: String, /// Base64-encoded raw secret (for HS256). pub key_material: String, pub issued_at: i64, pub expires_at: i64, pub active: bool, } /// JWT claims embedded in every access token. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenClaims { pub sub: String, pub iss: String, pub iat: i64, pub exp: i64, pub jti: String, #[serde(default, skip_serializing_if = "String::is_empty")] pub scope: String, #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")] pub extra: std::collections::HashMap, } /// Result of issuing or refreshing a token pair. pub struct IssuedTokens { pub access_token: String, pub refresh_token: String, pub expires_at: i64, pub key_id: String, } // ── Service ────────────────────────────────────────────────────────────── #[derive(Clone)] pub struct TokenService { redis: AppRedis, /// Current active signing key, swapped atomically on rotation. current_key: Arc>, } impl TokenService { /// Create a new TokenService. /// Loads the active signing key from Redis if one exists, otherwise generates /// and stores a fresh key. pub async fn new(redis: AppRedis) -> AppResult { let svc = Self { redis, current_key: Arc::new(ArcSwap::from_pointee(Self::placeholder_key())), }; svc.load_or_create_active_key().await?; Ok(svc) } // ── Issue ──────────────────────────────────────────────────────────── pub async fn issue_token( &self, user_id: &str, ttl_secs: i64, scopes: Vec, extra: std::collections::HashMap, ) -> AppResult { let now = chrono::Utc::now().timestamp(); let key = self.current_key.load(); let claims = TokenClaims { sub: user_id.to_string(), iss: "appks".to_string(), iat: now, exp: now + ttl_secs, jti: Uuid::now_v7().to_string(), scope: scopes.join(" "), extra, }; let access_token = self.sign_jwt(&claims, &key)?; let refresh_token = self.create_refresh_token(user_id).await?; Ok(IssuedTokens { access_token, refresh_token, expires_at: claims.exp, key_id: key.kid.clone(), }) } // ── Refresh ────────────────────────────────────────────────────────── pub async fn refresh_token( &self, refresh_token: &str, access_ttl_secs: i64, ) -> AppResult { let mut conn = self.redis.get_connection(); // Look up user_id from refresh token let user_id: Option = conn .get(format!("{REFRESH_PREFIX}{refresh_token}")) .await .map_err(AppError::Redis)?; let user_id = user_id.ok_or(AppError::Unauthorized)?; // Revoke old refresh token (rotation) let _: () = conn .del(format!("{REFRESH_PREFIX}{refresh_token}")) .await .map_err(AppError::Redis)?; // Issue new token pair self.issue_token(&user_id, access_ttl_secs, vec![], Default::default()) .await } // ── Revoke ─────────────────────────────────────────────────────────── /// Revoke a single token by its jti. pub async fn revoke_by_jti(&self, jti: &str, ttl_secs: i64) -> AppResult<()> { let mut conn = self.redis.get_connection(); let _: () = conn .set_ex(format!("{REVOKED_PREFIX}{jti}"), "1", ttl_secs as u64) .await .map_err(AppError::Redis)?; Ok(()) } /// Revoke all tokens for a user (deletes all their refresh tokens). pub async fn revoke_user_tokens(&self, user_id: &str) -> AppResult { let mut conn = self.redis.get_connection(); let pattern = format!("{REFRESH_PREFIX}*"); let keys: Vec = redis::cmd("KEYS") .arg(&pattern) .query_async(&mut conn) .await .map_err(AppError::Redis)?; let mut count = 0u32; for key in keys { let stored_uid: Option = conn.get(&key).await.map_err(AppError::Redis)?; if stored_uid.as_deref() == Some(user_id) { let _: () = conn.del(&key).await.map_err(AppError::Redis)?; count += 1; } } Ok(count) } // ── Verify ─────────────────────────────────────────────────────────── /// Verify a JWT access token. Returns the claims if valid, or a reason string if not. pub async fn verify_token( &self, token: &str, ) -> AppResult> { // 1. Decode header to get kid let header = match decode_header(token) { Ok(h) => h, Err(_) => return Ok(Err("invalid_token".to_string())), }; let kid = match &header.kid { Some(k) => k.clone(), None => return Ok(Err("missing_kid".to_string())), }; // 2. Find signing key by kid let key_info = match self.find_key(&kid).await? { Some(k) => k, None => return Ok(Err("unknown_key".to_string())), }; // 3. Decode + validate JWT let mut validation = Validation::new(Algorithm::HS256); validation.validate_exp = true; validation.set_issuer(&["appks"]); validation.required_spec_claims.clear(); let secret_bytes = B64 .decode(&key_info.key_material) .map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?; let decoding_key = DecodingKey::from_secret(&secret_bytes); let token_data: TokenData = match decode(token, &decoding_key, &validation) { Ok(td) => td, Err(e) => { let reason = match e.kind() { jsonwebtoken::errors::ErrorKind::ExpiredSignature => "expired", jsonwebtoken::errors::ErrorKind::InvalidSignature => "invalid_signature", jsonwebtoken::errors::ErrorKind::InvalidIssuer => "invalid_issuer", _ => "invalid", }; return Ok(Err(reason.to_string())); } }; // 4. Check revocation if self.is_revoked(&token_data.claims.jti).await? { return Ok(Err("revoked".to_string())); } Ok(Ok(token_data.claims)) } // ── Key management ─────────────────────────────────────────────────── /// Return all non-expired signing keys (for GetSigningKeys RPC). pub async fn get_signing_keys(&self) -> AppResult<(Vec, i64)> { let mut conn = self.redis.get_connection(); let key_ids: Vec = redis::cmd("KEYS") .arg(format!("{KEY_PREFIX}*")) .query_async(&mut conn) .await .map_err(AppError::Redis)?; let now = chrono::Utc::now().timestamp(); let mut keys = Vec::new(); for redis_key in key_ids { let json: Option = conn.get(&redis_key).await.map_err(AppError::Redis)?; if let Some(json) = json { if let Ok(info) = serde_json::from_str::(&json) { if info.expires_at > now { keys.push(info); } } } } let next_rotation_at = self .current_key .load() .issued_at + KEY_WINDOW_SECS; Ok((keys, next_rotation_at)) } /// Rotate signing keys if the current key is past its window. pub async fn rotate_if_needed(&self) -> AppResult { let now = chrono::Utc::now().timestamp(); let current = self.current_key.load(); if now < current.issued_at + KEY_WINDOW_SECS { return Ok(false); } // Double-check lock via Redis let lock_key = "core:token:rotation_lock"; let mut conn = self.redis.get_connection(); let acquired: bool = redis::cmd("SET") .arg(lock_key) .arg("1") .arg("NX") .arg("EX") .arg(10) .query_async(&mut conn) .await .map_err(AppError::Redis)?; if !acquired { // Another instance is rotating; reload from Redis self.load_or_create_active_key().await?; return Ok(false); } // Mark old key as inactive let mut old: SigningKeyInfo = (**current).clone(); old.active = false; self.store_key(&old).await?; // Generate new active key let new_key = Self::generate_key(true); self.store_key(&new_key).await?; self.current_key.store(Arc::new(new_key)); let _: () = conn.del(lock_key).await.map_err(AppError::Redis)?; Ok(true) } // ── Internal helpers ───────────────────────────────────────────────── fn generate_key(active: bool) -> SigningKeyInfo { use rand::RngCore; let mut secret = [0u8; 32]; rand::thread_rng().fill_bytes(&mut secret); let now = chrono::Utc::now().timestamp(); SigningKeyInfo { kid: Uuid::now_v7().to_string(), algorithm: "HS256".to_string(), key_material: B64.encode(secret), issued_at: now, expires_at: now + KEY_WINDOW_SECS, active, } } fn placeholder_key() -> SigningKeyInfo { SigningKeyInfo { kid: "placeholder".to_string(), algorithm: "HS256".to_string(), key_material: String::new(), issued_at: 0, expires_at: 0, active: false, } } async fn load_or_create_active_key(&self) -> AppResult<()> { let mut conn = self.redis.get_connection(); // Try loading the active key pointer from Redis let active_kid: Option = conn.get(ACTIVE_KEY).await.map_err(AppError::Redis)?; if let Some(kid) = active_kid { let redis_key = format!("{KEY_PREFIX}{kid}"); let json: Option = conn.get(&redis_key).await.map_err(AppError::Redis)?; if let Some(json) = json { let info: SigningKeyInfo = serde_json::from_str(&json)?; let now = chrono::Utc::now().timestamp(); if info.expires_at > now { self.current_key.store(Arc::new(info)); return Ok(()); } // Expired — fall through to generate new key } } // No valid active key — generate one let new_key = Self::generate_key(true); self.store_key(&new_key).await?; self.current_key.store(Arc::new(new_key)); Ok(()) } async fn store_key(&self, info: &SigningKeyInfo) -> AppResult<()> { let mut conn = self.redis.get_connection(); let redis_key = format!("{KEY_PREFIX}{}", info.kid); let json = serde_json::to_string(info)?; // Key lives in Redis for 2× the window (overlap for verification of older tokens) let _: () = conn .set_ex(&redis_key, &json, (KEY_WINDOW_SECS * 2) as u64) .await .map_err(AppError::Redis)?; if info.active { let _: () = conn .set_ex(ACTIVE_KEY, &info.kid, (KEY_WINDOW_SECS * 2) as u64) .await .map_err(AppError::Redis)?; } Ok(()) } async fn find_key(&self, kid: &str) -> AppResult> { // Fast path: check current active key let current = self.current_key.load(); if current.kid == kid { return Ok(Some((**current).clone())); } // Slow path: look up from Redis let mut conn = self.redis.get_connection(); let redis_key = format!("{KEY_PREFIX}{kid}"); let json: Option = conn.get(&redis_key).await.map_err(AppError::Redis)?; match json { Some(j) => Ok(Some(serde_json::from_str(&j)?)), None => Ok(None), } } fn sign_jwt(&self, claims: &TokenClaims, key: &SigningKeyInfo) -> AppResult { let secret_bytes = B64 .decode(&key.key_material) .map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?; let encoding_key = EncodingKey::from_secret(&secret_bytes); let mut header = Header::new(Algorithm::HS256); header.kid = Some(key.kid.clone()); encode(&header, claims, &encoding_key) .map_err(|e| AppError::InternalServerError(format!("JWT encode error: {e}"))) } async fn create_refresh_token(&self, user_id: &str) -> AppResult { let token = format!("rt_{}", Uuid::now_v7()); let key = format!("{REFRESH_PREFIX}{token}"); let mut conn = self.redis.get_connection(); // Refresh tokens live for 7 days let _: () = conn .set_ex(&key, user_id, 86400 * 7) .await .map_err(AppError::Redis)?; Ok(token) } async fn is_revoked(&self, jti: &str) -> AppResult { let mut conn = self.redis.get_connection(); let exists: bool = conn .exists(format!("{REVOKED_PREFIX}{jti}")) .await .map_err(AppError::Redis)?; Ok(exists) } } // ── Helpers ────────────────────────────────────────────────────────────── fn decode_header(token: &str) -> Result { jsonwebtoken::decode_header(token) }