feat(auth): replace internal auth with JWT token service

- Replace InternalAuthService with TokenService using JWT tokens
- Add support for token issuance, refresh, verification and revocation
- Implement automatic signing key rotation with Redis storage
- Add database migration checks for indexes and foreign key constraints
- Update gRPC endpoints to use token-based authentication
- Remove deprecated API key based authentication system
- Add JSON Web Token support with HMAC-SHA256 signing
- Implement refresh token handling with automatic rotation
- Add token revocation by JTI and user ID
- Update build configuration to include core proto files
- Migrate database schema to handle token-based authentication
- Add comprehensive token validation and verification logic
This commit is contained in:
zhenyi
2026-06-11 15:08:13 +08:00
parent a0bea36041
commit dbbfb747a4
16 changed files with 833 additions and 186 deletions
+401 -52
View File
@@ -1,97 +1,446 @@
use std::sync::Arc;
use arc_swap::ArcSwap;
use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
const API_KEY_PREFIX: &str = "internal:auth:";
const DEFAULT_TTL_SECS: u64 = 86400 * 30;
/// 3-hour key validity window.
const KEY_WINDOW_SECS: i64 = 3 * 3600;
/// Redis key for the currently active signing key.
const ACTIVE_KEY: &str = "core:token:active_key";
/// Redis prefix for all signing keys (by kid).
const KEY_PREFIX: &str = "core:token:key:";
/// Redis prefix for refresh tokens.
const REFRESH_PREFIX: &str = "core:token:refresh:";
/// Redis prefix for revoked token IDs (jti).
const REVOKED_PREFIX: &str = "core:token:revoked:";
// ── Types ────────────────────────────────────────────────────────────────
/// A signing key used for JWT issue/verify.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceIdentity {
pub service_name: String,
pub service_id: String,
pub scopes: Vec<String>,
pub struct SigningKeyInfo {
pub kid: String,
pub algorithm: String,
/// Base64-encoded raw secret (for HS256).
pub key_material: String,
pub issued_at: i64,
pub expires_at: i64,
pub active: bool,
}
/// JWT claims embedded in every access token.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
pub sub: String,
pub iss: String,
pub iat: i64,
pub exp: i64,
pub jti: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub scope: String,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub extra: std::collections::HashMap<String, String>,
}
/// Result of issuing or refreshing a token pair.
pub struct IssuedTokens {
pub access_token: String,
pub refresh_token: String,
pub expires_at: i64,
pub key_id: String,
}
// ── Service ──────────────────────────────────────────────────────────────
#[derive(Clone)]
pub struct InternalAuthService {
pub struct TokenService {
redis: AppRedis,
/// Current active signing key, swapped atomically on rotation.
current_key: Arc<ArcSwap<SigningKeyInfo>>,
}
impl InternalAuthService {
pub fn new(redis: AppRedis) -> Self {
Self { redis }
impl TokenService {
/// Create a new TokenService.
/// Loads the active signing key from Redis if one exists, otherwise generates
/// and stores a fresh key.
pub async fn new(redis: AppRedis) -> AppResult<Self> {
let svc = Self {
redis,
current_key: Arc::new(ArcSwap::from_pointee(Self::placeholder_key())),
};
svc.load_or_create_active_key().await?;
Ok(svc)
}
pub async fn issue_api_key(
&self,
service_name: &str,
scopes: Vec<String>,
ttl_secs: Option<u64>,
) -> AppResult<(String, ServiceIdentity)> {
let ttl = ttl_secs.unwrap_or(DEFAULT_TTL_SECS);
let now = chrono::Utc::now().timestamp();
let expires_at = now + ttl as i64;
// ── Issue ────────────────────────────────────────────────────────────
let identity = ServiceIdentity {
service_name: service_name.to_string(),
service_id: Uuid::now_v7().to_string(),
scopes,
issued_at: now,
expires_at,
pub async fn issue_token(
&self,
user_id: &str,
ttl_secs: i64,
scopes: Vec<String>,
extra: std::collections::HashMap<String, String>,
) -> AppResult<IssuedTokens> {
let now = chrono::Utc::now().timestamp();
let key = self.current_key.load();
let claims = TokenClaims {
sub: user_id.to_string(),
iss: "appks".to_string(),
iat: now,
exp: now + ttl_secs,
jti: Uuid::now_v7().to_string(),
scope: scopes.join(" "),
extra,
};
let api_key = format!("im_{}", Uuid::now_v7());
let key = format!("{API_KEY_PREFIX}{api_key}");
let json = serde_json::to_string(&identity)?;
let access_token = self.sign_jwt(&claims, &key)?;
let refresh_token = self.create_refresh_token(user_id).await?;
Ok(IssuedTokens {
access_token,
refresh_token,
expires_at: claims.exp,
key_id: key.kid.clone(),
})
}
// ── Refresh ──────────────────────────────────────────────────────────
pub async fn refresh_token(
&self,
refresh_token: &str,
access_ttl_secs: i64,
) -> AppResult<IssuedTokens> {
let mut conn = self.redis.get_connection();
redis::Cmd::new()
.arg("SETEX")
.arg(&key)
.arg(ttl)
.arg(&json)
.query_async::<()>(&mut conn)
// Look up user_id from refresh token
let user_id: Option<String> = conn
.get(format!("{REFRESH_PREFIX}{refresh_token}"))
.await
.map_err(AppError::Redis)?;
Ok((api_key, identity))
let user_id = user_id.ok_or(AppError::Unauthorized)?;
// Revoke old refresh token (rotation)
let _: () = conn
.del(format!("{REFRESH_PREFIX}{refresh_token}"))
.await
.map_err(AppError::Redis)?;
// Issue new token pair
self.issue_token(&user_id, access_ttl_secs, vec![], Default::default())
.await
}
pub async fn verify_api_key(&self, api_key: &str) -> AppResult<Option<ServiceIdentity>> {
let key = format!("{API_KEY_PREFIX}{api_key}");
let mut conn = self.redis.get_connection();
// ── Revoke ───────────────────────────────────────────────────────────
let json: Option<String> = redis::Cmd::new()
.arg("GET")
.arg(&key)
/// Revoke a single token by its jti.
pub async fn revoke_by_jti(&self, jti: &str, ttl_secs: i64) -> AppResult<()> {
let mut conn = self.redis.get_connection();
let _: () = conn
.set_ex(format!("{REVOKED_PREFIX}{jti}"), "1", ttl_secs as u64)
.await
.map_err(AppError::Redis)?;
Ok(())
}
/// Revoke all tokens for a user (deletes all their refresh tokens).
pub async fn revoke_user_tokens(&self, user_id: &str) -> AppResult<u32> {
let mut conn = self.redis.get_connection();
let pattern = format!("{REFRESH_PREFIX}*");
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern)
.query_async(&mut conn)
.await
.map_err(AppError::Redis)?;
match json {
Some(j) => {
let identity: ServiceIdentity = serde_json::from_str(&j)?;
Ok(Some(identity))
let mut count = 0u32;
for key in keys {
let stored_uid: Option<String> = conn.get(&key).await.map_err(AppError::Redis)?;
if stored_uid.as_deref() == Some(user_id) {
let _: () = conn.del(&key).await.map_err(AppError::Redis)?;
count += 1;
}
}
Ok(count)
}
// ── Verify ───────────────────────────────────────────────────────────
/// Verify a JWT access token. Returns the claims if valid, or a reason string if not.
pub async fn verify_token(
&self,
token: &str,
) -> AppResult<Result<TokenClaims, String>> {
// 1. Decode header to get kid
let header = match decode_header(token) {
Ok(h) => h,
Err(_) => return Ok(Err("invalid_token".to_string())),
};
let kid = match &header.kid {
Some(k) => k.clone(),
None => return Ok(Err("missing_kid".to_string())),
};
// 2. Find signing key by kid
let key_info = match self.find_key(&kid).await? {
Some(k) => k,
None => return Ok(Err("unknown_key".to_string())),
};
// 3. Decode + validate JWT
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = true;
validation.set_issuer(&["appks"]);
validation.required_spec_claims.clear();
let secret_bytes = B64
.decode(&key_info.key_material)
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
let decoding_key = DecodingKey::from_secret(&secret_bytes);
let token_data: TokenData<TokenClaims> = match decode(token, &decoding_key, &validation) {
Ok(td) => td,
Err(e) => {
let reason = match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => "expired",
jsonwebtoken::errors::ErrorKind::InvalidSignature => "invalid_signature",
jsonwebtoken::errors::ErrorKind::InvalidIssuer => "invalid_issuer",
_ => "invalid",
};
return Ok(Err(reason.to_string()));
}
};
// 4. Check revocation
if self.is_revoked(&token_data.claims.jti).await? {
return Ok(Err("revoked".to_string()));
}
Ok(Ok(token_data.claims))
}
// ── Key management ───────────────────────────────────────────────────
/// Return all non-expired signing keys (for GetSigningKeys RPC).
pub async fn get_signing_keys(&self) -> AppResult<(Vec<SigningKeyInfo>, i64)> {
let mut conn = self.redis.get_connection();
let key_ids: Vec<String> = redis::cmd("KEYS")
.arg(format!("{KEY_PREFIX}*"))
.query_async(&mut conn)
.await
.map_err(AppError::Redis)?;
let now = chrono::Utc::now().timestamp();
let mut keys = Vec::new();
for redis_key in key_ids {
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
if let Some(json) = json {
if let Ok(info) = serde_json::from_str::<SigningKeyInfo>(&json) {
if info.expires_at > now {
keys.push(info);
}
}
}
}
let next_rotation_at = self
.current_key
.load()
.issued_at
+ KEY_WINDOW_SECS;
Ok((keys, next_rotation_at))
}
/// Rotate signing keys if the current key is past its window.
pub async fn rotate_if_needed(&self) -> AppResult<bool> {
let now = chrono::Utc::now().timestamp();
let current = self.current_key.load();
if now < current.issued_at + KEY_WINDOW_SECS {
return Ok(false);
}
// Double-check lock via Redis
let lock_key = "core:token:rotation_lock";
let mut conn = self.redis.get_connection();
let acquired: bool = redis::cmd("SET")
.arg(lock_key)
.arg("1")
.arg("NX")
.arg("EX")
.arg(10)
.query_async(&mut conn)
.await
.map_err(AppError::Redis)?;
if !acquired {
// Another instance is rotating; reload from Redis
self.load_or_create_active_key().await?;
return Ok(false);
}
// Mark old key as inactive
let mut old: SigningKeyInfo = (**current).clone();
old.active = false;
self.store_key(&old).await?;
// Generate new active key
let new_key = Self::generate_key(true);
self.store_key(&new_key).await?;
self.current_key.store(Arc::new(new_key));
let _: () = conn.del(lock_key).await.map_err(AppError::Redis)?;
Ok(true)
}
// ── Internal helpers ─────────────────────────────────────────────────
fn generate_key(active: bool) -> SigningKeyInfo {
use rand::RngCore;
let mut secret = [0u8; 32];
rand::thread_rng().fill_bytes(&mut secret);
let now = chrono::Utc::now().timestamp();
SigningKeyInfo {
kid: Uuid::now_v7().to_string(),
algorithm: "HS256".to_string(),
key_material: B64.encode(secret),
issued_at: now,
expires_at: now + KEY_WINDOW_SECS,
active,
}
}
fn placeholder_key() -> SigningKeyInfo {
SigningKeyInfo {
kid: "placeholder".to_string(),
algorithm: "HS256".to_string(),
key_material: String::new(),
issued_at: 0,
expires_at: 0,
active: false,
}
}
async fn load_or_create_active_key(&self) -> AppResult<()> {
let mut conn = self.redis.get_connection();
// Try loading the active key pointer from Redis
let active_kid: Option<String> = conn.get(ACTIVE_KEY).await.map_err(AppError::Redis)?;
if let Some(kid) = active_kid {
let redis_key = format!("{KEY_PREFIX}{kid}");
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
if let Some(json) = json {
let info: SigningKeyInfo = serde_json::from_str(&json)?;
let now = chrono::Utc::now().timestamp();
if info.expires_at > now {
self.current_key.store(Arc::new(info));
return Ok(());
}
// Expired — fall through to generate new key
}
}
// No valid active key — generate one
let new_key = Self::generate_key(true);
self.store_key(&new_key).await?;
self.current_key.store(Arc::new(new_key));
Ok(())
}
async fn store_key(&self, info: &SigningKeyInfo) -> AppResult<()> {
let mut conn = self.redis.get_connection();
let redis_key = format!("{KEY_PREFIX}{}", info.kid);
let json = serde_json::to_string(info)?;
// Key lives in Redis for 2× the window (overlap for verification of older tokens)
let _: () = conn
.set_ex(&redis_key, &json, (KEY_WINDOW_SECS * 2) as u64)
.await
.map_err(AppError::Redis)?;
if info.active {
let _: () = conn
.set_ex(ACTIVE_KEY, &info.kid, (KEY_WINDOW_SECS * 2) as u64)
.await
.map_err(AppError::Redis)?;
}
Ok(())
}
async fn find_key(&self, kid: &str) -> AppResult<Option<SigningKeyInfo>> {
// Fast path: check current active key
let current = self.current_key.load();
if current.kid == kid {
return Ok(Some((**current).clone()));
}
// Slow path: look up from Redis
let mut conn = self.redis.get_connection();
let redis_key = format!("{KEY_PREFIX}{kid}");
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
match json {
Some(j) => Ok(Some(serde_json::from_str(&j)?)),
None => Ok(None),
}
}
pub async fn revoke_api_key(&self, api_key: &str) -> AppResult<()> {
let key = format!("{API_KEY_PREFIX}{api_key}");
fn sign_jwt(&self, claims: &TokenClaims, key: &SigningKeyInfo) -> AppResult<String> {
let secret_bytes = B64
.decode(&key.key_material)
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
let encoding_key = EncodingKey::from_secret(&secret_bytes);
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(key.kid.clone());
encode(&header, claims, &encoding_key)
.map_err(|e| AppError::InternalServerError(format!("JWT encode error: {e}")))
}
async fn create_refresh_token(&self, user_id: &str) -> AppResult<String> {
let token = format!("rt_{}", Uuid::now_v7());
let key = format!("{REFRESH_PREFIX}{token}");
let mut conn = self.redis.get_connection();
redis::Cmd::new()
.arg("DEL")
.arg(&key)
.query_async::<()>(&mut conn)
// Refresh tokens live for 7 days
let _: () = conn
.set_ex(&key, user_id, 86400 * 7)
.await
.map_err(AppError::Redis)?;
Ok(())
Ok(token)
}
async fn is_revoked(&self, jti: &str) -> AppResult<bool> {
let mut conn = self.redis.get_connection();
let exists: bool = conn
.exists(format!("{REVOKED_PREFIX}{jti}"))
.await
.map_err(AppError::Redis)?;
Ok(exists)
}
}
// ── Helpers ──────────────────────────────────────────────────────────────
fn decode_header(token: &str) -> Result<Header, jsonwebtoken::errors::Error> {
jsonwebtoken::decode_header(token)
}