From 3cba63c2246536c9b7997d7aea0124272af84ec8 Mon Sep 17 00:00:00 2001 From: Mauro D Date: Wed, 24 May 2023 17:39:54 +0000 Subject: [PATCH] WebSocket support for self-signed certificates --- Cargo.toml | 5 ++-- src/client_ws.rs | 59 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 50b2b97..3e99909 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,11 @@ resolver = "2" [dependencies] reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"]} -tokio-tungstenite = { version = "0.18", features = ["rustls-tls-webpki-roots"], optional = true} +tokio-tungstenite = { version = "0.19", features = ["rustls-tls-webpki-roots"], optional = true} tokio = { version = "1.16", default-features = false, features = ["io-util"], optional = true } futures-util = { version = "0.3", optional = true} async-stream = { version = "0.3", optional = true} +rustls = { version = "0.21.0", features = ["dangerous_configuration"], optional = true } serde = { version = "1.0", features = ["derive"]} serde_json = "1.0" chrono = { version = "0.4", features = ["serde"]} @@ -29,7 +30,7 @@ maybe-async = "0.2" [features] default = ["async"] async = ["futures-util", "async-stream", "reqwest/stream"] -websockets = ["tokio", "tokio-tungstenite"] +websockets = ["tokio", "tokio-tungstenite", "rustls"] blocking = ["reqwest/blocking", "maybe-async/is_sync"] debug = [] diff --git a/src/client_ws.rs b/src/client_ws.rs index 06fd3d3..5c3b8e9 100644 --- a/src/client_ws.rs +++ b/src/client_ws.rs @@ -9,15 +9,19 @@ * except according to those terms. */ -use std::pin::Pin; +use std::{pin::Pin, sync::Arc}; use ahash::AHashMap; use futures_util::{stream::SplitSink, SinkExt, Stream, StreamExt}; +use rustls::{ + client::{ServerCertVerified, ServerCertVerifier}, + Certificate, ClientConfig, ServerName, +}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_tungstenite::{ tungstenite::{client::IntoClientRequest, Message}, - MaybeTlsStream, WebSocketStream, + Connector, MaybeTlsStream, WebSocketStream, }; use crate::{ @@ -123,9 +127,9 @@ pub struct WebSocketStateChange { } #[derive(Debug, Deserialize)] -pub struct WebSocketProblem { +pub struct WebSocketError { #[serde(rename = "@type")] - pub type_: WebSocketProblemType, + pub type_: WebSocketErrorType, #[serde(rename = "requestId")] pub request_id: Option, @@ -139,8 +143,8 @@ pub struct WebSocketProblem { } #[derive(Serialize, Deserialize, Debug)] -pub enum WebSocketProblemType { - Problem, +pub enum WebSocketErrorType { + RequestError, } #[derive(Debug, Deserialize)] @@ -148,7 +152,7 @@ pub enum WebSocketProblemType { enum WebSocketMessage_ { Response(WebSocketResponse), StateChange(WebSocketStateChange), - Error(WebSocketProblem), + Error(WebSocketError), } #[derive(Debug)] @@ -162,6 +166,23 @@ pub struct WsStream { req_id: usize, } +#[doc(hidden)] +struct DummyVerifier; + +impl ServerCertVerifier for DummyVerifier { + fn verify_server_cert( + &self, + _e: &Certificate, + _i: &[Certificate], + _sn: &ServerName, + _sc: &mut dyn Iterator, + _o: &[u8], + _n: std::time::SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} + impl Client { pub async fn connect_ws( &self, @@ -178,7 +199,23 @@ impl Client { .headers_mut() .insert("Authorization", self.authorization.parse().unwrap()); - let (stream, _) = tokio_tungstenite::connect_async(request).await?; + let (stream, _) = if self.accept_invalid_certs & capabilities.url().starts_with("wss") { + tokio_tungstenite::connect_async_tls_with_config( + request, + None, + false, + Connector::Rustls(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(DummyVerifier {})) + .with_no_client_auth(), + )) + .into(), + ) + .await? + } else { + tokio_tungstenite::connect_async(request).await? + }; let (tx, mut rx) = stream.split(); *self.ws.lock().await = WsStream { tx, req_id: 0 }.into(); @@ -221,7 +258,7 @@ impl Client { .as_mut() .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?; - // Assing request id + // Assign request id let request_id = ws.req_id.to_string(); ws.req_id += 1; @@ -294,8 +331,8 @@ impl Client { } } -impl From for ProblemDetails { - fn from(problem: WebSocketProblem) -> Self { +impl From for ProblemDetails { + fn from(problem: WebSocketError) -> Self { ProblemDetails::new( problem.p_type, problem.status,