From 4674fb3a3994351f1f037a85539257a0fb639304 Mon Sep 17 00:00:00 2001 From: Mauro D Date: Thu, 4 Aug 2022 12:20:40 +0000 Subject: [PATCH] Add support for trusted redirects. --- Cargo.toml | 3 +- src/blob/download.rs | 1 + src/blob/upload.rs | 1 + src/client.rs | 70 +++++++++++++++++++++++++++++++++++++- src/event_source/stream.rs | 1 + 5 files changed, 74 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 51d3b12..0570a60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,8 @@ readme = "README.md" serde = { version = "1.0", features = ["derive"]} serde_json = "1.0" chrono = { version = "0.4", features = ["serde"]} -reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]} +#reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]} +reqwest = { git = "https://github.com/stalwartlabs/reqwest.git", default-features = false, features = ["stream", "rustls-tls"]} futures-util = "0.3" async-stream = "0.3" base64 = "0.13" diff --git a/src/blob/download.rs b/src/blob/download.rs index c58cff3..8be148e 100644 --- a/src/blob/download.rs +++ b/src/blob/download.rs @@ -39,6 +39,7 @@ impl Client { Client::handle_error( reqwest::Client::builder() .timeout(Duration::from_millis(self.timeout())) + .redirect(self.redirect_policy()) .default_headers(headers) .build()? .get(download_url) diff --git a/src/blob/upload.rs b/src/blob/upload.rs index 84ba009..2daa0b8 100644 --- a/src/blob/upload.rs +++ b/src/blob/upload.rs @@ -48,6 +48,7 @@ impl Client { &Client::handle_error( reqwest::Client::builder() .timeout(Duration::from_millis(self.timeout())) + .redirect(self.redirect_policy()) .default_headers(self.headers().clone()) .build()? .post(upload_url) diff --git a/src/client.rs b/src/client.rs index 3ce719c..3b90d4f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,9 +6,10 @@ use std::{ time::Duration, }; +use ahash::AHashSet; use reqwest::{ header::{self}, - Response, + redirect, Response, }; use serde::de::DeserializeOwned; @@ -36,6 +37,7 @@ pub struct Client { session_url: String, api_url: String, session_updated: AtomicBool, + trusted_hosts: Arc>, upload_url: Vec>, download_url: Vec>, @@ -53,6 +55,22 @@ pub struct Client { impl Client { pub async fn connect(url: &str, credentials: impl Into) -> crate::Result { + Self::connect_(url, credentials, None::>).await + } + + pub async fn connect_with_trusted( + url: &str, + credentials: impl Into, + trusted_hosts: impl IntoIterator>, + ) -> crate::Result { + Self::connect_(url, credentials, trusted_hosts.into()).await + } + + async fn connect_( + url: &str, + credentials: impl Into, + trusted_hosts: Option>>, + ) -> crate::Result { let authorization = match credentials.into() { Credentials::Basic(s) => format!("Basic {}", s), Credentials::Bearer(s) => format!("Bearer {}", s), @@ -67,10 +85,31 @@ impl Client { header::HeaderValue::from_str(&authorization).unwrap(), ); + let trusted_hosts = Arc::new( + trusted_hosts + .map(|hosts| hosts.into_iter().map(|h| h.into()).collect::>()) + .unwrap_or_default(), + ); + + let trusted_hosts_ = trusted_hosts.clone(); let session: Session = serde_json::from_slice( &Client::handle_error( reqwest::Client::builder() .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS)) + .redirect(redirect::Policy::custom(move |attempt| { + if attempt.previous().len() > 5 { + attempt.error("Too many redirects.") + } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) ) + { + attempt.follow_trusted() + } else { + let message = format!( + "Aborting redirect request to unknown host '{}'.", + attempt.url().host_str().unwrap_or("") + ); + attempt.error(message) + } + })) .default_headers(headers.clone()) .build()? .get(url) @@ -101,6 +140,7 @@ impl Client { session: parking_lot::Mutex::new(Arc::new(session)), session_url: url.to_string(), session_updated: true.into(), + trusted_hosts, #[cfg(feature = "websockets")] authorization, timeout: DEFAULT_TIMEOUT_MS, @@ -116,6 +156,14 @@ impl Client { self } + pub fn set_trusted_hosts( + &mut self, + trusted_hosts: impl IntoIterator>, + ) -> &mut Self { + self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect()); + self + } + pub fn timeout(&self) -> u64 { self.timeout } @@ -132,6 +180,24 @@ impl Client { &self.headers } + pub(crate) fn redirect_policy(&self) -> redirect::Policy { + let trusted_hosts = self.trusted_hosts.clone(); + redirect::Policy::custom(move |attempt| { + if attempt.previous().len() > 5 { + attempt.error("Too many redirects.") + } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) ) + { + attempt.follow_trusted() + } else { + let message = format!( + "Aborting redirect request to unknown host '{}'.", + attempt.url().host_str().unwrap_or("") + ); + attempt.error(message) + } + }) + } + pub async fn send( &self, request: &request::Request<'_>, @@ -142,6 +208,7 @@ impl Client { let response: response::Response = serde_json::from_slice( &Client::handle_error( reqwest::Client::builder() + .redirect(self.redirect_policy()) .timeout(Duration::from_millis(self.timeout)) .default_headers(self.headers.clone()) .build()? @@ -167,6 +234,7 @@ impl Client { &Client::handle_error( reqwest::Client::builder() .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS)) + .redirect(self.redirect_policy()) .default_headers(self.headers.clone()) .build()? .get(&self.session_url) diff --git a/src/event_source/stream.rs b/src/event_source/stream.rs index d3dddaf..5272f96 100644 --- a/src/event_source/stream.rs +++ b/src/event_source/stream.rs @@ -63,6 +63,7 @@ impl Client { let mut stream = Client::handle_error( reqwest::Client::builder() .connect_timeout(Duration::from_millis(self.timeout())) + .redirect(self.redirect_policy()) .default_headers(headers) .build()? .get(event_source_url)