Add support for trusted redirects.

main
Mauro D 2022-08-04 12:20:40 +00:00
parent 45f0aa3d81
commit 4674fb3a39
5 changed files with 74 additions and 2 deletions

View File

@ -15,7 +15,8 @@ readme = "README.md"
serde = { version = "1.0", features = ["derive"]}
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"]}
reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
#reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
reqwest = { git = "https://github.com/stalwartlabs/reqwest.git", default-features = false, features = ["stream", "rustls-tls"]}
futures-util = "0.3"
async-stream = "0.3"
base64 = "0.13"

View File

@ -39,6 +39,7 @@ impl Client {
Client::handle_error(
reqwest::Client::builder()
.timeout(Duration::from_millis(self.timeout()))
.redirect(self.redirect_policy())
.default_headers(headers)
.build()?
.get(download_url)

View File

@ -48,6 +48,7 @@ impl Client {
&Client::handle_error(
reqwest::Client::builder()
.timeout(Duration::from_millis(self.timeout()))
.redirect(self.redirect_policy())
.default_headers(self.headers().clone())
.build()?
.post(upload_url)

View File

@ -6,9 +6,10 @@ use std::{
time::Duration,
};
use ahash::AHashSet;
use reqwest::{
header::{self},
Response,
redirect, Response,
};
use serde::de::DeserializeOwned;
@ -36,6 +37,7 @@ pub struct Client {
session_url: String,
api_url: String,
session_updated: AtomicBool,
trusted_hosts: Arc<AHashSet<String>>,
upload_url: Vec<URLPart<blob::URLParameter>>,
download_url: Vec<URLPart<blob::URLParameter>>,
@ -53,6 +55,22 @@ pub struct Client {
impl Client {
pub async fn connect(url: &str, credentials: impl Into<Credentials>) -> crate::Result<Self> {
Self::connect_(url, credentials, None::<Vec<String>>).await
}
pub async fn connect_with_trusted(
url: &str,
credentials: impl Into<Credentials>,
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> crate::Result<Self> {
Self::connect_(url, credentials, trusted_hosts.into()).await
}
async fn connect_(
url: &str,
credentials: impl Into<Credentials>,
trusted_hosts: Option<impl IntoIterator<Item = impl Into<String>>>,
) -> crate::Result<Self> {
let authorization = match credentials.into() {
Credentials::Basic(s) => format!("Basic {}", s),
Credentials::Bearer(s) => format!("Bearer {}", s),
@ -67,10 +85,31 @@ impl Client {
header::HeaderValue::from_str(&authorization).unwrap(),
);
let trusted_hosts = Arc::new(
trusted_hosts
.map(|hosts| hosts.into_iter().map(|h| h.into()).collect::<AHashSet<_>>())
.unwrap_or_default(),
);
let trusted_hosts_ = trusted_hosts.clone();
let session: Session = serde_json::from_slice(
&Client::handle_error(
reqwest::Client::builder()
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
.redirect(redirect::Policy::custom(move |attempt| {
if attempt.previous().len() > 5 {
attempt.error("Too many redirects.")
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
{
attempt.follow_trusted()
} else {
let message = format!(
"Aborting redirect request to unknown host '{}'.",
attempt.url().host_str().unwrap_or("")
);
attempt.error(message)
}
}))
.default_headers(headers.clone())
.build()?
.get(url)
@ -101,6 +140,7 @@ impl Client {
session: parking_lot::Mutex::new(Arc::new(session)),
session_url: url.to_string(),
session_updated: true.into(),
trusted_hosts,
#[cfg(feature = "websockets")]
authorization,
timeout: DEFAULT_TIMEOUT_MS,
@ -116,6 +156,14 @@ impl Client {
self
}
pub fn set_trusted_hosts(
&mut self,
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> &mut Self {
self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
self
}
pub fn timeout(&self) -> u64 {
self.timeout
}
@ -132,6 +180,24 @@ impl Client {
&self.headers
}
pub(crate) fn redirect_policy(&self) -> redirect::Policy {
let trusted_hosts = self.trusted_hosts.clone();
redirect::Policy::custom(move |attempt| {
if attempt.previous().len() > 5 {
attempt.error("Too many redirects.")
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
{
attempt.follow_trusted()
} else {
let message = format!(
"Aborting redirect request to unknown host '{}'.",
attempt.url().host_str().unwrap_or("")
);
attempt.error(message)
}
})
}
pub async fn send<R>(
&self,
request: &request::Request<'_>,
@ -142,6 +208,7 @@ impl Client {
let response: response::Response<R> = serde_json::from_slice(
&Client::handle_error(
reqwest::Client::builder()
.redirect(self.redirect_policy())
.timeout(Duration::from_millis(self.timeout))
.default_headers(self.headers.clone())
.build()?
@ -167,6 +234,7 @@ impl Client {
&Client::handle_error(
reqwest::Client::builder()
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
.redirect(self.redirect_policy())
.default_headers(self.headers.clone())
.build()?
.get(&self.session_url)

View File

@ -63,6 +63,7 @@ impl Client {
let mut stream = Client::handle_error(
reqwest::Client::builder()
.connect_timeout(Duration::from_millis(self.timeout()))
.redirect(self.redirect_policy())
.default_headers(headers)
.build()?
.get(event_source_url)