diff --git a/src/marionette/error.rs b/src/marionette/error.rs index 17e019c..f1fccc4 100644 --- a/src/marionette/error.rs +++ b/src/marionette/error.rs @@ -8,6 +8,7 @@ pub enum MessageError { Io(std::io::Error), Parse(ParseIntError), Utf8(Utf8Error), + Disconnected, } impl From for MessageError { @@ -34,6 +35,7 @@ impl std::fmt::Display for MessageError { Self::Io(e) => write!(f, "I/O error: {}", e), Self::Parse(e) => write!(f, "Error parsing message: {}", e), Self::Utf8(e) => write!(f, "Error parsing message: {}", e), + Self::Disconnected => write!(f, "Disconnected"), } } } @@ -44,6 +46,7 @@ impl std::error::Error for MessageError { Self::Io(e) => Some(e), Self::Parse(e) => Some(e), Self::Utf8(e) => Some(e), + Self::Disconnected => None, } } } diff --git a/src/marionette/mod.rs b/src/marionette/mod.rs index 0ccbcf7..f044c83 100644 --- a/src/marionette/mod.rs +++ b/src/marionette/mod.rs @@ -150,6 +150,9 @@ impl MarionetteConnection { { let mut buf = vec![]; stream.read_until(b':', &mut buf).await?; + if buf.is_empty() { + return Err(MessageError::Disconnected); + } let length: usize = std::str::from_utf8(&buf[..buf.len() - 1])?.parse()?; trace!("Message length: {:?}", length);