134 lines
3.8 KiB
Rust
134 lines
3.8 KiB
Rust
mod host;
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use axum::async_trait;
|
|
use axum::extract::FromRequestParts;
|
|
use axum::headers::authorization::Bearer;
|
|
use axum::headers::{Authorization, Host};
|
|
use axum::http::request::Parts;
|
|
use axum::http::StatusCode;
|
|
use axum::response::{IntoResponse, Response};
|
|
use axum::routing::{get, post};
|
|
use axum::{RequestPartsExt, Router, TypedHeader};
|
|
use tokio::sync::RwLock;
|
|
use tracing::debug;
|
|
use uuid::Uuid;
|
|
|
|
use crate::auth::{self, Claims};
|
|
use crate::config::Configuration;
|
|
use crate::machine_id;
|
|
|
|
struct Context {
|
|
config: Arc<Configuration>,
|
|
cache: RwLock<HashMap<String, (Instant, Uuid)>>,
|
|
}
|
|
|
|
type State = Arc<Context>;
|
|
|
|
pub struct AuthError;
|
|
|
|
impl IntoResponse for AuthError {
|
|
fn into_response(self) -> Response {
|
|
(StatusCode::UNAUTHORIZED, "Unauthorized").into_response()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl FromRequestParts<Arc<Context>> for Claims {
|
|
type Rejection = AuthError;
|
|
|
|
async fn from_request_parts(
|
|
parts: &mut Parts,
|
|
ctx: &State,
|
|
) -> Result<Self, Self::Rejection> {
|
|
let TypedHeader(Authorization(bearer)) = parts
|
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
|
.await
|
|
.map_err(|e| {
|
|
debug!("Failed to extract token from HTTP request: {}", e);
|
|
AuthError
|
|
})?;
|
|
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|
|
|_| "localhost".to_owned(),
|
|
|v| v.hostname().to_owned(),
|
|
);
|
|
|
|
let hostname =
|
|
auth::get_token_subject(bearer.token()).map_err(|e| {
|
|
debug!("Could not get token subject: {}", e);
|
|
AuthError
|
|
})?;
|
|
let machine_id =
|
|
get_machine_id(&hostname, ctx).await.ok_or_else(|| {
|
|
debug!("No machine ID found for host {}", hostname);
|
|
AuthError
|
|
})?;
|
|
let claims = auth::validate_token(
|
|
bearer.token(),
|
|
&hostname,
|
|
&machine_id,
|
|
&host,
|
|
)
|
|
.map_err(|e| {
|
|
debug!("Invalid auth token: {}", e);
|
|
AuthError
|
|
})?;
|
|
debug!("Successfully authenticated request from host {}", hostname);
|
|
Ok(claims)
|
|
}
|
|
}
|
|
|
|
pub fn make_app(config: Configuration) -> Router {
|
|
let ctx = Arc::new(Context {
|
|
config: config.into(),
|
|
cache: RwLock::new(Default::default()),
|
|
});
|
|
Router::new()
|
|
.route("/", get(|| async { "UP" }))
|
|
.route("/host/sign", post(host::sign_host_cert))
|
|
.with_state(ctx)
|
|
}
|
|
|
|
async fn get_machine_id(hostname: &str, ctx: &State) -> Option<Uuid> {
|
|
let cache = ctx.cache.read().await;
|
|
if let Some((ts, m)) = cache.get(hostname) {
|
|
if ts.elapsed() < Duration::from_secs(60) {
|
|
debug!("Found cached machine ID for {}", hostname);
|
|
return Some(*m);
|
|
} else {
|
|
debug!("Cached machine ID for {} has expired", hostname);
|
|
}
|
|
}
|
|
drop(cache);
|
|
let machine_id =
|
|
machine_id::get_machine_id(hostname, ctx.config.clone()).await?;
|
|
let mut cache = ctx.cache.write().await;
|
|
debug!("Caching machine ID for {}", hostname);
|
|
cache.insert(hostname.into(), (Instant::now(), machine_id));
|
|
Some(machine_id)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use axum::body::Body;
|
|
use axum::http::Request;
|
|
use tower::ServiceExt;
|
|
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_up() {
|
|
let app = make_app(Configuration::default());
|
|
let response = app
|
|
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
|
assert_eq!(&body[..], b"UP");
|
|
}
|
|
}
|