821537186e
- Reorganized import statements in adapter tests for better readability - Replaced or_insert_with(Vec::new) with or_default() in test closures - Updated Cargo.lock with new dependency versions and checksums - Added TLS features to tonic dependency configuration - Included sqlx, chrono, and uuid dependencies with specific features - Added jsonwebtoken and arc-swap as project dependencies - Reformatted assertion statements to comply with line length limits - Adjusted base64 import order in engine codec module - Updated protobuf include statement formatting
172 lines
5.6 KiB
Rust
172 lines
5.6 KiB
Rust
//! Signing key store with atomic reads and periodic background refresh.
|
|
//!
|
|
//! Fetches HS256 signing keys from appks via `GetSigningKeys` RPC,
|
|
//! caches them behind `ArcSwap` for lock-free reads, and schedules
|
|
//! re-fetch when `next_rotation_at` is reached.
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use arc_swap::ArcSwap;
|
|
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
|
|
use jsonwebtoken::DecodingKey;
|
|
use tokio::task::JoinHandle;
|
|
use tonic::transport::Channel;
|
|
|
|
use crate::pb::core::GetSigningKeysRequest;
|
|
use crate::pb::core::token_service_client::TokenServiceClient;
|
|
use crate::{ImksError, ImksResult};
|
|
|
|
/// A cached signing key entry with a pre-computed `DecodingKey`.
|
|
struct CachedKey {
|
|
kid: String,
|
|
decoding_key: DecodingKey,
|
|
/// Unix timestamp (seconds) when this key expires.
|
|
expires_at: i64,
|
|
/// Whether this is the current active signing key.
|
|
active: bool,
|
|
}
|
|
|
|
/// Thread-safe store of signing keys with periodic background refresh.
|
|
///
|
|
/// Reads via `get_key()` are lock-free (ArcSwap).
|
|
/// A background task re-fetches keys from appks at each rotation window.
|
|
pub struct SigningKeyStore {
|
|
keys: Arc<ArcSwap<HashMap<String, CachedKey>>>,
|
|
refresh_handle: Option<JoinHandle<()>>,
|
|
}
|
|
|
|
impl SigningKeyStore {
|
|
/// Fetch initial keys from appks and start the background refresh loop.
|
|
pub async fn init(mut client: TokenServiceClient<Channel>) -> ImksResult<Self> {
|
|
let (cached, next_rotation) = fetch_keys(&mut client).await?;
|
|
|
|
let map: HashMap<String, CachedKey> =
|
|
cached.into_iter().map(|k| (k.kid.clone(), k)).collect();
|
|
|
|
let keys = Arc::new(ArcSwap::from_pointee(map));
|
|
|
|
let keys_clone = keys.clone();
|
|
let client_clone = client;
|
|
let refresh_handle = tokio::spawn(async move {
|
|
refresh_loop(client_clone, keys_clone, next_rotation).await;
|
|
});
|
|
|
|
tracing::info!("SigningKeyStore initialized with background refresh");
|
|
|
|
Ok(Self {
|
|
keys,
|
|
refresh_handle: Some(refresh_handle),
|
|
})
|
|
}
|
|
|
|
/// Look up a decoding key by its `kid`. Returns `None` if unknown or expired.
|
|
///
|
|
/// Inactive keys (from a previous rotation window) are still served so they
|
|
/// can validate tokens signed before the rotation. Expired keys (past their
|
|
/// 3h window) are rejected as a local safety net even though the RPC should
|
|
/// not return them.
|
|
pub fn get_key(&self, kid: &str) -> Option<DecodingKey> {
|
|
let map = self.keys.load();
|
|
let cached = map.get(kid)?;
|
|
|
|
debug_assert_eq!(cached.kid, kid, "CachedKey kid must match its HashMap key");
|
|
|
|
let now = chrono::Utc::now().timestamp();
|
|
if cached.expires_at > 0 && now >= cached.expires_at {
|
|
tracing::warn!(
|
|
kid = %cached.kid,
|
|
expires_at = cached.expires_at,
|
|
"Rejecting expired signing key"
|
|
);
|
|
return None;
|
|
}
|
|
|
|
if !cached.active {
|
|
tracing::debug!(
|
|
kid = %cached.kid,
|
|
"Serving inactive signing key (previous rotation window)"
|
|
);
|
|
}
|
|
|
|
Some(cached.decoding_key.clone())
|
|
}
|
|
|
|
/// Stop the background refresh task.
|
|
pub async fn shutdown(mut self) {
|
|
if let Some(handle) = self.refresh_handle.take() {
|
|
handle.abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for SigningKeyStore {
|
|
fn drop(&mut self) {
|
|
if let Some(handle) = self.refresh_handle.take() {
|
|
handle.abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Fetch all active signing keys from appks.
|
|
async fn fetch_keys(client: &mut TokenServiceClient<Channel>) -> ImksResult<(Vec<CachedKey>, i64)> {
|
|
let resp = client
|
|
.get_signing_keys(GetSigningKeysRequest { kid: String::new() })
|
|
.await
|
|
.map_err(ImksError::GrpcStatus)?;
|
|
|
|
let inner = resp.into_inner();
|
|
let mut cached_keys = Vec::new();
|
|
|
|
for key in &inner.keys {
|
|
let secret = BASE64
|
|
.decode(&key.key_material)
|
|
.map_err(|e| ImksError::Auth(format!("Invalid key base64 for kid={}: {e}", key.kid)))?;
|
|
|
|
cached_keys.push(CachedKey {
|
|
kid: key.kid.clone(),
|
|
decoding_key: DecodingKey::from_secret(&secret),
|
|
expires_at: key.expires_at,
|
|
active: key.active,
|
|
});
|
|
}
|
|
|
|
tracing::info!(
|
|
key_count = cached_keys.len(),
|
|
next_rotation = inner.next_rotation_at,
|
|
"Fetched signing keys from appks"
|
|
);
|
|
|
|
Ok((cached_keys, inner.next_rotation_at))
|
|
}
|
|
|
|
/// Background loop: sleep until `next_rotation_at`, re-fetch, swap atomically.
|
|
async fn refresh_loop(
|
|
mut client: TokenServiceClient<Channel>,
|
|
keys: Arc<ArcSwap<HashMap<String, CachedKey>>>,
|
|
mut next_rotation_at: i64,
|
|
) {
|
|
loop {
|
|
let now_secs = chrono::Utc::now().timestamp();
|
|
let sleep_secs = (next_rotation_at - now_secs).max(60);
|
|
|
|
tracing::debug!(sleep_secs, "Key refresh sleeping");
|
|
tokio::time::sleep(Duration::from_secs(sleep_secs as u64)).await;
|
|
|
|
match fetch_keys(&mut client).await {
|
|
Ok((cached, new_rotation)) => {
|
|
let map: HashMap<String, CachedKey> =
|
|
cached.into_iter().map(|k| (k.kid.clone(), k)).collect();
|
|
keys.store(Arc::new(map));
|
|
next_rotation_at = new_rotation;
|
|
tracing::info!("Signing keys refreshed");
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(error = %e, "Failed to refresh signing keys, retrying in 60s");
|
|
next_rotation_at = now_secs + 60;
|
|
}
|
|
}
|
|
}
|
|
}
|