Files
appks/service/internal_auth.rs
zhenyi dbbfb747a4 feat(auth): replace internal auth with JWT token service
- Replace InternalAuthService with TokenService using JWT tokens
- Add support for token issuance, refresh, verification and revocation
- Implement automatic signing key rotation with Redis storage
- Add database migration checks for indexes and foreign key constraints
- Update gRPC endpoints to use token-based authentication
- Remove deprecated API key based authentication system
- Add JSON Web Token support with HMAC-SHA256 signing
- Implement refresh token handling with automatic rotation
- Add token revocation by JTI and user ID
- Update build configuration to include core proto files
- Migrate database schema to handle token-based authentication
- Add comprehensive token validation and verification logic
2026-06-11 15:08:13 +08:00

447 lines
16 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::sync::Arc;
use arc_swap::ArcSwap;
use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::cache::redis::AppRedis;
use crate::error::{AppError, AppResult};
/// 3-hour key validity window.
const KEY_WINDOW_SECS: i64 = 3 * 3600;
/// Redis key for the currently active signing key.
const ACTIVE_KEY: &str = "core:token:active_key";
/// Redis prefix for all signing keys (by kid).
const KEY_PREFIX: &str = "core:token:key:";
/// Redis prefix for refresh tokens.
const REFRESH_PREFIX: &str = "core:token:refresh:";
/// Redis prefix for revoked token IDs (jti).
const REVOKED_PREFIX: &str = "core:token:revoked:";
// ── Types ────────────────────────────────────────────────────────────────
/// A signing key used for JWT issue/verify.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SigningKeyInfo {
pub kid: String,
pub algorithm: String,
/// Base64-encoded raw secret (for HS256).
pub key_material: String,
pub issued_at: i64,
pub expires_at: i64,
pub active: bool,
}
/// JWT claims embedded in every access token.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
pub sub: String,
pub iss: String,
pub iat: i64,
pub exp: i64,
pub jti: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub scope: String,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub extra: std::collections::HashMap<String, String>,
}
/// Result of issuing or refreshing a token pair.
pub struct IssuedTokens {
pub access_token: String,
pub refresh_token: String,
pub expires_at: i64,
pub key_id: String,
}
// ── Service ──────────────────────────────────────────────────────────────
#[derive(Clone)]
pub struct TokenService {
redis: AppRedis,
/// Current active signing key, swapped atomically on rotation.
current_key: Arc<ArcSwap<SigningKeyInfo>>,
}
impl TokenService {
/// Create a new TokenService.
/// Loads the active signing key from Redis if one exists, otherwise generates
/// and stores a fresh key.
pub async fn new(redis: AppRedis) -> AppResult<Self> {
let svc = Self {
redis,
current_key: Arc::new(ArcSwap::from_pointee(Self::placeholder_key())),
};
svc.load_or_create_active_key().await?;
Ok(svc)
}
// ── Issue ────────────────────────────────────────────────────────────
pub async fn issue_token(
&self,
user_id: &str,
ttl_secs: i64,
scopes: Vec<String>,
extra: std::collections::HashMap<String, String>,
) -> AppResult<IssuedTokens> {
let now = chrono::Utc::now().timestamp();
let key = self.current_key.load();
let claims = TokenClaims {
sub: user_id.to_string(),
iss: "appks".to_string(),
iat: now,
exp: now + ttl_secs,
jti: Uuid::now_v7().to_string(),
scope: scopes.join(" "),
extra,
};
let access_token = self.sign_jwt(&claims, &key)?;
let refresh_token = self.create_refresh_token(user_id).await?;
Ok(IssuedTokens {
access_token,
refresh_token,
expires_at: claims.exp,
key_id: key.kid.clone(),
})
}
// ── Refresh ──────────────────────────────────────────────────────────
pub async fn refresh_token(
&self,
refresh_token: &str,
access_ttl_secs: i64,
) -> AppResult<IssuedTokens> {
let mut conn = self.redis.get_connection();
// Look up user_id from refresh token
let user_id: Option<String> = conn
.get(format!("{REFRESH_PREFIX}{refresh_token}"))
.await
.map_err(AppError::Redis)?;
let user_id = user_id.ok_or(AppError::Unauthorized)?;
// Revoke old refresh token (rotation)
let _: () = conn
.del(format!("{REFRESH_PREFIX}{refresh_token}"))
.await
.map_err(AppError::Redis)?;
// Issue new token pair
self.issue_token(&user_id, access_ttl_secs, vec![], Default::default())
.await
}
// ── Revoke ───────────────────────────────────────────────────────────
/// Revoke a single token by its jti.
pub async fn revoke_by_jti(&self, jti: &str, ttl_secs: i64) -> AppResult<()> {
let mut conn = self.redis.get_connection();
let _: () = conn
.set_ex(format!("{REVOKED_PREFIX}{jti}"), "1", ttl_secs as u64)
.await
.map_err(AppError::Redis)?;
Ok(())
}
/// Revoke all tokens for a user (deletes all their refresh tokens).
pub async fn revoke_user_tokens(&self, user_id: &str) -> AppResult<u32> {
let mut conn = self.redis.get_connection();
let pattern = format!("{REFRESH_PREFIX}*");
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern)
.query_async(&mut conn)
.await
.map_err(AppError::Redis)?;
let mut count = 0u32;
for key in keys {
let stored_uid: Option<String> = conn.get(&key).await.map_err(AppError::Redis)?;
if stored_uid.as_deref() == Some(user_id) {
let _: () = conn.del(&key).await.map_err(AppError::Redis)?;
count += 1;
}
}
Ok(count)
}
// ── Verify ───────────────────────────────────────────────────────────
/// Verify a JWT access token. Returns the claims if valid, or a reason string if not.
pub async fn verify_token(
&self,
token: &str,
) -> AppResult<Result<TokenClaims, String>> {
// 1. Decode header to get kid
let header = match decode_header(token) {
Ok(h) => h,
Err(_) => return Ok(Err("invalid_token".to_string())),
};
let kid = match &header.kid {
Some(k) => k.clone(),
None => return Ok(Err("missing_kid".to_string())),
};
// 2. Find signing key by kid
let key_info = match self.find_key(&kid).await? {
Some(k) => k,
None => return Ok(Err("unknown_key".to_string())),
};
// 3. Decode + validate JWT
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = true;
validation.set_issuer(&["appks"]);
validation.required_spec_claims.clear();
let secret_bytes = B64
.decode(&key_info.key_material)
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
let decoding_key = DecodingKey::from_secret(&secret_bytes);
let token_data: TokenData<TokenClaims> = match decode(token, &decoding_key, &validation) {
Ok(td) => td,
Err(e) => {
let reason = match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => "expired",
jsonwebtoken::errors::ErrorKind::InvalidSignature => "invalid_signature",
jsonwebtoken::errors::ErrorKind::InvalidIssuer => "invalid_issuer",
_ => "invalid",
};
return Ok(Err(reason.to_string()));
}
};
// 4. Check revocation
if self.is_revoked(&token_data.claims.jti).await? {
return Ok(Err("revoked".to_string()));
}
Ok(Ok(token_data.claims))
}
// ── Key management ───────────────────────────────────────────────────
/// Return all non-expired signing keys (for GetSigningKeys RPC).
pub async fn get_signing_keys(&self) -> AppResult<(Vec<SigningKeyInfo>, i64)> {
let mut conn = self.redis.get_connection();
let key_ids: Vec<String> = redis::cmd("KEYS")
.arg(format!("{KEY_PREFIX}*"))
.query_async(&mut conn)
.await
.map_err(AppError::Redis)?;
let now = chrono::Utc::now().timestamp();
let mut keys = Vec::new();
for redis_key in key_ids {
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
if let Some(json) = json {
if let Ok(info) = serde_json::from_str::<SigningKeyInfo>(&json) {
if info.expires_at > now {
keys.push(info);
}
}
}
}
let next_rotation_at = self
.current_key
.load()
.issued_at
+ KEY_WINDOW_SECS;
Ok((keys, next_rotation_at))
}
/// Rotate signing keys if the current key is past its window.
pub async fn rotate_if_needed(&self) -> AppResult<bool> {
let now = chrono::Utc::now().timestamp();
let current = self.current_key.load();
if now < current.issued_at + KEY_WINDOW_SECS {
return Ok(false);
}
// Double-check lock via Redis
let lock_key = "core:token:rotation_lock";
let mut conn = self.redis.get_connection();
let acquired: bool = redis::cmd("SET")
.arg(lock_key)
.arg("1")
.arg("NX")
.arg("EX")
.arg(10)
.query_async(&mut conn)
.await
.map_err(AppError::Redis)?;
if !acquired {
// Another instance is rotating; reload from Redis
self.load_or_create_active_key().await?;
return Ok(false);
}
// Mark old key as inactive
let mut old: SigningKeyInfo = (**current).clone();
old.active = false;
self.store_key(&old).await?;
// Generate new active key
let new_key = Self::generate_key(true);
self.store_key(&new_key).await?;
self.current_key.store(Arc::new(new_key));
let _: () = conn.del(lock_key).await.map_err(AppError::Redis)?;
Ok(true)
}
// ── Internal helpers ─────────────────────────────────────────────────
fn generate_key(active: bool) -> SigningKeyInfo {
use rand::RngCore;
let mut secret = [0u8; 32];
rand::thread_rng().fill_bytes(&mut secret);
let now = chrono::Utc::now().timestamp();
SigningKeyInfo {
kid: Uuid::now_v7().to_string(),
algorithm: "HS256".to_string(),
key_material: B64.encode(secret),
issued_at: now,
expires_at: now + KEY_WINDOW_SECS,
active,
}
}
fn placeholder_key() -> SigningKeyInfo {
SigningKeyInfo {
kid: "placeholder".to_string(),
algorithm: "HS256".to_string(),
key_material: String::new(),
issued_at: 0,
expires_at: 0,
active: false,
}
}
async fn load_or_create_active_key(&self) -> AppResult<()> {
let mut conn = self.redis.get_connection();
// Try loading the active key pointer from Redis
let active_kid: Option<String> = conn.get(ACTIVE_KEY).await.map_err(AppError::Redis)?;
if let Some(kid) = active_kid {
let redis_key = format!("{KEY_PREFIX}{kid}");
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
if let Some(json) = json {
let info: SigningKeyInfo = serde_json::from_str(&json)?;
let now = chrono::Utc::now().timestamp();
if info.expires_at > now {
self.current_key.store(Arc::new(info));
return Ok(());
}
// Expired — fall through to generate new key
}
}
// No valid active key — generate one
let new_key = Self::generate_key(true);
self.store_key(&new_key).await?;
self.current_key.store(Arc::new(new_key));
Ok(())
}
async fn store_key(&self, info: &SigningKeyInfo) -> AppResult<()> {
let mut conn = self.redis.get_connection();
let redis_key = format!("{KEY_PREFIX}{}", info.kid);
let json = serde_json::to_string(info)?;
// Key lives in Redis for 2× the window (overlap for verification of older tokens)
let _: () = conn
.set_ex(&redis_key, &json, (KEY_WINDOW_SECS * 2) as u64)
.await
.map_err(AppError::Redis)?;
if info.active {
let _: () = conn
.set_ex(ACTIVE_KEY, &info.kid, (KEY_WINDOW_SECS * 2) as u64)
.await
.map_err(AppError::Redis)?;
}
Ok(())
}
async fn find_key(&self, kid: &str) -> AppResult<Option<SigningKeyInfo>> {
// Fast path: check current active key
let current = self.current_key.load();
if current.kid == kid {
return Ok(Some((**current).clone()));
}
// Slow path: look up from Redis
let mut conn = self.redis.get_connection();
let redis_key = format!("{KEY_PREFIX}{kid}");
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
match json {
Some(j) => Ok(Some(serde_json::from_str(&j)?)),
None => Ok(None),
}
}
fn sign_jwt(&self, claims: &TokenClaims, key: &SigningKeyInfo) -> AppResult<String> {
let secret_bytes = B64
.decode(&key.key_material)
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
let encoding_key = EncodingKey::from_secret(&secret_bytes);
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(key.kid.clone());
encode(&header, claims, &encoding_key)
.map_err(|e| AppError::InternalServerError(format!("JWT encode error: {e}")))
}
async fn create_refresh_token(&self, user_id: &str) -> AppResult<String> {
let token = format!("rt_{}", Uuid::now_v7());
let key = format!("{REFRESH_PREFIX}{token}");
let mut conn = self.redis.get_connection();
// Refresh tokens live for 7 days
let _: () = conn
.set_ex(&key, user_id, 86400 * 7)
.await
.map_err(AppError::Redis)?;
Ok(token)
}
async fn is_revoked(&self, jti: &str) -> AppResult<bool> {
let mut conn = self.redis.get_connection();
let exists: bool = conn
.exists(format!("{REVOKED_PREFIX}{jti}"))
.await
.map_err(AppError::Redis)?;
Ok(exists)
}
}
// ── Helpers ──────────────────────────────────────────────────────────────
fn decode_header(token: &str) -> Result<Header, jsonwebtoken::errors::Error> {
jsonwebtoken::decode_header(token)
}