Add: multi token management
This commit is contained in:
@@ -7,7 +7,7 @@ use axum::Router;
|
||||
use sprimo_protocol::v1::{
|
||||
CommandEnvelope, ErrorResponse, FrontendStateSnapshot, HealthResponse,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -19,7 +19,7 @@ use uuid::Uuid;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApiConfig {
|
||||
pub bind_addr: SocketAddr,
|
||||
pub auth_token: String,
|
||||
pub auth_tokens: Vec<String>,
|
||||
pub app_version: String,
|
||||
pub app_build: String,
|
||||
pub dedupe_capacity: usize,
|
||||
@@ -29,9 +29,14 @@ pub struct ApiConfig {
|
||||
impl ApiConfig {
|
||||
#[must_use]
|
||||
pub fn default_with_token(auth_token: String) -> Self {
|
||||
Self::default_with_tokens(vec![auth_token])
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn default_with_tokens(auth_tokens: Vec<String>) -> Self {
|
||||
Self {
|
||||
bind_addr: SocketAddr::from(([127, 0, 0, 1], 32_145)),
|
||||
auth_token,
|
||||
auth_tokens,
|
||||
app_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
app_build: "dev".to_string(),
|
||||
dedupe_capacity: 5_000,
|
||||
@@ -43,7 +48,7 @@ impl ApiConfig {
|
||||
#[derive(Debug)]
|
||||
pub struct ApiState {
|
||||
start_at: Instant,
|
||||
auth_token: String,
|
||||
auth_tokens: Arc<RwLock<HashSet<String>>>,
|
||||
app_version: String,
|
||||
app_build: String,
|
||||
dedupe_capacity: usize,
|
||||
@@ -59,10 +64,14 @@ impl ApiState {
|
||||
config: ApiConfig,
|
||||
snapshot: Arc<RwLock<FrontendStateSnapshot>>,
|
||||
command_tx: mpsc::Sender<CommandEnvelope>,
|
||||
auth_tokens: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
if let Ok(mut guard) = auth_tokens.write() {
|
||||
*guard = config.auth_tokens.into_iter().collect();
|
||||
}
|
||||
Self {
|
||||
start_at: Instant::now(),
|
||||
auth_token: config.auth_token,
|
||||
auth_tokens,
|
||||
app_version: config.app_version,
|
||||
app_build: config.app_build,
|
||||
dedupe_capacity: config.dedupe_capacity,
|
||||
@@ -188,11 +197,18 @@ fn require_auth(headers: &HeaderMap, state: &ApiState) -> Result<(), ApiError> {
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.ok_or(ApiError::Unauthorized)?;
|
||||
let expected = format!("Bearer {}", state.auth_token);
|
||||
if raw == expected {
|
||||
return Ok(());
|
||||
let Some(token) = raw.strip_prefix("Bearer ") else {
|
||||
return Err(ApiError::Unauthorized);
|
||||
};
|
||||
let guard = state
|
||||
.auth_tokens
|
||||
.read()
|
||||
.map_err(|_| ApiError::Internal("auth token lock poisoned".to_string()))?;
|
||||
if guard.contains(token) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ApiError::Unauthorized)
|
||||
}
|
||||
Err(ApiError::Unauthorized)
|
||||
}
|
||||
|
||||
enum ApiError {
|
||||
@@ -240,6 +256,7 @@ mod tests {
|
||||
use sprimo_protocol::v1::{
|
||||
CapabilityFlags, CommandEnvelope, FrontendCommand, FrontendStateSnapshot,
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tokio::sync::mpsc;
|
||||
use tower::ServiceExt;
|
||||
@@ -255,6 +272,7 @@ mod tests {
|
||||
ApiConfig::default_with_token("token".to_string()),
|
||||
snapshot,
|
||||
tx,
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
)),
|
||||
rx,
|
||||
)
|
||||
@@ -335,6 +353,48 @@ mod tests {
|
||||
assert_eq!(received.id, command.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn command_accepts_with_any_configured_token() {
|
||||
let snapshot =
|
||||
FrontendStateSnapshot::idle(CapabilityFlags::default());
|
||||
let snapshot = Arc::new(RwLock::new(snapshot));
|
||||
let (tx, mut rx) = mpsc::channel(8);
|
||||
let state = Arc::new(ApiState::new(
|
||||
ApiConfig::default_with_tokens(vec![
|
||||
"token-a".to_string(),
|
||||
"token-b".to_string(),
|
||||
]),
|
||||
snapshot,
|
||||
tx,
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
));
|
||||
let app = app_router(state);
|
||||
let command = CommandEnvelope {
|
||||
id: Uuid::new_v4(),
|
||||
ts_ms: 1,
|
||||
command: FrontendCommand::Toast {
|
||||
text: "hi".to_string(),
|
||||
ttl_ms: None,
|
||||
},
|
||||
};
|
||||
let body = serde_json::to_vec(&command).expect("json");
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/command")
|
||||
.header("content-type", "application/json")
|
||||
.header("authorization", "Bearer token-b")
|
||||
.body(Body::from(body))
|
||||
.expect("request"),
|
||||
)
|
||||
.await
|
||||
.expect("response");
|
||||
assert_eq!(response.status(), StatusCode::ACCEPTED);
|
||||
let received = rx.recv().await.expect("forwarded command");
|
||||
assert_eq!(received.id, command.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_json_returns_bad_request() {
|
||||
let (state, _) = build_state();
|
||||
|
||||
Reference in New Issue
Block a user