feat(auth): replace internal auth with JWT token service
- Replace InternalAuthService with TokenService using JWT tokens - Add support for token issuance, refresh, verification and revocation - Implement automatic signing key rotation with Redis storage - Add database migration checks for indexes and foreign key constraints - Update gRPC endpoints to use token-based authentication - Remove deprecated API key based authentication system - Add JSON Web Token support with HMAC-SHA256 signing - Implement refresh token handling with automatic rotation - Add token revocation by JTI and user ID - Update build configuration to include core proto files - Migrate database schema to handle token-based authentication - Add comprehensive token validation and verification logic
This commit is contained in:
+401
-52
@@ -1,97 +1,446 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode};
|
||||
use redis::AsyncCommands;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::cache::redis::AppRedis;
|
||||
use crate::error::{AppError, AppResult};
|
||||
|
||||
const API_KEY_PREFIX: &str = "internal:auth:";
|
||||
const DEFAULT_TTL_SECS: u64 = 86400 * 30;
|
||||
/// 3-hour key validity window.
|
||||
const KEY_WINDOW_SECS: i64 = 3 * 3600;
|
||||
/// Redis key for the currently active signing key.
|
||||
const ACTIVE_KEY: &str = "core:token:active_key";
|
||||
/// Redis prefix for all signing keys (by kid).
|
||||
const KEY_PREFIX: &str = "core:token:key:";
|
||||
/// Redis prefix for refresh tokens.
|
||||
const REFRESH_PREFIX: &str = "core:token:refresh:";
|
||||
/// Redis prefix for revoked token IDs (jti).
|
||||
const REVOKED_PREFIX: &str = "core:token:revoked:";
|
||||
|
||||
// ── Types ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A signing key used for JWT issue/verify.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServiceIdentity {
|
||||
pub service_name: String,
|
||||
pub service_id: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub struct SigningKeyInfo {
|
||||
pub kid: String,
|
||||
pub algorithm: String,
|
||||
/// Base64-encoded raw secret (for HS256).
|
||||
pub key_material: String,
|
||||
pub issued_at: i64,
|
||||
pub expires_at: i64,
|
||||
pub active: bool,
|
||||
}
|
||||
|
||||
/// JWT claims embedded in every access token.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenClaims {
|
||||
pub sub: String,
|
||||
pub iss: String,
|
||||
pub iat: i64,
|
||||
pub exp: i64,
|
||||
pub jti: String,
|
||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||
pub scope: String,
|
||||
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
|
||||
pub extra: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Result of issuing or refreshing a token pair.
|
||||
pub struct IssuedTokens {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub expires_at: i64,
|
||||
pub key_id: String,
|
||||
}
|
||||
|
||||
// ── Service ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct InternalAuthService {
|
||||
pub struct TokenService {
|
||||
redis: AppRedis,
|
||||
/// Current active signing key, swapped atomically on rotation.
|
||||
current_key: Arc<ArcSwap<SigningKeyInfo>>,
|
||||
}
|
||||
|
||||
impl InternalAuthService {
|
||||
pub fn new(redis: AppRedis) -> Self {
|
||||
Self { redis }
|
||||
impl TokenService {
|
||||
/// Create a new TokenService.
|
||||
/// Loads the active signing key from Redis if one exists, otherwise generates
|
||||
/// and stores a fresh key.
|
||||
pub async fn new(redis: AppRedis) -> AppResult<Self> {
|
||||
let svc = Self {
|
||||
redis,
|
||||
current_key: Arc::new(ArcSwap::from_pointee(Self::placeholder_key())),
|
||||
};
|
||||
svc.load_or_create_active_key().await?;
|
||||
Ok(svc)
|
||||
}
|
||||
|
||||
pub async fn issue_api_key(
|
||||
&self,
|
||||
service_name: &str,
|
||||
scopes: Vec<String>,
|
||||
ttl_secs: Option<u64>,
|
||||
) -> AppResult<(String, ServiceIdentity)> {
|
||||
let ttl = ttl_secs.unwrap_or(DEFAULT_TTL_SECS);
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let expires_at = now + ttl as i64;
|
||||
// ── Issue ────────────────────────────────────────────────────────────
|
||||
|
||||
let identity = ServiceIdentity {
|
||||
service_name: service_name.to_string(),
|
||||
service_id: Uuid::now_v7().to_string(),
|
||||
scopes,
|
||||
issued_at: now,
|
||||
expires_at,
|
||||
pub async fn issue_token(
|
||||
&self,
|
||||
user_id: &str,
|
||||
ttl_secs: i64,
|
||||
scopes: Vec<String>,
|
||||
extra: std::collections::HashMap<String, String>,
|
||||
) -> AppResult<IssuedTokens> {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let key = self.current_key.load();
|
||||
|
||||
let claims = TokenClaims {
|
||||
sub: user_id.to_string(),
|
||||
iss: "appks".to_string(),
|
||||
iat: now,
|
||||
exp: now + ttl_secs,
|
||||
jti: Uuid::now_v7().to_string(),
|
||||
scope: scopes.join(" "),
|
||||
extra,
|
||||
};
|
||||
|
||||
let api_key = format!("im_{}", Uuid::now_v7());
|
||||
let key = format!("{API_KEY_PREFIX}{api_key}");
|
||||
let json = serde_json::to_string(&identity)?;
|
||||
let access_token = self.sign_jwt(&claims, &key)?;
|
||||
let refresh_token = self.create_refresh_token(user_id).await?;
|
||||
|
||||
Ok(IssuedTokens {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at: claims.exp,
|
||||
key_id: key.kid.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
// ── Refresh ──────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn refresh_token(
|
||||
&self,
|
||||
refresh_token: &str,
|
||||
access_ttl_secs: i64,
|
||||
) -> AppResult<IssuedTokens> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
redis::Cmd::new()
|
||||
.arg("SETEX")
|
||||
.arg(&key)
|
||||
.arg(ttl)
|
||||
.arg(&json)
|
||||
.query_async::<()>(&mut conn)
|
||||
|
||||
// Look up user_id from refresh token
|
||||
let user_id: Option<String> = conn
|
||||
.get(format!("{REFRESH_PREFIX}{refresh_token}"))
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
Ok((api_key, identity))
|
||||
let user_id = user_id.ok_or(AppError::Unauthorized)?;
|
||||
|
||||
// Revoke old refresh token (rotation)
|
||||
let _: () = conn
|
||||
.del(format!("{REFRESH_PREFIX}{refresh_token}"))
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
// Issue new token pair
|
||||
self.issue_token(&user_id, access_ttl_secs, vec![], Default::default())
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn verify_api_key(&self, api_key: &str) -> AppResult<Option<ServiceIdentity>> {
|
||||
let key = format!("{API_KEY_PREFIX}{api_key}");
|
||||
let mut conn = self.redis.get_connection();
|
||||
// ── Revoke ───────────────────────────────────────────────────────────
|
||||
|
||||
let json: Option<String> = redis::Cmd::new()
|
||||
.arg("GET")
|
||||
.arg(&key)
|
||||
/// Revoke a single token by its jti.
|
||||
pub async fn revoke_by_jti(&self, jti: &str, ttl_secs: i64) -> AppResult<()> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
let _: () = conn
|
||||
.set_ex(format!("{REVOKED_PREFIX}{jti}"), "1", ttl_secs as u64)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Revoke all tokens for a user (deletes all their refresh tokens).
|
||||
pub async fn revoke_user_tokens(&self, user_id: &str) -> AppResult<u32> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
let pattern = format!("{REFRESH_PREFIX}*");
|
||||
|
||||
let keys: Vec<String> = redis::cmd("KEYS")
|
||||
.arg(&pattern)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
match json {
|
||||
Some(j) => {
|
||||
let identity: ServiceIdentity = serde_json::from_str(&j)?;
|
||||
Ok(Some(identity))
|
||||
let mut count = 0u32;
|
||||
for key in keys {
|
||||
let stored_uid: Option<String> = conn.get(&key).await.map_err(AppError::Redis)?;
|
||||
if stored_uid.as_deref() == Some(user_id) {
|
||||
let _: () = conn.del(&key).await.map_err(AppError::Redis)?;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
// ── Verify ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Verify a JWT access token. Returns the claims if valid, or a reason string if not.
|
||||
pub async fn verify_token(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> AppResult<Result<TokenClaims, String>> {
|
||||
// 1. Decode header to get kid
|
||||
let header = match decode_header(token) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Ok(Err("invalid_token".to_string())),
|
||||
};
|
||||
|
||||
let kid = match &header.kid {
|
||||
Some(k) => k.clone(),
|
||||
None => return Ok(Err("missing_kid".to_string())),
|
||||
};
|
||||
|
||||
// 2. Find signing key by kid
|
||||
let key_info = match self.find_key(&kid).await? {
|
||||
Some(k) => k,
|
||||
None => return Ok(Err("unknown_key".to_string())),
|
||||
};
|
||||
|
||||
// 3. Decode + validate JWT
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.validate_exp = true;
|
||||
validation.set_issuer(&["appks"]);
|
||||
validation.required_spec_claims.clear();
|
||||
|
||||
let secret_bytes = B64
|
||||
.decode(&key_info.key_material)
|
||||
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
|
||||
let decoding_key = DecodingKey::from_secret(&secret_bytes);
|
||||
|
||||
let token_data: TokenData<TokenClaims> = match decode(token, &decoding_key, &validation) {
|
||||
Ok(td) => td,
|
||||
Err(e) => {
|
||||
let reason = match e.kind() {
|
||||
jsonwebtoken::errors::ErrorKind::ExpiredSignature => "expired",
|
||||
jsonwebtoken::errors::ErrorKind::InvalidSignature => "invalid_signature",
|
||||
jsonwebtoken::errors::ErrorKind::InvalidIssuer => "invalid_issuer",
|
||||
_ => "invalid",
|
||||
};
|
||||
return Ok(Err(reason.to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Check revocation
|
||||
if self.is_revoked(&token_data.claims.jti).await? {
|
||||
return Ok(Err("revoked".to_string()));
|
||||
}
|
||||
|
||||
Ok(Ok(token_data.claims))
|
||||
}
|
||||
|
||||
// ── Key management ───────────────────────────────────────────────────
|
||||
|
||||
/// Return all non-expired signing keys (for GetSigningKeys RPC).
|
||||
pub async fn get_signing_keys(&self) -> AppResult<(Vec<SigningKeyInfo>, i64)> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
|
||||
let key_ids: Vec<String> = redis::cmd("KEYS")
|
||||
.arg(format!("{KEY_PREFIX}*"))
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let mut keys = Vec::new();
|
||||
|
||||
for redis_key in key_ids {
|
||||
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
|
||||
if let Some(json) = json {
|
||||
if let Ok(info) = serde_json::from_str::<SigningKeyInfo>(&json) {
|
||||
if info.expires_at > now {
|
||||
keys.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let next_rotation_at = self
|
||||
.current_key
|
||||
.load()
|
||||
.issued_at
|
||||
+ KEY_WINDOW_SECS;
|
||||
|
||||
Ok((keys, next_rotation_at))
|
||||
}
|
||||
|
||||
/// Rotate signing keys if the current key is past its window.
|
||||
pub async fn rotate_if_needed(&self) -> AppResult<bool> {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
let current = self.current_key.load();
|
||||
|
||||
if now < current.issued_at + KEY_WINDOW_SECS {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Double-check lock via Redis
|
||||
let lock_key = "core:token:rotation_lock";
|
||||
let mut conn = self.redis.get_connection();
|
||||
let acquired: bool = redis::cmd("SET")
|
||||
.arg(lock_key)
|
||||
.arg("1")
|
||||
.arg("NX")
|
||||
.arg("EX")
|
||||
.arg(10)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
if !acquired {
|
||||
// Another instance is rotating; reload from Redis
|
||||
self.load_or_create_active_key().await?;
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Mark old key as inactive
|
||||
let mut old: SigningKeyInfo = (**current).clone();
|
||||
old.active = false;
|
||||
self.store_key(&old).await?;
|
||||
|
||||
// Generate new active key
|
||||
let new_key = Self::generate_key(true);
|
||||
self.store_key(&new_key).await?;
|
||||
self.current_key.store(Arc::new(new_key));
|
||||
|
||||
let _: () = conn.del(lock_key).await.map_err(AppError::Redis)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
fn generate_key(active: bool) -> SigningKeyInfo {
|
||||
use rand::RngCore;
|
||||
let mut secret = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut secret);
|
||||
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
SigningKeyInfo {
|
||||
kid: Uuid::now_v7().to_string(),
|
||||
algorithm: "HS256".to_string(),
|
||||
key_material: B64.encode(secret),
|
||||
issued_at: now,
|
||||
expires_at: now + KEY_WINDOW_SECS,
|
||||
active,
|
||||
}
|
||||
}
|
||||
|
||||
fn placeholder_key() -> SigningKeyInfo {
|
||||
SigningKeyInfo {
|
||||
kid: "placeholder".to_string(),
|
||||
algorithm: "HS256".to_string(),
|
||||
key_material: String::new(),
|
||||
issued_at: 0,
|
||||
expires_at: 0,
|
||||
active: false,
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_or_create_active_key(&self) -> AppResult<()> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
|
||||
// Try loading the active key pointer from Redis
|
||||
let active_kid: Option<String> = conn.get(ACTIVE_KEY).await.map_err(AppError::Redis)?;
|
||||
|
||||
if let Some(kid) = active_kid {
|
||||
let redis_key = format!("{KEY_PREFIX}{kid}");
|
||||
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
|
||||
if let Some(json) = json {
|
||||
let info: SigningKeyInfo = serde_json::from_str(&json)?;
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
if info.expires_at > now {
|
||||
self.current_key.store(Arc::new(info));
|
||||
return Ok(());
|
||||
}
|
||||
// Expired — fall through to generate new key
|
||||
}
|
||||
}
|
||||
|
||||
// No valid active key — generate one
|
||||
let new_key = Self::generate_key(true);
|
||||
self.store_key(&new_key).await?;
|
||||
self.current_key.store(Arc::new(new_key));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn store_key(&self, info: &SigningKeyInfo) -> AppResult<()> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
let redis_key = format!("{KEY_PREFIX}{}", info.kid);
|
||||
let json = serde_json::to_string(info)?;
|
||||
|
||||
// Key lives in Redis for 2× the window (overlap for verification of older tokens)
|
||||
let _: () = conn
|
||||
.set_ex(&redis_key, &json, (KEY_WINDOW_SECS * 2) as u64)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
if info.active {
|
||||
let _: () = conn
|
||||
.set_ex(ACTIVE_KEY, &info.kid, (KEY_WINDOW_SECS * 2) as u64)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn find_key(&self, kid: &str) -> AppResult<Option<SigningKeyInfo>> {
|
||||
// Fast path: check current active key
|
||||
let current = self.current_key.load();
|
||||
if current.kid == kid {
|
||||
return Ok(Some((**current).clone()));
|
||||
}
|
||||
|
||||
// Slow path: look up from Redis
|
||||
let mut conn = self.redis.get_connection();
|
||||
let redis_key = format!("{KEY_PREFIX}{kid}");
|
||||
let json: Option<String> = conn.get(&redis_key).await.map_err(AppError::Redis)?;
|
||||
|
||||
match json {
|
||||
Some(j) => Ok(Some(serde_json::from_str(&j)?)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn revoke_api_key(&self, api_key: &str) -> AppResult<()> {
|
||||
let key = format!("{API_KEY_PREFIX}{api_key}");
|
||||
fn sign_jwt(&self, claims: &TokenClaims, key: &SigningKeyInfo) -> AppResult<String> {
|
||||
let secret_bytes = B64
|
||||
.decode(&key.key_material)
|
||||
.map_err(|e| AppError::InternalServerError(format!("bad key material: {e}")))?;
|
||||
let encoding_key = EncodingKey::from_secret(&secret_bytes);
|
||||
|
||||
let mut header = Header::new(Algorithm::HS256);
|
||||
header.kid = Some(key.kid.clone());
|
||||
|
||||
encode(&header, claims, &encoding_key)
|
||||
.map_err(|e| AppError::InternalServerError(format!("JWT encode error: {e}")))
|
||||
}
|
||||
|
||||
async fn create_refresh_token(&self, user_id: &str) -> AppResult<String> {
|
||||
let token = format!("rt_{}", Uuid::now_v7());
|
||||
let key = format!("{REFRESH_PREFIX}{token}");
|
||||
let mut conn = self.redis.get_connection();
|
||||
|
||||
redis::Cmd::new()
|
||||
.arg("DEL")
|
||||
.arg(&key)
|
||||
.query_async::<()>(&mut conn)
|
||||
// Refresh tokens live for 7 days
|
||||
let _: () = conn
|
||||
.set_ex(&key, user_id, 86400 * 7)
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
|
||||
Ok(())
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
async fn is_revoked(&self, jti: &str) -> AppResult<bool> {
|
||||
let mut conn = self.redis.get_connection();
|
||||
let exists: bool = conn
|
||||
.exists(format!("{REVOKED_PREFIX}{jti}"))
|
||||
.await
|
||||
.map_err(AppError::Redis)?;
|
||||
Ok(exists)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
fn decode_header(token: &str) -> Result<Header, jsonwebtoken::errors::Error> {
|
||||
jsonwebtoken::decode_header(token)
|
||||
}
|
||||
|
||||
+7
-5
@@ -63,7 +63,7 @@ pub struct NotificationService {
|
||||
}
|
||||
|
||||
pub use im::ImService;
|
||||
pub use internal_auth::InternalAuthService;
|
||||
pub use internal_auth::TokenService;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppService {
|
||||
@@ -75,13 +75,13 @@ pub struct AppService {
|
||||
pub pr: PrService,
|
||||
pub notify: NotificationService,
|
||||
pub im: ImService,
|
||||
pub internal_auth: InternalAuthService,
|
||||
pub internal_auth: TokenService,
|
||||
pub ctx: Arc<ServiceContext>,
|
||||
}
|
||||
|
||||
impl AppService {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
pub async fn new(
|
||||
version: String,
|
||||
db: AppDatabase,
|
||||
redis: AppRedis,
|
||||
@@ -91,7 +91,9 @@ impl AppService {
|
||||
registry: Arc<EtcdRegistry>,
|
||||
nats: Arc<NatsQueue>,
|
||||
) -> Self {
|
||||
let internal_auth = InternalAuthService::new(redis.clone());
|
||||
let token_service = TokenService::new(redis.clone())
|
||||
.await
|
||||
.expect("failed to initialize TokenService");
|
||||
|
||||
let ctx = Arc::new(ServiceContext {
|
||||
version,
|
||||
@@ -114,7 +116,7 @@ impl AppService {
|
||||
pr: PrService { ctx: ctx.clone() },
|
||||
notify: NotificationService { ctx: ctx.clone() },
|
||||
im: ImService { ctx: ctx.clone() },
|
||||
internal_auth,
|
||||
internal_auth: token_service,
|
||||
ctx,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user