mod host; use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use axum::async_trait; use axum::extract::FromRequestParts; use axum::headers::authorization::Bearer; use axum::headers::{Authorization, Host}; use axum::http::request::Parts; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{RequestPartsExt, Router, TypedHeader}; use tokio::sync::RwLock; use tracing::debug; use uuid::Uuid; use crate::auth::{self, Claims}; use crate::config::Configuration; use crate::machine_id; struct Context { config: Arc, cache: RwLock>, } type State = Arc; pub struct AuthError; impl IntoResponse for AuthError { fn into_response(self) -> Response { (StatusCode::UNAUTHORIZED, "Unauthorized").into_response() } } #[async_trait] impl FromRequestParts> for Claims { type Rejection = AuthError; async fn from_request_parts( parts: &mut Parts, ctx: &State, ) -> Result { let TypedHeader(Authorization(bearer)) = parts .extract::>>() .await .map_err(|e| { debug!("Failed to extract token from HTTP request: {}", e); AuthError })?; let host = parts.extract::>().await.map_or_else( |_| "localhost".to_owned(), |v| v.hostname().to_owned(), ); let hostname = auth::get_token_subject(bearer.token()).map_err(|e| { debug!("Could not get token subject: {}", e); AuthError })?; let machine_id = get_machine_id(&hostname, ctx).await.ok_or_else(|| { debug!("No machine ID found for host {}", hostname); AuthError })?; let claims = auth::validate_token( bearer.token(), &hostname, &machine_id, &host, ) .map_err(|e| { debug!("Invalid auth token: {}", e); AuthError })?; debug!("Successfully authenticated request from host {}", hostname); Ok(claims) } } pub fn make_app(config: Configuration) -> Router { let ctx = Arc::new(Context { config: config.into(), cache: RwLock::new(Default::default()), }); Router::new() .route("/", get(|| async { "UP" })) .route("/host/sign", post(host::sign_host_cert)) .with_state(ctx) } async fn get_machine_id(hostname: &str, ctx: &State) -> Option { let cache = ctx.cache.read().await; if let Some((ts, m)) = cache.get(hostname) { if ts.elapsed() < Duration::from_secs(60) { debug!("Found cached machine ID for {}", hostname); return Some(*m); } else { debug!("Cached machine ID for {} has expired", hostname); } } drop(cache); let machine_id = machine_id::get_machine_id(hostname, ctx.config.clone()).await?; let mut cache = ctx.cache.write().await; debug!("Caching machine ID for {}", hostname); cache.insert(hostname.into(), (Instant::now(), machine_id)); Some(machine_id) } #[cfg(test)] mod test { use axum::body::Body; use axum::http::Request; use tower::ServiceExt; use super::*; #[tokio::test] async fn test_up() { let app = make_app(Configuration::default()); let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); assert_eq!(&body[..], b"UP"); } }