From 2265470710c244dc36c8c61473ca9053b298a829 Mon Sep 17 00:00:00 2001 From: "Dustin C. Hatch" Date: Thu, 1 May 2025 21:11:50 -0500 Subject: [PATCH] Buffer message batches to send to HTTP server Instead of "relaying" messages from the MQTT subscriber to the HTTP request via a second channel, we now collect each batch of messages and serialize them into a buffer. This makes it possible to retry the HTTP request if it fails, without losing any data. Using the `Bytes` data structure is the most effecient way to do this, as it implements `Clone` without copying, so each iteration of the retry loop uses the same data in memory. Being able to retry failed HTTP requests eliminates the need for the "preflight" request entirely. --- Cargo.lock | 1 + Cargo.toml | 1 + src/main.rs | 138 ++++++++++++++++++++++++-------------------- src/relay.rs | 66 --------------------- src/streambuffer.rs | 98 +++++++++++++++++++++++++++++++ 5 files changed, 176 insertions(+), 128 deletions(-) delete mode 100644 src/relay.rs create mode 100644 src/streambuffer.rs diff --git a/Cargo.lock b/Cargo.lock index 28f84f8..c313410 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -813,6 +813,7 @@ dependencies = [ name = "mqtt2vl" version = "0.1.0" dependencies = [ + "bytes", "chrono", "futures", "metrics", diff --git a/Cargo.toml b/Cargo.toml index ce6a18a..2f1c746 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +bytes = "1.10.1" chrono = { version = "0.4.40", default-features = false, features = ["std", "now", "serde"] } futures = "0.3.31" metrics = "0.24.2" diff --git a/src/main.rs b/src/main.rs index 8dbb6f9..64aee37 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,15 @@ mod backoff; mod config; mod mqtt; -mod relay; +mod streambuffer; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; +use bytes::buf::BufMut; +use bytes::{Bytes, BytesMut}; use chrono::{DateTime, FixedOffset, Utc}; use futures::stream::StreamExt; use metrics::{counter, gauge}; @@ -16,11 +18,13 @@ use tokio::signal::unix::SignalKind; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::sync::Notify; use tokio::task::JoinHandle; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info, trace, warn}; use backoff::Backoff; use config::Configuration; use mqtt::MqttClient; +use streambuffer::StreamBuffer; static USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -47,6 +51,43 @@ impl TryFrom for LogRecord { } } +struct RecordBuffer { + records: Vec, +} + +impl From> for RecordBuffer { + fn from(records: Vec) -> Self { + Self { records } + } +} + +impl RecordBuffer { + fn as_bytes(&self) -> Bytes { + let mut buf = BytesMut::new(); + for msg in self.records.iter() { + match serde_json::to_string(msg) { + Ok(s) => { + buf.put(s.as_bytes()); + buf.put_u8(b'\n'); + } + Err(e) => { + error!("Failed to serialize message: {}", e); + continue; + } + }; + } + buf.freeze() + } + + fn into_bytes(self) -> Bytes { + self.as_bytes() + } + + fn len(&self) -> usize { + self.records.len() + } +} + async fn run_sender( client: reqwest::Client, url: &str, @@ -54,65 +95,44 @@ async fn run_sender( notify: &Notify, ) { let mut backoff = Backoff::default(); - let relay = relay::Relay::from(chan); - 'outer: loop { - if relay.closed().await { - break; - } - let Some(Some((stream, handle))) = (tokio::select! { - s = relay.new_stream(1) => Some(s), - _ = notify.notified() => None, - }) else { - break; + let mut stream = + StreamBuffer::from(UnboundedReceiverStream::from(chan).map(|v| { + trace!("{:?}", v); + v + })); + let timeout = Duration::from_millis(100); + 'main: loop { + let buf = if let Some(b) = stream.buffer(timeout).await { + RecordBuffer::from(b) + } else { + break 'main; }; - 'inner: loop { + let size = buf.len(); + let bytes = buf.into_bytes(); + 'retry: loop { + let body = reqwest::Body::from(bytes.clone()); let req = client .post(url) - .header(reqwest::header::ACCEPT, "application/json") - .header(reqwest::header::CONTENT_LENGTH, "0"); - debug!("Checking HTTP connection"); - tokio::select! { - _ = notify.notified() => break 'outer, - r = req.send() => { - if let Err(e) = r { - counter!("sender_http_check_errors_count").increment(1); - error!("Error in HTTP request: {}", e); - tokio::select! { - _ = notify.notified() => break 'outer, - _ = backoff.sleep() => (), - } - continue; - } - backoff.reset(); - debug!("HTTP connection successful"); - break 'inner; + .header( + reqwest::header::CONTENT_TYPE, + "application/stream+json", + ) + .body(body); + info!("Starting HTTP sender stream"); + if let Err(e) = req.send().await { + counter!("sender_http_errors_count").increment(1); + error!("HTTP request error: {}", e); + tokio::select! { + _ = backoff.sleep() => (), + _ = notify.notified() => break 'main, } + } else { + debug!("Finished HTTP POST request"); + backoff.reset(); + gauge!("message_queue_depth").decrement(size as f64); + break 'retry; } } - let stream = stream - .map(|v| { - trace!("{:?}", v); - gauge!("message_queue_depth").decrement(1); - v - }) - .map(|v| serde_json::to_string(&v).map(|v| format!("{}\n", v))); - let body = reqwest::Body::wrap_stream(stream); - let req = client - .post(url) - .header(reqwest::header::CONTENT_TYPE, "application/stream+json") - .body(body); - info!("Starting HTTP sender stream"); - if let Err(e) = req.send().await { - counter!("sender_http_stream_errors_count").increment(1); - error!("HTTP request error: {}", e); - if let Err(e) = handle.await { - error!("Error in sender: {}", e); - } - backoff.sleep().await; - } else { - debug!("Finished HTTP POST request"); - backoff.reset(); - } } info!("Stopping HTTP sender"); } @@ -227,18 +247,12 @@ fn setup_metrics( "Total number of non-UTF8 messages ignored" ); - metrics::counter!("sender_http_stream_errors_count").absolute(0); + metrics::counter!("sender_http_errors_count").absolute(0); metrics::describe_counter!( - "sender_http_stream_errors_count", + "sender_http_errors_count", "Total number of HTTP errors encountered while streaming messages", ); - metrics::counter!("sender_http_check_errors_count").absolute(0); - metrics::describe_counter!( - "sender_http_check_errors_count", - "Total number of HTTP errors encountered during preflight checks", - ); - Ok(()) } diff --git a/src/relay.rs b/src/relay.rs deleted file mode 100644 index 4540cfa..0000000 --- a/src/relay.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; - -use tokio::sync::mpsc::{self, UnboundedReceiver}; -use tokio::sync::Mutex; -use tokio::task::JoinHandle; -use tokio_stream::wrappers::ReceiverStream; -use tracing::{debug, warn}; - -pub struct Relay { - channel: Arc>>, -} - -impl From> for Relay { - fn from(channel: UnboundedReceiver) -> Self { - let channel = Arc::new(Mutex::new(channel)); - Self { channel } - } -} - -impl Relay { - pub async fn closed(&self) -> bool { - let chan = self.channel.lock().await; - chan.is_closed() - } - - pub async fn new_stream( - &self, - buffer: usize, - ) -> Option<(ReceiverStream, JoinHandle<()>)> { - let chan = self.channel.clone(); - let mut chan = chan.lock().await; - if let Some(it) = chan.recv().await { - let (tx, rx) = mpsc::channel(buffer); - let h = tokio::spawn({ - let chan = self.channel.clone(); - async move { - let mut chan = chan.lock().await; - if tx.send(it).await.is_err() { - warn!("Downstream channel closed unexpectedly"); - return; - } - let dur = Duration::from_millis(100); - loop { - tokio::select! { - it = chan.recv() => { - let Some(it) = it else { - break; - }; - if tx.send(it).await.is_err() { - debug!("Downstream channel closed"); - break; - } - } - _ = tokio::time::sleep(dur) => break - } - } - } - }); - - Some((ReceiverStream::new(rx), h)) - } else { - None - } - } -} diff --git a/src/streambuffer.rs b/src/streambuffer.rs new file mode 100644 index 0000000..e023e4c --- /dev/null +++ b/src/streambuffer.rs @@ -0,0 +1,98 @@ +use std::time::Duration; + +use futures::stream::{Stream, StreamExt}; + +pub struct StreamBuffer +where + S: Stream + Unpin, +{ + stream: S, +} + +impl StreamBuffer +where + S: Stream + Unpin, +{ + pub async fn buffer(&mut self, timeout: Duration) -> Option> { + let mut buf: Vec<_>; + if let Some(it) = self.stream.next().await { + buf = vec![it]; + } else { + return None; + } + while let Ok(Some(it)) = + tokio::time::timeout(timeout, self.stream.next()).await + { + buf.push(it); + } + Some(buf) + } +} + +impl From for StreamBuffer +where + S: Stream + Unpin, +{ + fn from(stream: S) -> Self { + Self { stream } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; + + #[tokio::test] + async fn test_buffer_channel_collects_until_timeout() { + let (tx, rx) = mpsc::channel(10); + let stream = ReceiverStream::new(rx); + + // Spawn a task to send values with a delay shorter than the timeout + tokio::spawn(async move { + tx.send(1).await.unwrap(); + tokio::time::sleep(Duration::from_millis(10)).await; + tx.send(2).await.unwrap(); + tokio::time::sleep(Duration::from_millis(10)).await; + tx.send(3).await.unwrap(); + }); + + let mut buf = StreamBuffer::from(stream); + // The timeout is longer than the delays between sends + let result = buf.buffer(Duration::from_millis(50)).await; + + assert_eq!(result, Some(vec![1, 2, 3])); + } + + #[tokio::test] + async fn test_buffer_channel_stops_on_timeout() { + let (tx, rx) = mpsc::channel(10); + let stream = ReceiverStream::new(rx); + + tokio::spawn(async move { + tx.send(1).await.unwrap(); + // No more sends + }); + + let mut buf = StreamBuffer::from(stream); + // The timeout is short and no second item will arrive + let result = buf.buffer(Duration::from_millis(20)).await; + + assert_eq!(result, Some(vec![1])); + } + + #[tokio::test] + async fn test_buffer_channel_none_when_stream_empty() { + let (_tx, rx) = mpsc::channel::(1); + let stream = ReceiverStream::new(rx); + + // The receiver is empty and closed + drop(_tx); + + let mut buf = StreamBuffer::from(stream); + let result = buf.buffer(Duration::from_millis(20)).await; + assert_eq!(result, None); + } +}