//! Application configuration use std::io::ErrorKind; use std::path::{Path, PathBuf}; use serde::Deserialize; use tracing::debug; /// Error returned by [`load_config`] #[derive(Debug)] pub enum ConfigError { Io(std::io::Error), Toml(toml::de::Error), } impl From for ConfigError { fn from(e: std::io::Error) -> Self { Self::Io(e) } } impl From for ConfigError { fn from(e: toml::de::Error) -> Self { Self::Toml(e) } } impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::Io(e) => write!(f, "Could not read config file: {}", e), Self::Toml(e) => write!(f, "Could not parse config: {}", e), } } } impl std::error::Error for ConfigError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Self::Io(e) => Some(e), Self::Toml(e) => Some(e), } } } /// Host CA Configuration #[derive(Debug, Deserialize)] pub struct HostCaConfig { /// Path to the Host CA private key file #[serde(default = "default_host_ca_key")] pub private_key_file: PathBuf, pub private_key_passphrase_file: Option, /// Duration of issued host certificates #[serde(default = "default_host_cert_duration")] pub cert_duration: u64, } impl Default for HostCaConfig { fn default() -> Self { Self { private_key_file: default_host_ca_key(), private_key_passphrase_file: None, cert_duration: default_host_cert_duration(), } } } /// CA configuration #[derive(Debug, Default, Deserialize)] pub struct CaConfig { /// Host CA configuration pub host: HostCaConfig, } /// Defines a connection to a libvirt VM host #[derive(Debug, Deserialize)] pub struct LibvirtConfig { /// libvirt Connection URI pub uri: String, } /// Top-level configuration structure #[derive(Debug, Deserialize)] pub struct Configuration { /// List of libvirt connection options #[serde(default)] pub libvirt: Vec, /// Path to the machine ID map JSON file #[serde(default = "default_machine_ids")] pub machine_ids: PathBuf, /// CA configuration #[serde(default)] pub ca: CaConfig, } impl Default for Configuration { fn default() -> Self { Self { libvirt: vec![], machine_ids: default_machine_ids(), ca: Default::default(), } } } fn default_config_path(basename: &str) -> PathBuf { dirs::config_dir().map_or(PathBuf::from(basename), |mut p| { p.push(env!("CARGO_PKG_NAME")); p.push(basename); p }) } fn default_machine_ids() -> PathBuf { default_config_path("machine_ids.json") } fn default_host_ca_key() -> PathBuf { default_config_path("host-ca.key") } fn default_host_cert_duration() -> u64 { 86400 * 30 } /// Load configuration from a TOML file /// /// If `path` is provided, the configuration will be loaded from the /// TOML file at that location. If `path` is `None`, the path will be /// inferred from the XDG Configuration directory (i.e. /// `${XDG_CONFIG_HOME}/sshca/config.toml`). /// /// If the configuration file does not exist, the default values will be /// used. If any error is encountered while reading or parsing the /// file, a [`ConfigError`] will be returned. pub fn load_config

(path: Option

) -> Result where P: AsRef, { let path = match path { Some(p) => PathBuf::from(p.as_ref()), None => default_config_path("config.toml"), }; debug!("Loading configuration from {}", path.display()); match std::fs::read_to_string(path) { Ok(s) => Ok(toml::from_str(&s)?), Err(ref e) if e.kind() == ErrorKind::NotFound => { Ok(Default::default()) } Err(e) => Err(e.into()), } } #[cfg(test)] mod test { use serial_test::serial; use super::*; #[test] #[serial] fn test_default_config() { std::env::remove_var("XDG_CONFIG_HOME"); std::env::set_var("HOME", "/home/user"); let config = Configuration::default(); assert_eq!( config.machine_ids, PathBuf::from("/home/user/.config/sshca/machine_ids.json"), ); assert_eq!(config.libvirt.len(), 0); } #[test] #[serial] fn test_default_config_path() { std::env::remove_var("XDG_CONFIG_HOME"); std::env::set_var("HOME", "/home/user"); let path = default_config_path("config.toml"); assert_eq!( path, PathBuf::from("/home/user/.config/sshca/config.toml"), ); } #[test] #[serial] fn test_default_config_path_env() { std::env::set_var("XDG_CONFIG_HOME", "/etc"); let path = default_config_path("config.toml"); assert_eq!(path, PathBuf::from("/etc/sshca/config.toml")); } #[test] fn test_config_toml() { let config_toml = r#" [[libvirt]] uri = "qemu+ssh://vmhost0.example.org/system" [[libvirt]] uri = "qemu+ssh://vmhost1.example.org/system" "#; let config: Configuration = toml::from_str(config_toml).unwrap(); assert_eq!(config.libvirt.len(), 2); assert_eq!( config.libvirt[0].uri, "qemu+ssh://vmhost0.example.org/system" ); assert_eq!( config.libvirt[1].uri, "qemu+ssh://vmhost1.example.org/system" ); } }