server: Move auth logic to host module

In preparation for adding a separate authorization strategy for client
requests, I have moved the implementation of the authorization strategy
for host requests in to the `server::host` module.
master
Dustin 2023-11-16 19:58:26 -06:00
parent 818cfc94c2
commit 839d756a28
3 changed files with 93 additions and 86 deletions

View File

@ -19,7 +19,7 @@ use uuid::Uuid;
/// JWT Token Claims /// JWT Token Claims
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct Claims { pub struct HostClaims {
/// Token subject (machine hostname) /// Token subject (machine hostname)
pub sub: String, pub sub: String,
} }
@ -34,7 +34,7 @@ pub fn get_token_subject(token: &str) -> Result<String> {
v.insecure_disable_signature_validation(); v.insecure_disable_signature_validation();
v.set_required_spec_claims(&["sub"]); v.set_required_spec_claims(&["sub"]);
let k = DecodingKey::from_secret(b""); let k = DecodingKey::from_secret(b"");
let data: TokenData<Claims> = decode(token, &k, &v)?; let data: TokenData<HostClaims> = decode(token, &k, &v)?;
Ok(data.claims.sub) Ok(data.claims.sub)
} }
@ -46,12 +46,12 @@ pub fn get_token_subject(token: &str) -> Result<String> {
/// `service` argument, and is within its validity period (not before/expires). /// `service` argument, and is within its validity period (not before/expires).
/// The token must be signed with HMAC-SHA256 using the host's machine ID as /// The token must be signed with HMAC-SHA256 using the host's machine ID as
/// the secret key. /// the secret key.
pub fn validate_token( pub fn validate_host_token(
token: &str, token: &str,
hostname: &str, hostname: &str,
machine_id: &Uuid, machine_id: &Uuid,
service: &str, service: &str,
) -> Result<Claims> { ) -> Result<HostClaims> {
let mut v = Validation::new(Algorithm::HS256); let mut v = Validation::new(Algorithm::HS256);
v.validate_nbf = true; v.validate_nbf = true;
v.set_issuer(&[hostname]); v.set_issuer(&[hostname]);
@ -66,7 +66,7 @@ pub fn validate_token(
OsRng.fill_bytes(&mut secret); OsRng.fill_bytes(&mut secret);
} }
let k = DecodingKey::from_secret(&secret); let k = DecodingKey::from_secret(&secret);
let data: TokenData<Claims> = decode(token, &k, &v)?; let data: TokenData<HostClaims> = decode(token, &k, &v)?;
Ok(data.claims) Ok(data.claims)
} }
@ -130,7 +130,12 @@ pub(crate) mod test {
let machine_id = uuid!("9afd42e5-4ac3-4530-90c4-191869063ef9"); let machine_id = uuid!("9afd42e5-4ac3-4530-90c4-191869063ef9");
let token = make_token(hostname, machine_id); let token = make_token(hostname, machine_id);
validate_token(&token, hostname, &machine_id, "sshca.example.org") validate_host_token(
.unwrap(); &token,
hostname,
&machine_id,
"sshca.example.org",
)
.unwrap();
} }
} }

View File

