Add support for trusted redirects.
parent
45f0aa3d81
commit
4674fb3a39
|
@ -15,7 +15,8 @@ readme = "README.md"
|
|||
serde = { version = "1.0", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
chrono = { version = "0.4", features = ["serde"]}
|
||||
reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
|
||||
#reqwest = { version = "0.11", default-features = false, features = ["stream", "rustls-tls"]}
|
||||
reqwest = { git = "https://github.com/stalwartlabs/reqwest.git", default-features = false, features = ["stream", "rustls-tls"]}
|
||||
futures-util = "0.3"
|
||||
async-stream = "0.3"
|
||||
base64 = "0.13"
|
||||
|
|
|
@ -39,6 +39,7 @@ impl Client {
|
|||
Client::handle_error(
|
||||
reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(self.timeout()))
|
||||
.redirect(self.redirect_policy())
|
||||
.default_headers(headers)
|
||||
.build()?
|
||||
.get(download_url)
|
||||
|
|
|
@ -48,6 +48,7 @@ impl Client {
|
|||
&Client::handle_error(
|
||||
reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(self.timeout()))
|
||||
.redirect(self.redirect_policy())
|
||||
.default_headers(self.headers().clone())
|
||||
.build()?
|
||||
.post(upload_url)
|
||||
|
|
|
@ -6,9 +6,10 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
|
||||
use ahash::AHashSet;
|
||||
use reqwest::{
|
||||
header::{self},
|
||||
Response,
|
||||
redirect, Response,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
|
@ -36,6 +37,7 @@ pub struct Client {
|
|||
session_url: String,
|
||||
api_url: String,
|
||||
session_updated: AtomicBool,
|
||||
trusted_hosts: Arc<AHashSet<String>>,
|
||||
|
||||
upload_url: Vec<URLPart<blob::URLParameter>>,
|
||||
download_url: Vec<URLPart<blob::URLParameter>>,
|
||||
|
@ -53,6 +55,22 @@ pub struct Client {
|
|||
|
||||
impl Client {
|
||||
pub async fn connect(url: &str, credentials: impl Into<Credentials>) -> crate::Result<Self> {
|
||||
Self::connect_(url, credentials, None::<Vec<String>>).await
|
||||
}
|
||||
|
||||
pub async fn connect_with_trusted(
|
||||
url: &str,
|
||||
credentials: impl Into<Credentials>,
|
||||
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
|
||||
) -> crate::Result<Self> {
|
||||
Self::connect_(url, credentials, trusted_hosts.into()).await
|
||||
}
|
||||
|
||||
async fn connect_(
|
||||
url: &str,
|
||||
credentials: impl Into<Credentials>,
|
||||
trusted_hosts: Option<impl IntoIterator<Item = impl Into<String>>>,
|
||||
) -> crate::Result<Self> {
|
||||
let authorization = match credentials.into() {
|
||||
Credentials::Basic(s) => format!("Basic {}", s),
|
||||
Credentials::Bearer(s) => format!("Bearer {}", s),
|
||||
|
@ -67,10 +85,31 @@ impl Client {
|
|||
header::HeaderValue::from_str(&authorization).unwrap(),
|
||||
);
|
||||
|
||||
let trusted_hosts = Arc::new(
|
||||
trusted_hosts
|
||||
.map(|hosts| hosts.into_iter().map(|h| h.into()).collect::<AHashSet<_>>())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
let trusted_hosts_ = trusted_hosts.clone();
|
||||
let session: Session = serde_json::from_slice(
|
||||
&Client::handle_error(
|
||||
reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
|
||||
.redirect(redirect::Policy::custom(move |attempt| {
|
||||
if attempt.previous().len() > 5 {
|
||||
attempt.error("Too many redirects.")
|
||||
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
|
||||
{
|
||||
attempt.follow_trusted()
|
||||
} else {
|
||||
let message = format!(
|
||||
"Aborting redirect request to unknown host '{}'.",
|
||||
attempt.url().host_str().unwrap_or("")
|
||||
);
|
||||
attempt.error(message)
|
||||
}
|
||||
}))
|
||||
.default_headers(headers.clone())
|
||||
.build()?
|
||||
.get(url)
|
||||
|
@ -101,6 +140,7 @@ impl Client {
|
|||
session: parking_lot::Mutex::new(Arc::new(session)),
|
||||
session_url: url.to_string(),
|
||||
session_updated: true.into(),
|
||||
trusted_hosts,
|
||||
#[cfg(feature = "websockets")]
|
||||
authorization,
|
||||
timeout: DEFAULT_TIMEOUT_MS,
|
||||
|
@ -116,6 +156,14 @@ impl Client {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn set_trusted_hosts(
|
||||
&mut self,
|
||||
trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
|
||||
) -> &mut Self {
|
||||
self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn timeout(&self) -> u64 {
|
||||
self.timeout
|
||||
}
|
||||
|
@ -132,6 +180,24 @@ impl Client {
|
|||
&self.headers
|
||||
}
|
||||
|
||||
pub(crate) fn redirect_policy(&self) -> redirect::Policy {
|
||||
let trusted_hosts = self.trusted_hosts.clone();
|
||||
redirect::Policy::custom(move |attempt| {
|
||||
if attempt.previous().len() > 5 {
|
||||
attempt.error("Too many redirects.")
|
||||
} else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
|
||||
{
|
||||
attempt.follow_trusted()
|
||||
} else {
|
||||
let message = format!(
|
||||
"Aborting redirect request to unknown host '{}'.",
|
||||
attempt.url().host_str().unwrap_or("")
|
||||
);
|
||||
attempt.error(message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn send<R>(
|
||||
&self,
|
||||
request: &request::Request<'_>,
|
||||
|
@ -142,6 +208,7 @@ impl Client {
|
|||
let response: response::Response<R> = serde_json::from_slice(
|
||||
&Client::handle_error(
|
||||
reqwest::Client::builder()
|
||||
.redirect(self.redirect_policy())
|
||||
.timeout(Duration::from_millis(self.timeout))
|
||||
.default_headers(self.headers.clone())
|
||||
.build()?
|
||||
|
@ -167,6 +234,7 @@ impl Client {
|
|||
&Client::handle_error(
|
||||
reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
|
||||
.redirect(self.redirect_policy())
|
||||
.default_headers(self.headers.clone())
|
||||
.build()?
|
||||
.get(&self.session_url)
|
||||
|
|
|
@ -63,6 +63,7 @@ impl Client {
|
|||
let mut stream = Client::handle_error(
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_millis(self.timeout()))
|
||||
.redirect(self.redirect_policy())
|
||||
.default_headers(headers)
|
||||
.build()?
|
||||
.get(event_source_url)
|
||||
|
|
Loading…
Reference in New Issue