marionette: Handle concurrent communication

The Marionette protocol is designed to facilitate concurrent,
asynchronous messages.  Each request message includes a message
ID, and the corresponding response includes the same message ID.  This
allows several requests to be in flight at once.  In order for this to
be useful, the client needs to maintain a record of each request it has
sent so that it knows how to handle responses, even if they arrive out
of order.

To implement this functionality in *mqttmarionette*, the
`MarionetteConnection` structure spawns a Tokio task to handle all
incoming messages from the server.  When a message arrives, its ID is
looked up in a registry that maps message IDs to Tokio "oneshot"
channels.  If a channel is found in the map, the response is sent back
to the caller through the channel.

In order to handle incoming messages in a separate task, the TCP stream
has to be split into its read and write parts.  The receiver task cannot
be spawned, though, until after the first unsolicited message is read
from the socket, since a) there is no caller to send the message back to
and b) it does not follow the same encoding scheme as the rest of the
Marionette messages.  As such, I've refactored the
`MarionetteConnection` structure to handle the initial message in the
`connect` function and dropped the `handshake` method.  A new
`start_session` method is responsible for initiating the Marionette
session.
dev/ci
Dustin 2022-12-30 09:51:14 -06:00
parent f3815e2b12
commit c8386f9dee
3 changed files with 118 additions and 56 deletions

View File