@ -1,16 +1,26 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::headers::authorization::Bearer;
use axum::headers::{Authorization, Host};
use axum::http::request::Parts;
use axum::extract::multipart::{Multipart, MultipartError}; use axum::extract::multipart::{Multipart, MultipartError};
use axum::extract::State; use axum::extract::State;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::{RequestPartsExt, TypedHeader};
use serde::Serialize; use serde::Serialize;
use ssh_key::Algorithm; use ssh_key::Algorithm;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::auth::Claims; use crate::auth::{self, HostClaims};
use crate::ca; use crate::ca;
use crate::machine_id;
use super::{AuthError, Context};
#[derive(Serialize)] #[derive(Serialize)]
pub struct SignKeyResponse { pub struct SignKeyResponse {
@ -80,6 +90,71 @@ impl IntoResponse for SignKeyError {
} }
} }
#[async_trait]
impl FromRequestParts<Arc<Context>> for HostClaims {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
ctx: &super::State,
) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| {
debug!("Failed to extract token from HTTP request: {}", e);
AuthError
})?;
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|_| "localhost".to_owned(),
|v| v.hostname().to_owned(),
);
let hostname =
auth::get_token_subject(bearer.token()).map_err(|e| {
debug!("Could not get token subject: {}", e);
AuthError
})?;
let machine_id =
get_machine_id(&hostname, ctx).await.ok_or_else(|| {
debug!("No machine ID found for host {}", hostname);
AuthError
})?;
let claims = auth::validate_host_token(
bearer.token(),
&hostname,
&machine_id,
&host,
)
.map_err(|e| {
debug!("Invalid auth token: {}", e);
AuthError
})?;
debug!("Successfully authenticated request from host {}", hostname);
Ok(claims)
}
}
async fn get_machine_id(hostname: &str, ctx: &super::State) -> Option<Uuid> {
let cache = ctx.cache.read().await;
if let Some((ts, m)) = cache.get(hostname) {
if ts.elapsed() < Duration::from_secs(60) {
debug!("Found cached machine ID for {}", hostname);
return Some(*m);
} else {
debug!("Cached machine ID for {} has expired", hostname);
}
}
drop(cache);
let machine_id =
machine_id::get_machine_id(hostname, ctx.config.clone()).await?;
let mut cache = ctx.cache.write().await;
debug!("Caching machine ID for {}", hostname);
cache.insert(hostname.into(), (Instant::now(), machine_id));
Some(machine_id)
}
#[derive(Default)] #[derive(Default)]
struct SignKeyRequest { struct SignKeyRequest {
hostname: String, hostname: String,
@ -87,7 +162,7 @@ struct SignKeyRequest {
} }
pub(super) async fn sign_host_cert( pub(super) async fn sign_host_cert(
claims: Claims, claims: HostClaims,
State(ctx): State<super::State>, State(ctx): State<super::State>,
mut form: Multipart, mut form: Multipart,
) -> Result<String, SignKeyError> { ) -> Result<String, SignKeyError> {
@ -136,8 +211,7 @@ pub(super) async fn sign_host_cert(
pubkey.algorithm().as_str(), pubkey.algorithm().as_str(),
hostname hostname
); );
let cert = let cert = ca::sign_cert(&hostname, &pubkey, duration, &privkey, &[])?;
ca::sign_cert(&hostname, &pubkey, duration, &privkey, &[])?;
info!( info!(
"Signed {} key for {}", "Signed {} key for {}",
pubkey.algorithm().as_str(), pubkey.algorithm().as_str(),

View File

@ -2,24 +2,16 @@ mod host;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::Instant;
use axum::async_trait;
use axum::extract::FromRequestParts;
use axum::headers::authorization::Bearer;
use axum::headers::{Authorization, Host};
use axum::http::request::Parts;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{RequestPartsExt, Router, TypedHeader}; use axum::Router;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::debug;
use uuid::Uuid; use uuid::Uuid;
use crate::auth::{self, Claims};
use crate::config::Configuration; use crate::config::Configuration;
use crate::machine_id;
struct Context { struct Context {
config: Arc<Configuration>, config: Arc<Configuration>,
@ -36,51 +28,6 @@ impl IntoResponse for AuthError {
} }
} }
#[async_trait]
impl FromRequestParts<Arc<Context>> for Claims {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
ctx: &State,
) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| {
debug!("Failed to extract token from HTTP request: {}", e);
AuthError
})?;
let host = parts.extract::<TypedHeader<Host>>().await.map_or_else(
|_| "localhost".to_owned(),
|v| v.hostname().to_owned(),
);
let hostname =
auth::get_token_subject(bearer.token()).map_err(|e| {
debug!("Could not get token subject: {}", e);
AuthError
})?;
let machine_id =
get_machine_id(&hostname, ctx).await.ok_or_else(|| {
debug!("No machine ID found for host {}", hostname);
AuthError
})?;
let claims = auth::validate_token(
bearer.token(),
&hostname,
&machine_id,
&host,
)
.map_err(|e| {
debug!("Invalid auth token: {}", e);
AuthError
})?;
debug!("Successfully authenticated request from host {}", hostname);
Ok(claims)
}
}
pub fn make_app(config: Configuration) -> Router { pub fn make_app(config: Configuration) -> Router {
let ctx = Arc::new(Context { let ctx = Arc::new(Context {
config: config.into(), config: config.into(),
@ -92,25 +39,6 @@ pub fn make_app(config: Configuration) -> Router {
.with_state(ctx) .with_state(ctx)
} }
async fn get_machine_id(hostname: &str, ctx: &State) -> Option<Uuid> {
let cache = ctx.cache.read().await;
if let Some((ts, m)) = cache.get(hostname) {
if ts.elapsed() < Duration::from_secs(60) {
debug!("Found cached machine ID for {}", hostname);
return Some(*m);
} else {
debug!("Cached machine ID for {} has expired", hostname);
}
}
drop(cache);
let machine_id =
machine_id::get_machine_id(hostname, ctx.config.clone()).await?;
let mut cache = ctx.cache.write().await;
debug!("Caching machine ID for {}", hostname);
cache.insert(hostname.into(), (Instant::now(), machine_id));
Some(machine_id)
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use axum::body::Body; use axum::body::Body;