214 lines
5.5 KiB
Rust
214 lines
5.5 KiB
Rust
//! Application configuration
|
|
use std::io::ErrorKind;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use serde::Deserialize;
|
|
use tracing::debug;
|
|
|
|
/// Error returned by [`load_config`]
|
|
#[derive(Debug)]
|
|
pub enum ConfigError {
|
|
Io(std::io::Error),
|
|
Toml(toml::de::Error),
|
|
}
|
|
|
|
impl From<std::io::Error> for ConfigError {
|
|
fn from(e: std::io::Error) -> Self {
|
|
Self::Io(e)
|
|
}
|
|
}
|
|
|
|
impl From<toml::de::Error> for ConfigError {
|
|
fn from(e: toml::de::Error) -> Self {
|
|
Self::Toml(e)
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for ConfigError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
match self {
|
|
Self::Io(e) => write!(f, "Could not read config file: {}", e),
|
|
Self::Toml(e) => write!(f, "Could not parse config: {}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ConfigError {
|
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
match self {
|
|
Self::Io(e) => Some(e),
|
|
Self::Toml(e) => Some(e),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Host CA Configuration
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct HostCaConfig {
|
|
/// Path to the Host CA private key file
|
|
#[serde(default = "default_host_ca_key")]
|
|
pub private_key_file: PathBuf,
|
|
pub private_key_passphrase_file: Option<PathBuf>,
|
|
|
|
/// Duration of issued host certificates
|
|
#[serde(default = "default_host_cert_duration")]
|
|
pub cert_duration: u64,
|
|
}
|
|
|
|
impl Default for HostCaConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
private_key_file: default_host_ca_key(),
|
|
private_key_passphrase_file: None,
|
|
cert_duration: default_host_cert_duration(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// CA configuration
|
|
#[derive(Debug, Default, Deserialize)]
|
|
pub struct CaConfig {
|
|
/// Host CA configuration
|
|
pub host: HostCaConfig,
|
|
}
|
|
|
|
/// Defines a connection to a libvirt VM host
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct LibvirtConfig {
|
|
/// libvirt Connection URI
|
|
pub uri: String,
|
|
}
|
|
|
|
/// Top-level configuration structure
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct Configuration {
|
|
/// List of libvirt connection options
|
|
#[serde(default)]
|
|
pub libvirt: Vec<LibvirtConfig>,
|
|
/// Path to the machine ID map JSON file
|
|
#[serde(default = "default_machine_ids")]
|
|
pub machine_ids: PathBuf,
|
|
/// CA configuration
|
|
#[serde(default)]
|
|
pub ca: CaConfig,
|
|
}
|
|
|
|
impl Default for Configuration {
|
|
fn default() -> Self {
|
|
Self {
|
|
libvirt: vec![],
|
|
machine_ids: default_machine_ids(),
|
|
ca: Default::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn default_config_path(basename: &str) -> PathBuf {
|
|
dirs::config_dir().map_or(PathBuf::from(basename), |mut p| {
|
|
p.push(env!("CARGO_PKG_NAME"));
|
|
p.push(basename);
|
|
p
|
|
})
|
|
}
|
|
|
|
fn default_machine_ids() -> PathBuf {
|
|
default_config_path("machine_ids.json")
|
|
}
|
|
|
|
fn default_host_ca_key() -> PathBuf {
|
|
default_config_path("host-ca.key")
|
|
}
|
|
|
|
fn default_host_cert_duration() -> u64 {
|
|
86400 * 30
|
|
}
|
|
|
|
/// Load configuration from a TOML file
|
|
///
|
|
/// If `path` is provided, the configuration will be loaded from the
|
|
/// TOML file at that location. If `path` is `None`, the path will be
|
|
/// inferred from the XDG Configuration directory (i.e.
|
|
/// `${XDG_CONFIG_HOME}/sshca/config.toml`).
|
|
///
|
|
/// If the configuration file does not exist, the default values will be
|
|
/// used. If any error is encountered while reading or parsing the
|
|
/// file, a [`ConfigError`] will be returned.
|
|
pub fn load_config<P>(path: Option<P>) -> Result<Configuration, ConfigError>
|
|
where
|
|
P: AsRef<Path>,
|
|
{
|
|
let path = match path {
|
|
Some(p) => PathBuf::from(p.as_ref()),
|
|
None => default_config_path("config.toml"),
|
|
};
|
|
debug!("Loading configuration from {}", path.display());
|
|
match std::fs::read_to_string(path) {
|
|
Ok(s) => Ok(toml::from_str(&s)?),
|
|
Err(ref e) if e.kind() == ErrorKind::NotFound => {
|
|
Ok(Default::default())
|
|
}
|
|
Err(e) => Err(e.into()),
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use serial_test::serial;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn test_default_config() {
|
|
std::env::remove_var("XDG_CONFIG_HOME");
|
|
std::env::set_var("HOME", "/home/user");
|
|
let config = Configuration::default();
|
|
assert_eq!(
|
|
config.machine_ids,
|
|
PathBuf::from("/home/user/.config/sshca/machine_ids.json"),
|
|
);
|
|
assert_eq!(config.libvirt.len(), 0);
|
|
}
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn test_default_config_path() {
|
|
std::env::remove_var("XDG_CONFIG_HOME");
|
|
std::env::set_var("HOME", "/home/user");
|
|
let path = default_config_path("config.toml");
|
|
assert_eq!(
|
|
path,
|
|
PathBuf::from("/home/user/.config/sshca/config.toml"),
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn test_default_config_path_env() {
|
|
std::env::set_var("XDG_CONFIG_HOME", "/etc");
|
|
let path = default_config_path("config.toml");
|
|
assert_eq!(path, PathBuf::from("/etc/sshca/config.toml"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_toml() {
|
|
let config_toml = r#"
|
|
[[libvirt]]
|
|
uri = "qemu+ssh://vmhost0.example.org/system"
|
|
|
|
[[libvirt]]
|
|
uri = "qemu+ssh://vmhost1.example.org/system"
|
|
"#;
|
|
let config: Configuration = toml::from_str(config_toml).unwrap();
|
|
assert_eq!(config.libvirt.len(), 2);
|
|
assert_eq!(
|
|
config.libvirt[0].uri,
|
|
"qemu+ssh://vmhost0.example.org/system"
|
|
);
|
|
assert_eq!(
|
|
config.libvirt[1].uri,
|
|
"qemu+ssh://vmhost1.example.org/system"
|
|
);
|
|
}
|
|
}
|