@ -1,5 +1,5 @@
use std::str::Utf8Error;
use std::num::ParseIntError;
use std::str::Utf8Error;
#[derive(Debug)]
pub enum MessageError {
@ -27,32 +27,26 @@ impl From<Utf8Error> for MessageError {
}
#[derive(Debug)]
pub enum HandshakeError {
pub enum ConnectionError {
Message(MessageError),
Io(std::io::Error),
Parse(ParseIntError),
Utf8(Utf8Error),
Json(serde_json::Error),
}
impl From<MessageError> for HandshakeError {
impl From<MessageError> for ConnectionError {
fn from(e: MessageError) -> Self {
match e {
MessageError::Io(e) => Self::Io(e),
MessageError::Parse(e) => Self::Parse(e),
MessageError::Utf8(e) => Self::Utf8(e),
}
Self::Message(e)
}
}
impl From<std::io::Error> for HandshakeError {
impl From<std::io::Error> for ConnectionError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<serde_json::Error> for HandshakeError {
impl From<serde_json::Error> for ConnectionError {
fn from(e: serde_json::Error) -> Self {
Self::Json(e)
}
}

View File

@ -1,78 +1,146 @@
pub mod error;
pub mod message;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufStream};
use tokio::io::{
AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader,
BufWriter,
};
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::{TcpStream, ToSocketAddrs};
use tracing::{debug, trace};
use tokio::sync::oneshot;
use tracing::{debug, error, trace, warn};
pub use error::{HandshakeError, MessageError};
pub use error::{ConnectionError, MessageError};
use message::{Command, Hello, NewSessionParams, NewSessionResponse};
#[derive(Debug, Deserialize, Serialize)]
struct Message(u8, u32, Option<String>, serde_json::Value);
struct Message(u8, u32, Option<String>, Option<serde_json::Value>);
type SenderMap = HashMap<u32, oneshot::Sender<Option<serde_json::Value>>>;
pub struct Marionette {
ts: Instant,
stream: BufStream<TcpStream>,
stream: BufWriter<OwnedWriteHalf>,
sender: Arc<Mutex<SenderMap>>,
}
impl Marionette {
pub async fn connect<A>(addr: A) -> Result<Self, std::io::Error>
pub async fn connect<A>(addr: A) -> Result<Self, ConnectionError>
where
A: ToSocketAddrs,
{
let conn = TcpStream::connect(addr).await?;
let stream = BufStream::new(conn);
let (read, write) = conn.into_split();
let stream = BufWriter::new(write);
let mut rstream = BufReader::new(read);
let ts = Instant::now();
Ok(Self { ts, stream })
}
pub async fn handshake(&mut self) -> Result<(), HandshakeError> {
let buf = self.next_message().await?;
let sender = Arc::new(Mutex::new(HashMap::new()));
let buf = Self::next_message(&mut rstream).await?;
let hello: Hello = serde_json::from_slice(&buf)?;
debug!("Received hello: {:?}", hello);
self.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?;
let buf = self.next_message().await?;
let msg: Message = serde_json::from_slice(&buf)?;
let res: NewSessionResponse = serde_json::from_value(msg.3)?;
Self::start_recv_loop(rstream, sender.clone());
Ok(Self { ts, stream, sender })
}
pub async fn new_session(
&mut self,
) -> Result<NewSessionResponse, std::io::Error> {
let res = self
.send_message(Command::NewSession(NewSessionParams {
strict_file_interactability: true,
}))
.await?
.unwrap();
debug!("Received message: {:?}", res);
Ok(())
Ok(res)
}
async fn next_message(&mut self) -> Result<Vec<u8>, MessageError> {
let mut buf = vec![];
self.stream.read_until(b':', &mut buf).await?;
let length: usize =
std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?;
trace!("Message length: {:?}", length);
let mut buf = vec![0; length];
self.stream.read_exact(&mut buf[..]).await?;
trace!("Received message: {:?}", buf);
Ok(buf)
}
async fn send_message(
pub async fn send_message<T>(
&mut self,
command: Command,
) -> Result<(), std::io::Error> {
) -> Result<Option<T>, std::io::Error>
where
T: DeserializeOwned,
{
let value = serde_json::to_value(command)?;
let (command, params) = (
value.get("command").unwrap().as_str().unwrap().into(),
value.get("params").unwrap().clone(),
value.get("params").cloned(),
);
let msgid = (self.ts.elapsed().as_millis() % u32::MAX as u128) as u32;
let message = Message(0, msgid, Some(command), params);
let message = serde_json::to_string(&message)?;
let message = format!("{}:{}", message.len(), message);
trace!("Sending message: {}", message);
let (tx, rx) = oneshot::channel();
{
let mut sender = self.sender.lock().unwrap();
sender.insert(msgid, tx);
}
self.stream.write_all(message.as_bytes()).await?;
self.stream.flush().await?;
Ok(())
let Some(r) = rx.await.unwrap() else {
return Ok(None)
};
Ok(serde_json::from_value(r)?)
}
fn start_recv_loop<T>(
mut stream: BufReader<T>,
sender: Arc<Mutex<SenderMap>>,
) where
T: AsyncRead + Send + Unpin + 'static,
{
tokio::spawn(async move {
loop {
let buf = match Self::next_message(&mut stream).await {
Ok(b) => b,
Err(e) => {
error!("Error receiving message: {:?}", e);
break;
}
};
let msg: Message = match serde_json::from_slice(&buf[..]) {
Ok(m) => m,
Err(e) => {
warn!("Error parsing message: {}", e);
continue;
}
};
let msgid = msg.1;
let value = msg.3;
let mut sender = sender.lock().unwrap();
if let Some(s) = sender.remove(&msgid) {
if s.send(value).is_err() {
warn!("Failed to send result to caller");
}
} else {
warn!("Got unsolicited message {} ({:?})", msgid, value);
}
}
});
}
async fn next_message<T>(
stream: &mut BufReader<T>,
) -> Result<Vec<u8>, MessageError>
where
T: AsyncRead + Unpin,
{
let mut buf = vec![];
stream.read_until(b':', &mut buf).await?;
let length: usize =
std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?;
trace!("Message length: {:?}", length);
let mut buf = vec![0; length];
stream.read_exact(&mut buf[..]).await?;
trace!("Received message: {:?}", buf);
Ok(buf)
}
}

View File

@ -1,14 +1,14 @@
use tracing::{debug, info};
use crate::browser::{Browser, BrowserError};
use crate::marionette::error::ConnectionError;
use crate::marionette::Marionette;
use crate::marionette::error::HandshakeError;
#[derive(Debug)]
pub enum SessionError {
Browser(BrowserError),
Io(std::io::Error),
Handshake(HandshakeError),
Connection(ConnectionError),
InvalidState(String),
}
@ -24,13 +24,12 @@ impl From<std::io::Error> for SessionError {
}
}
impl From<HandshakeError> for SessionError {
fn from(e: HandshakeError) -> Self {
Self::Handshake(e)
impl From<ConnectionError> for SessionError {
fn from(e: ConnectionError) -> Self {
Self::Connection(e)
}
}
pub struct Session {
browser: Browser,
marionette: Marionette,
@ -48,7 +47,8 @@ impl Session {
debug!("Connecting to Firefox Marionette on port {}", port);
let mut marionette = Marionette::connect(("127.0.0.1", port)).await?;
info!("Successfully connected to Firefox Marionette");
marionette.handshake().await?;
let ses = marionette.new_session().await?;
debug!("Started Marionette session {}", ses.session_id);
Ok(Self {
browser,
marionette,