//! Signing key store with atomic reads and periodic background refresh. //! //! Fetches HS256 signing keys from appks via `GetSigningKeys` RPC, //! caches them behind `ArcSwap` for lock-free reads, and schedules //! re-fetch when `next_rotation_at` is reached. use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use arc_swap::ArcSwap; use base64::{Engine, engine::general_purpose::STANDARD as BASE64}; use jsonwebtoken::DecodingKey; use tokio::task::JoinHandle; use tonic::transport::Channel; use crate::pb::core::GetSigningKeysRequest; use crate::pb::core::token_service_client::TokenServiceClient; use crate::{ImksError, ImksResult}; /// A cached signing key entry with a pre-computed `DecodingKey`. struct CachedKey { kid: String, decoding_key: DecodingKey, /// Unix timestamp (seconds) when this key expires. expires_at: i64, /// Whether this is the current active signing key. active: bool, } /// Thread-safe store of signing keys with periodic background refresh. /// /// Reads via `get_key()` are lock-free (ArcSwap). /// A background task re-fetches keys from appks at each rotation window. pub struct SigningKeyStore { keys: Arc>>, refresh_handle: Option>, } impl SigningKeyStore { /// Fetch initial keys from appks and start the background refresh loop. pub async fn init(mut client: TokenServiceClient) -> ImksResult { let (cached, next_rotation) = fetch_keys(&mut client).await?; let map: HashMap = cached.into_iter().map(|k| (k.kid.clone(), k)).collect(); let keys = Arc::new(ArcSwap::from_pointee(map)); let keys_clone = keys.clone(); let client_clone = client; let refresh_handle = tokio::spawn(async move { refresh_loop(client_clone, keys_clone, next_rotation).await; }); tracing::info!("SigningKeyStore initialized with background refresh"); Ok(Self { keys, refresh_handle: Some(refresh_handle), }) } /// Look up a decoding key by its `kid`. Returns `None` if unknown or expired. /// /// Inactive keys (from a previous rotation window) are still served so they /// can validate tokens signed before the rotation. Expired keys (past their /// 3h window) are rejected as a local safety net even though the RPC should /// not return them. pub fn get_key(&self, kid: &str) -> Option { let map = self.keys.load(); let cached = map.get(kid)?; debug_assert_eq!(cached.kid, kid, "CachedKey kid must match its HashMap key"); let now = chrono::Utc::now().timestamp(); if cached.expires_at > 0 && now >= cached.expires_at { tracing::warn!( kid = %cached.kid, expires_at = cached.expires_at, "Rejecting expired signing key" ); return None; } if !cached.active { tracing::debug!( kid = %cached.kid, "Serving inactive signing key (previous rotation window)" ); } Some(cached.decoding_key.clone()) } /// Stop the background refresh task. pub async fn shutdown(mut self) { if let Some(handle) = self.refresh_handle.take() { handle.abort(); } } } impl Drop for SigningKeyStore { fn drop(&mut self) { if let Some(handle) = self.refresh_handle.take() { handle.abort(); } } } /// Fetch all active signing keys from appks. async fn fetch_keys(client: &mut TokenServiceClient) -> ImksResult<(Vec, i64)> { let resp = client .get_signing_keys(GetSigningKeysRequest { kid: String::new() }) .await .map_err(ImksError::GrpcStatus)?; let inner = resp.into_inner(); let mut cached_keys = Vec::new(); for key in &inner.keys { let secret = BASE64 .decode(&key.key_material) .map_err(|e| ImksError::Auth(format!("Invalid key base64 for kid={}: {e}", key.kid)))?; cached_keys.push(CachedKey { kid: key.kid.clone(), decoding_key: DecodingKey::from_secret(&secret), expires_at: key.expires_at, active: key.active, }); } tracing::info!( key_count = cached_keys.len(), next_rotation = inner.next_rotation_at, "Fetched signing keys from appks" ); Ok((cached_keys, inner.next_rotation_at)) } /// Background loop: sleep until `next_rotation_at`, re-fetch, swap atomically. async fn refresh_loop( mut client: TokenServiceClient, keys: Arc>>, mut next_rotation_at: i64, ) { loop { let now_secs = chrono::Utc::now().timestamp(); let sleep_secs = (next_rotation_at - now_secs).max(60); tracing::debug!(sleep_secs, "Key refresh sleeping"); tokio::time::sleep(Duration::from_secs(sleep_secs as u64)).await; match fetch_keys(&mut client).await { Ok((cached, new_rotation)) => { let map: HashMap = cached.into_iter().map(|k| (k.kid.clone(), k)).collect(); keys.store(Arc::new(map)); next_rotation_at = new_rotation; tracing::info!("Signing keys refreshed"); } Err(e) => { tracing::error!(error = %e, "Failed to refresh signing keys, retrying in 60s"); next_rotation_at = now_secs + 60; } } } }