dbbfb747a4
- Replace InternalAuthService with TokenService using JWT tokens - Add support for token issuance, refresh, verification and revocation - Implement automatic signing key rotation with Redis storage - Add database migration checks for indexes and foreign key constraints - Update gRPC endpoints to use token-based authentication - Remove deprecated API key based authentication system - Add JSON Web Token support with HMAC-SHA256 signing - Implement refresh token handling with automatic rotation - Add token revocation by JTI and user ID - Update build configuration to include core proto files - Migrate database schema to handle token-based authentication - Add comprehensive token validation and verification logic
447 lines
16 KiB
Rust
447 lines
16 KiB
Rust
use std::sync::Arc;
|
||
|
||
use arc_swap::ArcSwap;
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
|
||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
|
||
use redis::AsyncCommands;
|
||
use serde::{Deserialize, Serialize};
|
||
use uuid::Uuid;
|
||
|
||
use crate::cache::redis::AppRedis;
|
||
use crate::error::{AppError, AppResult};
|
||
|
||
/// 3-hour key validity window.
|
||
const KEY_WINDOW_SECS: i64 = 3 * 3600;
|
||
/// Redis key for the currently active signing key.
|
||
const ACTIVE_KEY: &str = "core:token:active_key";
|
||
/// Redis prefix for all signing keys (by kid).
|
||
const KEY_PREFIX: &str = "core:token:key:";
|
||
/// Redis prefix for refresh tokens.
|
||
const REFRESH_PREFIX: &str = "core:token:refresh:";
|
||
/// Redis prefix for revoked token IDs (jti).
|
||
const REVOKED_PREFIX: &str = "core:token:revoked:";
|
||
|
||
// ── Types ────────────────────────────────────────────────────────────────
|
||
|
||
/// A signing key used for JWT issue/verify.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct SigningKeyInfo {
|
||
pub kid: String,
|
||
pub algorithm: String,
|
||
/// Base64-encoded raw secret (for HS256).
|
||
pub key_material: String,
|
||
pub issued_at: i64,
|
||
pub expires_at: i64,
|
||
pub active: bool,
|
||
}
|
||
|
||
/// JWT claims embedded in every access token.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct TokenClaims {
|
||
pub sub: String,
|
||
pub iss: String,
|
||
pub iat: i64,
|
||
pub exp: i64,
|
||
pub jti: String,
|
||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||
pub scope: String,
|
||
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
|
||
pub extra: std::collections::HashMap<String, String>,
|
||
}
|
||
|
||
/// Result of issuing or refreshing a token pair.
|
||
pub struct IssuedTokens {
|
||
pub access_token: String,
|
||
pub refresh_token: String,
|
||
pub expires_at: i64,
|
||
pub key_id: String,
|
||
}
|
||
|
||
// ── Service ──────────────────────────────────────────────────────────────
|
||
|
||
#[derive(Clone)]
|
||
pub struct TokenService {
|
||
redis: AppRedis,
|
||
/// Current active signing key, swapped atomically on rotation.
|
||
current_key: Arc<ArcSwap<SigningKeyInfo>>,
|
||
}
|
||
|
||
impl TokenService {
|
||
/// Create a new TokenService.
|
||
/// Loads the active signing key from Redis if one exists, otherwise generates
|
||
/// and stores a fresh key.
|
||
pub async fn new(redis: AppRedis) -> AppResult<Self> {
|
||
let svc = Self {
|
||
redis,
|
||
current_key: Arc::new(ArcSwap::from_pointee(Self::placeholder_key())),
|
||
};
|
||
svc.load_or_create_active_key().await?;
|
||
Ok(svc)
|
||
}
|
||
|
||
// ── Issue ────────────────────────────────────────────────────────────
|
||
|
||
pub async fn issue_token(
|
||
&self,
|
||
user_id: &str,
|
||
ttl_secs: i64,
|
||
scopes: Vec<String>,
|
||
extra: std::collections::HashMap<String, String>,
|
||
) -> AppResult<IssuedTokens> {
|
||
let now = chrono::Utc::now().timestamp();
|
||
let key = self.current_key.load();
|
||
|
||
let claims = TokenClaims {
|
||
sub: user_id.to_string(),
|
||
iss: "appks".to_string(),
|
||
iat: now,
|
||
exp: now + ttl_secs,
|
||
jti: Uuid::now_v7().to_string(),
|
||
scope: scopes.join(" "),
|
||
extra,
|
||
};
|
||
|
||
let access_token = self.sign_jwt(&claims, &key)?;
|
||
let refresh_token = self.create_refresh_token(user_id).await?;
|
||
|
||
Ok(IssuedTokens {
|
||
access_token,
|
||
refresh_token,
|
||
expires_at: claims.exp,
|
||
key_id: key.kid.clone(),
|
||
})
|
||
}
|
||
|
||
// ── Refresh ──────────────────────────────────────────────────────────
|
||
|
||
pub async fn refresh_token(
|
||
&self,
|
||
refresh_token: &str,
|
||
access_ttl_secs: i64,
|
||
) -> AppResult<IssuedTokens> {
|
||
let mut conn = self.redis.get_connection();
|
||
|
||
// Look up user_id from refresh token
|
||
let user_id: Option<String> = conn
|
||
.get(format!("{REFRESH_PREFIX}{refresh_token}"))
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
let user_id = user_id.ok_or(AppError::Unauthorized)?;
|
||
|
||
// Revoke old refresh token (rotation)
|
||
let _: () = conn
|
||
.del(format!("{REFRESH_PREFIX}{refresh_token}"))
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
// Issue new token pair
|
||
self.issue_token(&user_id, access_ttl_secs, vec![], Default::default())
|
||
.await
|
||
}
|
||
|
||
// ── Revoke ───────────────────────────────────────────────────────────
|
||
|
||
/// Revoke a single token by its jti.
|
||
pub async fn revoke_by_jti(&self, jti: &str, ttl_secs: i64) -> AppResult<()> {
|
||
let mut conn = self.redis.get_connection();
|
||
let _: () = conn
|
||
.set_ex(format!("{REVOKED_PREFIX}{jti}"), "1", ttl_secs as u64)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
Ok(())
|
||
}
|
||
|
||
/// Revoke all tokens for a user (deletes all their refresh tokens).
|
||
pub async fn revoke_user_tokens(&self, user_id: &str) -> AppResult<u32> {
|
||
let mut conn = self.redis.get_connection();
|
||
let pattern = format!("{REFRESH_PREFIX}*");
|
||
|
||
let keys: Vec<String> = redis::cmd("KEYS")
|
||
.arg(&pattern)
|
||
.query_async(&mut conn)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
let mut count = 0u32;
|
||
for key in keys {
|
||
let stored_uid: Option<String> = conn.get(&key).await.map_err(AppError::Redis)?;
|
||
if stored_uid.as_deref() == Some(user_id) {
|
||
let _: () = conn.del(&key).await.map_err(AppError::Redis)?;
|
||
count += 1;
|
||
}
|
||
}
|
||
Ok(count)
|
||
}
|
||
|
||
// ── Verify ───────────────────────────────────────────────────────────
|
||
|
||
/// Verify a JWT access token. Returns the claims if valid, or a reason string if not.
|
||
pub async fn verify_token(
|
||
&self,
|
||
token: &str,
|
||
) -> AppResult<Result<TokenClaims, String>> {
|
||
// 1. Decode header to get kid
|
||
let header = match decode_header(token) {
|
||
Ok(h) => h,
|
||
Err(_) => return Ok(Err("invalid_token".to_string())),
|
||
};
|
||
|
||
let kid = match &header.kid {
|
||
Some(k) => k.clone(),
|
||
None => return Ok(Err("missing_kid".to_string())),
|
||
};
|
||
|
||
// 2. Find signing key by kid
|
||
let key_info = match self.find_key(&kid).await? {
|
||
Some(k) => k,
|
||
None => return Ok(Err("unknown_key".to_string())),
|
||
};
|
||
|
||
// 3. Decode + validate JWT
|
||
let mut validation = Validation::new(Algorithm::HS256);
|
||
validation.validate_exp = true;
|
||
validation.set_issuer(&["appks"]);
|
||
validation.required_spec_claims.clear();
|
||
|
||
let secret_bytes = B64
|
||
.decode(&key_info.key_material)
|
||
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
|
||
let decoding_key = DecodingKey::from_secret(&secret_bytes);
|
||
|
||
let token_data: TokenData<TokenClaims> = match decode(token, &decoding_key, &validation) {
|
||
Ok(td) => td,
|
||
Err(e) => {
|
||
let reason = match e.kind() {
|
||
jsonwebtoken::errors::ErrorKind::ExpiredSignature => "expired",
|
||
jsonwebtoken::errors::ErrorKind::InvalidSignature => "invalid_signature",
|
||
jsonwebtoken::errors::ErrorKind::InvalidIssuer => "invalid_issuer",
|
||
_ => "invalid",
|
||
};
|
||
return Ok(Err(reason.to_string()));
|
||
}
|
||
};
|
||
|
||
// 4. Check revocation
|
||
if self.is_revoked(&token_data.claims.jti).await? {
|
||
return Ok(Err("revoked".to_string()));
|
||
}
|
||
|
||
Ok(Ok(token_data.claims))
|
||
}
|
||
|
||
// ── Key management ───────────────────────────────────────────────────
|
||
|
||
/// Return all non-expired signing keys (for GetSigningKeys RPC).
|
||
pub async fn get_signing_keys(&self) -> AppResult<(Vec<SigningKeyInfo>, i64)> {
|
||
let mut conn = self.redis.get_connection();
|
||
|
||
let key_ids: Vec<String> = redis::cmd("KEYS")
|
||
.arg(format!("{KEY_PREFIX}*"))
|
||
.query_async(&mut conn)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
let now = chrono::Utc::now().timestamp();
|
||
let mut keys = Vec::new();
|
||
|
||
for redis_key in key_ids {
|
||
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
|
||
if let Some(json) = json {
|
||
if let Ok(info) = serde_json::from_str::<SigningKeyInfo>(&json) {
|
||
if info.expires_at > now {
|
||
keys.push(info);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
let next_rotation_at = self
|
||
.current_key
|
||
.load()
|
||
.issued_at
|
||
+ KEY_WINDOW_SECS;
|
||
|
||
Ok((keys, next_rotation_at))
|
||
}
|
||
|
||
/// Rotate signing keys if the current key is past its window.
|
||
pub async fn rotate_if_needed(&self) -> AppResult<bool> {
|
||
let now = chrono::Utc::now().timestamp();
|
||
let current = self.current_key.load();
|
||
|
||
if now < current.issued_at + KEY_WINDOW_SECS {
|
||
return Ok(false);
|
||
}
|
||
|
||
// Double-check lock via Redis
|
||
let lock_key = "core:token:rotation_lock";
|
||
let mut conn = self.redis.get_connection();
|
||
let acquired: bool = redis::cmd("SET")
|
||
.arg(lock_key)
|
||
.arg("1")
|
||
.arg("NX")
|
||
.arg("EX")
|
||
.arg(10)
|
||
.query_async(&mut conn)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
if !acquired {
|
||
// Another instance is rotating; reload from Redis
|
||
self.load_or_create_active_key().await?;
|
||
return Ok(false);
|
||
}
|
||
|
||
// Mark old key as inactive
|
||
let mut old: SigningKeyInfo = (**current).clone();
|
||
old.active = false;
|
||
self.store_key(&old).await?;
|
||
|
||
// Generate new active key
|
||
let new_key = Self::generate_key(true);
|
||
self.store_key(&new_key).await?;
|
||
self.current_key.store(Arc::new(new_key));
|
||
|
||
let _: () = conn.del(lock_key).await.map_err(AppError::Redis)?;
|
||
|
||
Ok(true)
|
||
}
|
||
|
||
// ── Internal helpers ─────────────────────────────────────────────────
|
||
|
||
fn generate_key(active: bool) -> SigningKeyInfo {
|
||
use rand::RngCore;
|
||
let mut secret = [0u8; 32];
|
||
rand::thread_rng().fill_bytes(&mut secret);
|
||
|
||
let now = chrono::Utc::now().timestamp();
|
||
SigningKeyInfo {
|
||
kid: Uuid::now_v7().to_string(),
|
||
algorithm: "HS256".to_string(),
|
||
key_material: B64.encode(secret),
|
||
issued_at: now,
|
||
expires_at: now + KEY_WINDOW_SECS,
|
||
active,
|
||
}
|
||
}
|
||
|
||
fn placeholder_key() -> SigningKeyInfo {
|
||
SigningKeyInfo {
|
||
kid: "placeholder".to_string(),
|
||
algorithm: "HS256".to_string(),
|
||
key_material: String::new(),
|
||
issued_at: 0,
|
||
expires_at: 0,
|
||
active: false,
|
||
}
|
||
}
|
||
|
||
async fn load_or_create_active_key(&self) -> AppResult<()> {
|
||
let mut conn = self.redis.get_connection();
|
||
|
||
// Try loading the active key pointer from Redis
|
||
let active_kid: Option<String> = conn.get(ACTIVE_KEY).await.map_err(AppError::Redis)?;
|
||
|
||
if let Some(kid) = active_kid {
|
||
let redis_key = format!("{KEY_PREFIX}{kid}");
|
||
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
|
||
if let Some(json) = json {
|
||
let info: SigningKeyInfo = serde_json::from_str(&json)?;
|
||
let now = chrono::Utc::now().timestamp();
|
||
if info.expires_at > now {
|
||
self.current_key.store(Arc::new(info));
|
||
return Ok(());
|
||
}
|
||
// Expired — fall through to generate new key
|
||
}
|
||
}
|
||
|
||
// No valid active key — generate one
|
||
let new_key = Self::generate_key(true);
|
||
self.store_key(&new_key).await?;
|
||
self.current_key.store(Arc::new(new_key));
|
||
Ok(())
|
||
}
|
||
|
||
async fn store_key(&self, info: &SigningKeyInfo) -> AppResult<()> {
|
||
let mut conn = self.redis.get_connection();
|
||
let redis_key = format!("{KEY_PREFIX}{}", info.kid);
|
||
let json = serde_json::to_string(info)?;
|
||
|
||
// Key lives in Redis for 2× the window (overlap for verification of older tokens)
|
||
let _: () = conn
|
||
.set_ex(&redis_key, &json, (KEY_WINDOW_SECS * 2) as u64)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
if info.active {
|
||
let _: () = conn
|
||
.set_ex(ACTIVE_KEY, &info.kid, (KEY_WINDOW_SECS * 2) as u64)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
async fn find_key(&self, kid: &str) -> AppResult<Option<SigningKeyInfo>> {
|
||
// Fast path: check current active key
|
||
let current = self.current_key.load();
|
||
if current.kid == kid {
|
||
return Ok(Some((**current).clone()));
|
||
}
|
||
|
||
// Slow path: look up from Redis
|
||
let mut conn = self.redis.get_connection();
|
||
let redis_key = format!("{KEY_PREFIX}{kid}");
|
||
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
|
||
|
||
match json {
|
||
Some(j) => Ok(Some(serde_json::from_str(&j)?)),
|
||
None => Ok(None),
|
||
}
|
||
}
|
||
|
||
fn sign_jwt(&self, claims: &TokenClaims, key: &SigningKeyInfo) -> AppResult<String> {
|
||
let secret_bytes = B64
|
||
.decode(&key.key_material)
|
||
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
|
||
let encoding_key = EncodingKey::from_secret(&secret_bytes);
|
||
|
||
let mut header = Header::new(Algorithm::HS256);
|
||
header.kid = Some(key.kid.clone());
|
||
|
||
encode(&header, claims, &encoding_key)
|
||
.map_err(|e| AppError::InternalServerError(format!("JWT encode error: {e}")))
|
||
}
|
||
|
||
async fn create_refresh_token(&self, user_id: &str) -> AppResult<String> {
|
||
let token = format!("rt_{}", Uuid::now_v7());
|
||
let key = format!("{REFRESH_PREFIX}{token}");
|
||
let mut conn = self.redis.get_connection();
|
||
|
||
// Refresh tokens live for 7 days
|
||
let _: () = conn
|
||
.set_ex(&key, user_id, 86400 * 7)
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
|
||
Ok(token)
|
||
}
|
||
|
||
async fn is_revoked(&self, jti: &str) -> AppResult<bool> {
|
||
let mut conn = self.redis.get_connection();
|
||
let exists: bool = conn
|
||
.exists(format!("{REVOKED_PREFIX}{jti}"))
|
||
.await
|
||
.map_err(AppError::Redis)?;
|
||
Ok(exists)
|
||
}
|
||
}
|
||
|
||
// ── Helpers ──────────────────────────────────────────────────────────────
|
||
|
||
fn decode_header(token: &str) -> Result<Header, jsonwebtoken::errors::Error> {
|
||
jsonwebtoken::decode_header(token)
|
||
}
|