Add: multi token management

This commit is contained in:
DaZuo0122
2026-02-15 12:20:00 +08:00
parent f20ed1fd9d
commit 832fbda04d
8 changed files with 713 additions and 16 deletions

View File

@@ -7,7 +7,7 @@ use axum::Router;
use sprimo_protocol::v1::{
CommandEnvelope, ErrorResponse, FrontendStateSnapshot, HealthResponse,
};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
@@ -19,7 +19,7 @@ use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct ApiConfig {
pub bind_addr: SocketAddr,
pub auth_token: String,
pub auth_tokens: Vec<String>,
pub app_version: String,
pub app_build: String,
pub dedupe_capacity: usize,
@@ -29,9 +29,14 @@ pub struct ApiConfig {
impl ApiConfig {
#[must_use]
pub fn default_with_token(auth_token: String) -> Self {
Self::default_with_tokens(vec![auth_token])
}
#[must_use]
pub fn default_with_tokens(auth_tokens: Vec<String>) -> Self {
Self {
bind_addr: SocketAddr::from(([127, 0, 0, 1], 32_145)),
auth_token,
auth_tokens,
app_version: env!("CARGO_PKG_VERSION").to_string(),
app_build: "dev".to_string(),
dedupe_capacity: 5_000,
@@ -43,7 +48,7 @@ impl ApiConfig {
#[derive(Debug)]
pub struct ApiState {
start_at: Instant,
auth_token: String,
auth_tokens: Arc<RwLock<HashSet<String>>>,
app_version: String,
app_build: String,
dedupe_capacity: usize,
@@ -59,10 +64,14 @@ impl ApiState {
config: ApiConfig,
snapshot: Arc<RwLock<FrontendStateSnapshot>>,
command_tx: mpsc::Sender<CommandEnvelope>,
auth_tokens: Arc<RwLock<HashSet<String>>>,
) -> Self {
if let Ok(mut guard) = auth_tokens.write() {
*guard = config.auth_tokens.into_iter().collect();
}
Self {
start_at: Instant::now(),
auth_token: config.auth_token,
auth_tokens,
app_version: config.app_version,
app_build: config.app_build,
dedupe_capacity: config.dedupe_capacity,
@@ -188,11 +197,18 @@ fn require_auth(headers: &HeaderMap, state: &ApiState) -> Result<(), ApiError> {
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.ok_or(ApiError::Unauthorized)?;
let expected = format!("Bearer {}", state.auth_token);
if raw == expected {
return Ok(());
let Some(token) = raw.strip_prefix("Bearer ") else {
return Err(ApiError::Unauthorized);
};
let guard = state
.auth_tokens
.read()
.map_err(|_| ApiError::Internal("auth token lock poisoned".to_string()))?;
if guard.contains(token) {
Ok(())
} else {
Err(ApiError::Unauthorized)
}
Err(ApiError::Unauthorized)
}
enum ApiError {
@@ -240,6 +256,7 @@ mod tests {
use sprimo_protocol::v1::{
CapabilityFlags, CommandEnvelope, FrontendCommand, FrontendStateSnapshot,
};
use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc;
use tower::ServiceExt;
@@ -255,6 +272,7 @@ mod tests {
ApiConfig::default_with_token("token".to_string()),
snapshot,
tx,
Arc::new(RwLock::new(HashSet::new())),
)),
rx,
)
@@ -335,6 +353,48 @@ mod tests {
assert_eq!(received.id, command.id);
}
#[tokio::test]
async fn command_accepts_with_any_configured_token() {
let snapshot =
FrontendStateSnapshot::idle(CapabilityFlags::default());
let snapshot = Arc::new(RwLock::new(snapshot));
let (tx, mut rx) = mpsc::channel(8);
let state = Arc::new(ApiState::new(
ApiConfig::default_with_tokens(vec![
"token-a".to_string(),
"token-b".to_string(),
]),
snapshot,
tx,
Arc::new(RwLock::new(HashSet::new())),
));
let app = app_router(state);
let command = CommandEnvelope {
id: Uuid::new_v4(),
ts_ms: 1,
command: FrontendCommand::Toast {
text: "hi".to_string(),
ttl_ms: None,
},
};
let body = serde_json::to_vec(&command).expect("json");
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/command")
.header("content-type", "application/json")
.header("authorization", "Bearer token-b")
.body(Body::from(body))
.expect("request"),
)
.await
.expect("response");
assert_eq!(response.status(), StatusCode::ACCEPTED);
let received = rx.recv().await.expect("forwarded command");
assert_eq!(received.id, command.id);
}
#[tokio::test]
async fn malformed_json_returns_bad_request() {
let (state, _) = build_state();