sshca/server/src/config.rs

214 lines
5.5 KiB
Rust

//! Application configuration
use std::io::ErrorKind;
use std::path::{Path, PathBuf};
use serde::Deserialize;
use tracing::debug;
/// Error returned by [`load_config`]
#[derive(Debug)]
pub enum ConfigError {
Io(std::io::Error),
Toml(toml::de::Error),
}
impl From<std::io::Error> for ConfigError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<toml::de::Error> for ConfigError {
fn from(e: toml::de::Error) -> Self {
Self::Toml(e)
}
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "Could not read config file: {}", e),
Self::Toml(e) => write!(f, "Could not parse config: {}", e),
}
}
}
impl std::error::Error for ConfigError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Toml(e) => Some(e),
}
}
}
/// Host CA Configuration
#[derive(Debug, Deserialize)]
pub struct HostCaConfig {
/// Path to the Host CA private key file
#[serde(default = "default_host_ca_key")]
pub private_key_file: PathBuf,
pub private_key_passphrase_file: Option<PathBuf>,
/// Duration of issued host certificates
#[serde(default = "default_host_cert_duration")]
pub cert_duration: u64,
}
impl Default for HostCaConfig {
fn default() -> Self {
Self {
private_key_file: default_host_ca_key(),
private_key_passphrase_file: None,
cert_duration: default_host_cert_duration(),
}
}
}
/// CA configuration
#[derive(Debug, Default, Deserialize)]
pub struct CaConfig {
/// Host CA configuration
pub host: HostCaConfig,
}
/// Defines a connection to a libvirt VM host
#[derive(Debug, Deserialize)]
pub struct LibvirtConfig {
/// libvirt Connection URI
pub uri: String,
}
/// Top-level configuration structure
#[derive(Debug, Deserialize)]
pub struct Configuration {
/// List of libvirt connection options
#[serde(default)]
pub libvirt: Vec<LibvirtConfig>,
/// Path to the machine ID map JSON file
#[serde(default = "default_machine_ids")]
pub machine_ids: PathBuf,
/// CA configuration
#[serde(default)]
pub ca: CaConfig,
}
impl Default for Configuration {
fn default() -> Self {
Self {
libvirt: vec![],
machine_ids: default_machine_ids(),
ca: Default::default(),
}
}
}
fn default_config_path(basename: &str) -> PathBuf {
dirs::config_dir().map_or(PathBuf::from(basename), |mut p| {
p.push(env!("CARGO_PKG_NAME"));
p.push(basename);
p
})
}
fn default_machine_ids() -> PathBuf {
default_config_path("machine_ids.json")
}
fn default_host_ca_key() -> PathBuf {
default_config_path("host-ca.key")
}
fn default_host_cert_duration() -> u64 {
86400 * 30
}
/// Load configuration from a TOML file
///
/// If `path` is provided, the configuration will be loaded from the
/// TOML file at that location. If `path` is `None`, the path will be
/// inferred from the XDG Configuration directory (i.e.
/// `${XDG_CONFIG_HOME}/sshca/config.toml`).
///
/// If the configuration file does not exist, the default values will be
/// used. If any error is encountered while reading or parsing the
/// file, a [`ConfigError`] will be returned.
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
where
P: AsRef<Path>,
{
let path = match path {
Some(p) => PathBuf::from(p.as_ref()),
None => default_config_path("config.toml"),
};
debug!("Loading configuration from {}", path.display());
match std::fs::read_to_string(path) {
Ok(s) => Ok(toml::from_str(&s)?),
Err(ref e) if e.kind() == ErrorKind::NotFound => {
Ok(Default::default())
}
Err(e) => Err(e.into()),
}
}
#[cfg(test)]
mod test {
use serial_test::serial;
use super::*;
#[test]
#[serial]
fn test_default_config() {
std::env::remove_var("XDG_CONFIG_HOME");
std::env::set_var("HOME", "/home/user");
let config = Configuration::default();
assert_eq!(
config.machine_ids,
PathBuf::from("/home/user/.config/sshca/machine_ids.json"),
);
assert_eq!(config.libvirt.len(), 0);
}
#[test]
#[serial]
fn test_default_config_path() {
std::env::remove_var("XDG_CONFIG_HOME");
std::env::set_var("HOME", "/home/user");
let path = default_config_path("config.toml");
assert_eq!(
path,
PathBuf::from("/home/user/.config/sshca/config.toml"),
);
}
#[test]
#[serial]
fn test_default_config_path_env() {
std::env::set_var("XDG_CONFIG_HOME", "/etc");
let path = default_config_path("config.toml");
assert_eq!(path, PathBuf::from("/etc/sshca/config.toml"));
}
#[test]
fn test_config_toml() {
let config_toml = r#"
[[libvirt]]
uri = "qemu+ssh://vmhost0.example.org/system"
[[libvirt]]
uri = "qemu+ssh://vmhost1.example.org/system"
"#;
let config: Configuration = toml::from_str(config_toml).unwrap();
assert_eq!(config.libvirt.len(), 2);
assert_eq!(
config.libvirt[0].uri,
"qemu+ssh://vmhost0.example.org/system"
);
assert_eq!(
config.libvirt[1].uri,
"qemu+ssh://vmhost1.example.org/system"
);
}
}