diff --git a/.gitignore b/.gitignore index ea8c4bf..fabfb87 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/config.toml diff --git a/Cargo.lock b/Cargo.lock index 6de7c88..097d7ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,6 +371,7 @@ dependencies = [ "serde_json", "tokio", "tokio-stream", + "toml", "tracing", "tracing-subscriber", ] @@ -739,6 +740,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1333c76748e868a4d9d1017b5ab53171dfd095f70c712fdb4653a406547f598f" +dependencies = [ + "serde", +] + [[package]] name = "tracing" version = "0.1.37" diff --git a/Cargo.toml b/Cargo.toml index cb4ac25..d6e7262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,5 +12,6 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.23.0", features = ["io-util", "macros", "net", "rt", "signal", "sync", "time"] } tokio-stream = "0.1.11" +toml = "0.5.10" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..563270f --- /dev/null +++ b/src/config.rs @@ -0,0 +1,99 @@ +use std::io::ErrorKind; +use std::path::{Path, PathBuf}; + +use serde::Deserialize; + +#[derive(Debug)] +pub enum ConfigError { + Io(std::io::Error), + Toml(toml::de::Error), +} + +impl From for ConfigError { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} + +impl From for ConfigError { + fn from(e: toml::de::Error) -> Self { + Self::Toml(e) + } +} + +impl std::fmt::Display for ConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Io(e) => write!(f, "Could not read config file: {}", e), + Self::Toml(e) => write!(f, "Could not parse config: {}", e), + } + } +} + +impl std::error::Error for ConfigError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Io(e) => Some(e), + Self::Toml(e) => Some(e), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct MqttConfig { + #[serde(default = "default_mqtt_host")] + pub host: String, + #[serde(default = "default_mqtt_port")] + pub port: u16, + #[serde(default)] + pub tls: bool, + #[serde(default)] + pub ca_file: PathBuf, + #[serde(default)] + pub username: Option, + #[serde(default)] + pub password: Option, +} + +impl Default for MqttConfig { + fn default() -> Self { + Self { + host: default_mqtt_host(), + port: default_mqtt_port(), + tls: Default::default(), + ca_file: Default::default(), + username: Default::default(), + password: Default::default(), + } + } +} + +#[derive(Debug, Default, Deserialize)] +pub struct Configuration { + pub mqtt: MqttConfig, +} + +fn default_mqtt_host() -> String { + "localhost".into() +} + +const fn default_mqtt_port() -> u16 { + 1883 +} + +pub fn load_config

(path: Option

) -> Result +where + P: AsRef, +{ + let path = match path { + Some(p) => PathBuf::from(p.as_ref()), + None => PathBuf::from("config.toml"), + }; + match std::fs::read_to_string(path) { + Ok(s) => Ok(toml::from_str(&s)?), + Err(ref e) if e.kind() == ErrorKind::NotFound => { + Ok(Default::default()) + } + Err(e) => Err(e.into()), + } +} diff --git a/src/main.rs b/src/main.rs index 634b78d..2f2b86d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,7 @@ mod browser; +mod config; mod marionette; +mod mqtt; mod session; use tokio::signal::unix::{self, SignalKind}; @@ -15,13 +17,22 @@ async fn main() { .with_writer(std::io::stderr) .init(); + let config = + config::load_config(std::env::var("MQTTMARIONETTE_CONFIG").ok()) + .unwrap(); + let mut sig_term = unix::signal(SignalKind::terminate()).unwrap(); let mut sig_int = unix::signal(SignalKind::interrupt()).unwrap(); - let session = Session::begin().await.unwrap(); + let task = tokio::spawn(async move { + let session = Session::begin(config).await.unwrap(); + session.run().await; + }); tokio::select! { _ = sig_term.recv() => info!("Received SIGTERM"), _ = sig_int.recv() => info!("Received SIGINT"), }; + + task.abort(); } diff --git a/src/mqtt.rs b/src/mqtt.rs new file mode 100644 index 0000000..a22fa3c --- /dev/null +++ b/src/mqtt.rs @@ -0,0 +1,63 @@ +use std::time::Duration; + +pub use paho_mqtt::Error; +use paho_mqtt::{ + AsyncClient, AsyncReceiver, ConnectOptions, ConnectOptionsBuilder, + CreateOptionsBuilder, Message, SslOptionsBuilder, +}; +use tokio_stream::StreamExt; +use tracing::{info, trace}; + +use crate::config::Configuration; + +pub struct MqttClient { + client: AsyncClient, + stream: AsyncReceiver>, +} + +impl MqttClient { + pub async fn new(config: &Configuration) -> Result { + let uri = format!( + "{}://{}:{}", + if config.mqtt.tls { "ssl" } else { "tcp" }, + config.mqtt.host, + config.mqtt.port + ); + info!("Connecting to MQTT server {}", uri); + let client_opts = + CreateOptionsBuilder::new().server_uri(uri).finalize(); + let mut client = AsyncClient::new(client_opts)?; + let stream = client.get_stream(10); + client.connect(Self::conn_opts(config)?).await?; + info!("Successfully connected to MQTT broker"); + + Ok(Self { client, stream }) + } + + pub async fn run(mut self) { + while let Some(msg) = self.stream.next().await { + let Some(msg) = msg else {continue}; + trace!("Received message: {:?}", msg); + } + } + + fn conn_opts(config: &Configuration) -> Result { + let mut conn_opts = ConnectOptionsBuilder::new(); + conn_opts.automatic_reconnect( + Duration::from_millis(500), + Duration::from_secs(30), + ); + if config.mqtt.tls { + let ssl_opts = SslOptionsBuilder::new() + .trust_store(&config.mqtt.ca_file)? + .finalize(); + conn_opts.ssl_options(ssl_opts); + } + if let [Some(username), Some(password)] = + [&config.mqtt.username, &config.mqtt.password] + { + conn_opts.user_name(username).password(password); + } + Ok(conn_opts.finalize()) + } +} diff --git a/src/session.rs b/src/session.rs index a081a66..2883126 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,10 @@ -use tracing::{debug, info}; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use crate::mqtt::MqttClient; use crate::browser::{Browser, BrowserError}; +use crate::config::Configuration; use crate::marionette::error::ConnectionError; use crate::marionette::Marionette; @@ -53,12 +57,13 @@ impl std::error::Error for SessionError { } pub struct Session { + config: Configuration, browser: Browser, marionette: Marionette, } impl Session { - pub async fn begin() -> Result { + pub async fn begin(config: Configuration) -> Result { debug!("Launching Firefox"); let browser = Browser::launch()?; browser.wait_ready().await?; @@ -72,8 +77,24 @@ impl Session { let ses = marionette.new_session().await?; debug!("Started Marionette session {}", ses.session_id); Ok(Self { + config, browser, marionette, }) } + + pub async fn run(&self) { + let client; + loop { + match MqttClient::new(&self.config).await { + Ok(c) => { + client = c; + break; + } + Err(e) => warn!("Failed to connect to MQTT server: {}", e), + } + tokio::time::sleep(Duration::from_secs(1)).await; + } + client.run().await; + } }