diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/hornbeam.iml b/.idea/hornbeam.iml
new file mode 100644
index 0000000..fc9fd51
--- /dev/null
+++ b/.idea/hornbeam.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..3ce3588
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..9050722
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/hornbeam/src/frame/length.rs b/hornbeam/src/frame/length.rs
new file mode 100644
index 0000000..1a763c4
--- /dev/null
+++ b/hornbeam/src/frame/length.rs
@@ -0,0 +1,68 @@
+use std::io::{Error, Read, Write};
+use crate::frame::AsBit;
+
+pub enum WireLength {
+ Small(u8),
+ Medium(u16),
+ Large(u64)
+}
+
+impl From for WireLength {
+ fn from(value: u64) -> Self {
+ if value <= 125 {
+ Self::Small(value as u8)
+ } else if value > 125 && value < 65536 {
+ Self::Medium(value as u16)
+ } else {
+ Self::Large(value)
+ }
+ }
+}
+
+impl Into for WireLength {
+ fn into(self) -> u64 {
+ match self {
+ Self::Small(v) => v as u64,
+ Self::Medium(v) => v as u64,
+ Self::Large(v) => v
+ }
+ }
+}
+
+pub trait LengthWritable {
+ fn write_length(&self, masked: bool, w: &mut W) -> Result<(), std::io::Error>;
+}
+impl LengthWritable for WireLength {
+ fn write_length(&self, masked: bool, w: &mut W) -> Result<(), Error> {
+ match self {
+ Self::Small(v) => w.write_all(&[*v | masked.as_bit() << 7]),
+ Self::Medium(v) => {
+ let bytes = v.to_be_bytes();
+ w.write_all(&[126 | masked.as_bit() << 7, bytes[0], bytes[1]])
+ },
+ Self::Large(v) => {
+ let bytes = v.to_be_bytes();
+ w.write_all(&[127 | masked.as_bit() << 7, bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
+ },
+ }
+ }
+}
+
+pub trait LengthReadable {
+ fn read_length(initial: u8, r: &mut R) -> Result where Self: Sized;
+}
+impl LengthReadable for WireLength {
+ fn read_length(initial: u8, r: &mut R) -> Result where Self: Sized {
+ if initial <= 125 {
+ return Ok(WireLength::Small(initial));
+ } else if initial == 126 {
+ let mut buf2 = [0u8; 2];
+ r.read_exact(&mut buf2)?;
+ return Ok(WireLength::Medium(u16::from_be_bytes(buf2)));
+ }
+
+ let mut buf2 = [0u8; 8];
+ r.read_exact(&mut buf2)?;
+ return Ok(WireLength::Large(u64::from_be_bytes(buf2)));
+ }
+}
\ No newline at end of file
diff --git a/hornbeam/src/frame/mask.rs b/hornbeam/src/frame/mask.rs
new file mode 100644
index 0000000..143f806
--- /dev/null
+++ b/hornbeam/src/frame/mask.rs
@@ -0,0 +1,6 @@
+// Thanks Tungstenite for this implementation
+pub fn mask(data: &mut [u8], key: [u8; 4]) {
+ for (i, byte) in data.iter_mut().enumerate() {
+ *byte ^= key[i & 3];
+ }
+}
\ No newline at end of file
diff --git a/hornbeam/src/frame/mod.rs b/hornbeam/src/frame/mod.rs
new file mode 100644
index 0000000..b736af6
--- /dev/null
+++ b/hornbeam/src/frame/mod.rs
@@ -0,0 +1,43 @@
+/// Contains traits for the wire encoding of frames
+pub mod wire;
+/// Contains the implementation of FrameWritable
+pub mod write;
+/// Contains functions useful for encoding and decoding wire lengths
+pub mod length;
+/// Contains functions for frame masking
+pub mod mask;
+
+pub struct Frame {
+ pub fin: bool, // 1 bit
+ pub rsv1: bool, // 1 bit
+ pub rsv2: bool, // 1 bit
+ pub rsv3: bool, // 1 bit
+ pub opcode: Opcode, // 4 bits
+ // -- byte boundary --
+ pub mask: bool, // 1 bit
+ pub payload_len: u64, // 7 bits, or 7 + 16 bits, or 7 + 64 bits
+ // -- byte boundary --
+ pub masking_key: Option<[u8; 4]>, // 4 bytes
+ // -- byte boundary --
+ pub payload_data: Vec
+}
+
+#[derive(Clone)]
+#[repr(u8)]
+pub enum Opcode {
+ Continuation = 0x0,
+ Text = 0x1,
+ Binary = 0x2,
+ ConnectionClose = 0x8,
+ Ping = 0x9,
+ Pong = 0xa
+}
+
+pub trait AsBit {
+ fn as_bit(&self) -> u8;
+}
+impl AsBit for bool {
+ fn as_bit(&self) -> u8 {
+ if *self { 1u8 } else { 0u8 }
+ }
+}
\ No newline at end of file
diff --git a/hornbeam/src/frame/wire.rs b/hornbeam/src/frame/wire.rs
new file mode 100644
index 0000000..14cc32b
--- /dev/null
+++ b/hornbeam/src/frame/wire.rs
@@ -0,0 +1,13 @@
+use crate::frame::Frame;
+
+pub trait FrameWritable {
+ type Error;
+
+ fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error>;
+}
+
+pub trait FrameReadable {
+ type Error;
+
+ fn read_frame(&mut self) -> Result;
+}
\ No newline at end of file
diff --git a/hornbeam/src/frame/write.rs b/hornbeam/src/frame/write.rs
new file mode 100644
index 0000000..5dd18a1
--- /dev/null
+++ b/hornbeam/src/frame/write.rs
@@ -0,0 +1,150 @@
+use std::error::Error;
+use std::fmt::{Display, Formatter};
+use std::io;
+use std::io::Write;
+use crate::frame::{AsBit, Frame};
+use crate::frame::length::{LengthWritable, WireLength};
+use crate::frame::mask::mask;
+use crate::frame::wire::FrameWritable;
+
+#[derive(Debug)]
+pub enum FrameWriteError {
+ IoError(io::Error),
+ MaskEnabledButMissingKey
+}
+impl Display for FrameWriteError {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::IoError(e) => write!(f, "io error: {}", e),
+ Self::MaskEnabledButMissingKey => write!(f, "frame masking enabled but key not present")
+ }
+ }
+}
+impl Error for FrameWriteError {}
+impl From for FrameWriteError {
+ fn from(value: io::Error) -> Self {
+ Self::IoError(value)
+ }
+}
+
+impl FrameWritable for W {
+ type Error = FrameWriteError;
+
+ fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error> {
+ // build the flags and opcode byte
+ let flags: u8 = (frame.fin.as_bit() << 7) | (frame.rsv1.as_bit() << 6) | (frame.rsv2.as_bit() << 5) | (frame.rsv3.as_bit() << 4);
+ let opcode: u8 = (frame.opcode.clone()) as u8;
+ let byte_0 = flags | opcode;
+ self.write_all(&[byte_0])?;
+
+ // write the mask bit and length
+ WireLength::from(frame.payload_len).write_length(frame.mask, self)?;
+
+ let mut data = frame.payload_data.clone();
+
+ if frame.mask {
+ if let Some(key) = frame.masking_key {
+ self.write_all(&key)?;
+
+ mask(&mut data, key);
+ } else {
+ return Err(FrameWriteError::MaskEnabledButMissingKey); // TODO: we should do this earlier, before stuff is written to the socket
+ }
+ }
+
+ self.write_all(&data)?;
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::io::Cursor;
+ use crate::frame::{Frame, Opcode};
+ use crate::frame::wire::FrameWritable;
+
+ #[test]
+ fn encoding_6455_5_7_a() {
+ let mut buf: Cursor> = Cursor::new(Vec::new());
+
+ let frame = Frame {
+ fin: true,
+ rsv1: false,
+ rsv2: false,
+ rsv3: false,
+ opcode: Opcode::Text,
+ mask: false,
+ payload_len: 5,
+ masking_key: None,
+ payload_data: "Hello".to_string().as_bytes().to_vec(),
+ };
+
+ buf.write_frame(&frame).unwrap();
+
+ assert_eq!(buf.into_inner(), vec![0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
+ }
+
+ #[test]
+ fn encoding_6455_5_7_b() {
+ let mut buf: Cursor> = Cursor::new(Vec::new());
+
+ let frame = Frame {
+ fin: true,
+ rsv1: false,
+ rsv2: false,
+ rsv3: false,
+ opcode: Opcode::Text,
+ mask: true,
+ payload_len: 5,
+ masking_key: Some([0x37, 0xfa, 0x21, 0x3d]),
+ payload_data: "Hello".to_string().as_bytes().to_vec(),
+ };
+
+ buf.write_frame(&frame).unwrap();
+
+ assert_eq!(buf.into_inner(), vec![0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58]);
+ }
+
+ #[test]
+ fn encoding_6455_5_7_e_simplified() {
+ let mut buf: Cursor> = Cursor::new(Vec::new());
+
+ let frame = Frame {
+ fin: true,
+ rsv1: false,
+ rsv2: false,
+ rsv3: false,
+ opcode: Opcode::Text,
+ mask: false,
+ payload_len: 256,
+ masking_key: None,
+ payload_data: "Hello".to_string().as_bytes().to_vec(),
+ };
+
+ buf.write_frame(&frame).unwrap();
+
+ assert_eq!(buf.into_inner(), vec![0x81, 0x7e, 0x01, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
+ }
+
+ #[test]
+ fn encoding_6455_5_7_f_simplified() {
+ let mut buf: Cursor> = Cursor::new(Vec::new());
+
+ let frame = Frame {
+ fin: true,
+ rsv1: false,
+ rsv2: false,
+ rsv3: false,
+ opcode: Opcode::Text,
+ mask: false,
+ payload_len: 65536,
+ masking_key: None,
+ payload_data: "Hello".to_string().as_bytes().to_vec(),
+ };
+
+ buf.write_frame(&frame).unwrap();
+
+ assert_eq!(buf.into_inner(), vec![0x81, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
+ }
+}
\ No newline at end of file
diff --git a/hornbeam/src/handshake_client.rs b/hornbeam/src/handshake_client.rs
index 3883950..07e064f 100644
--- a/hornbeam/src/handshake_client.rs
+++ b/hornbeam/src/handshake_client.rs
@@ -4,7 +4,8 @@ use std::io;
use std::io::{Read, Write};
use std::string::FromUtf8Error;
use url::Url;
-use crate::handshake_common::{HeaderMap, WEBSOCKET_PROTOCOL_VERSION};
+use crate::b64::impl_b64::base64_encode;
+use crate::handshake_common::{derive_handshake_response, HeaderMap, WEBSOCKET_PROTOCOL_VERSION};
use crate::random::websocket_client_key;
/// Contains the information needed to perform the WebSocket client handshake. Create from a URL with `ClientConnectionInfo::from(url)`,
@@ -240,7 +241,11 @@ impl ClientConnectionInfo {
} else if line.to_lowercase().starts_with("sec-websocket-accept: ") {
let accept_key = line.split(' ').nth(1).unwrap();
+ if derive_handshake_response(&base64_encode(&self.websocket_key)) != accept_key {
+ return Err(ClientHandshakeRecvError::IncorrectSecWebsocketAccept);
+ }
+ has_sec_accept = true;
}
}
diff --git a/hornbeam/src/handshake_common.rs b/hornbeam/src/handshake_common.rs
index 2bb8e04..07498df 100644
--- a/hornbeam/src/handshake_common.rs
+++ b/hornbeam/src/handshake_common.rs
@@ -7,9 +7,22 @@ pub const WEBSOCKET_PROTOCOL_VERSION: i32 = 13;
/// A type alias for a key-value header map
pub type HeaderMap = HashMap;
-pub(crate) fn derive_handshake_response(input: [u8; 16]) -> String {
+pub(crate) fn derive_handshake_response(input: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(input);
hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
crate::b64::impl_b64::base64_encode(&hasher.finalize())
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::handshake_common::derive_handshake_response;
+
+ #[test]
+ fn handshake_response_derivation() {
+ assert_eq!(
+ derive_handshake_response("dGhlIHNhbXBsZSBub25jZQ=="),
+ "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
+ )
+ }
}
\ No newline at end of file
diff --git a/hornbeam/src/lib.rs b/hornbeam/src/lib.rs
index 8972bf4..b7f1243 100644
--- a/hornbeam/src/lib.rs
+++ b/hornbeam/src/lib.rs
@@ -11,7 +11,7 @@
#![warn(clippy::nursery)]
//#![deny(clippy::unwrap_used)]
#![warn(clippy::expect_used)]
-#![deny(missing_docs)]
+//#![deny(missing_docs)]
#![allow(clippy::must_use_candidate)] // This gets annoying
#[allow(unused)]
@@ -81,4 +81,7 @@ compile_error!("You need to select one CSPRNG implementation");
#[path = "random_rand.rs"]
pub mod random;
-pub(crate) mod b64;
\ No newline at end of file
+pub(crate) mod b64;
+
+/// WebSocket frame encoding and decoding
+pub mod frame;
\ No newline at end of file
diff --git a/hornbeam/src/random_rand.rs b/hornbeam/src/random_rand.rs
index 35208db..526f29e 100644
--- a/hornbeam/src/random_rand.rs
+++ b/hornbeam/src/random_rand.rs
@@ -3,4 +3,9 @@ use rand::Rng;
/// Generates a random WebSocket client key
pub fn websocket_client_key() -> [u8; 16] {
rand::thread_rng().gen()
+}
+
+/// Generates a masking key
+pub fn masking_key() -> [u8; 4] {
+ rand::thread_rng().gen()
}
\ No newline at end of file