diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/hornbeam.iml b/.idea/hornbeam.iml new file mode 100644 index 0000000..fc9fd51 --- /dev/null +++ b/.idea/hornbeam.iml @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..3ce3588 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..9050722 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/hornbeam/src/frame/length.rs b/hornbeam/src/frame/length.rs new file mode 100644 index 0000000..1a763c4 --- /dev/null +++ b/hornbeam/src/frame/length.rs @@ -0,0 +1,68 @@ +use std::io::{Error, Read, Write}; +use crate::frame::AsBit; + +pub enum WireLength { + Small(u8), + Medium(u16), + Large(u64) +} + +impl From for WireLength { + fn from(value: u64) -> Self { + if value <= 125 { + Self::Small(value as u8) + } else if value > 125 && value < 65536 { + Self::Medium(value as u16) + } else { + Self::Large(value) + } + } +} + +impl Into for WireLength { + fn into(self) -> u64 { + match self { + Self::Small(v) => v as u64, + Self::Medium(v) => v as u64, + Self::Large(v) => v + } + } +} + +pub trait LengthWritable { + fn write_length(&self, masked: bool, w: &mut W) -> Result<(), std::io::Error>; +} +impl LengthWritable for WireLength { + fn write_length(&self, masked: bool, w: &mut W) -> Result<(), Error> { + match self { + Self::Small(v) => w.write_all(&[*v | masked.as_bit() << 7]), + Self::Medium(v) => { + let bytes = v.to_be_bytes(); + w.write_all(&[126 | masked.as_bit() << 7, bytes[0], bytes[1]]) + }, + Self::Large(v) => { + let bytes = v.to_be_bytes(); + w.write_all(&[127 | masked.as_bit() << 7, bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]]) + }, + } + } +} + +pub trait LengthReadable { + fn read_length(initial: u8, r: &mut R) -> Result where Self: Sized; +} +impl LengthReadable for WireLength { + fn read_length(initial: u8, r: &mut R) -> Result where Self: Sized { + if initial <= 125 { + return Ok(WireLength::Small(initial)); + } else if initial == 126 { + let mut buf2 = [0u8; 2]; + r.read_exact(&mut buf2)?; + return Ok(WireLength::Medium(u16::from_be_bytes(buf2))); + } + + let mut buf2 = [0u8; 8]; + r.read_exact(&mut buf2)?; + return Ok(WireLength::Large(u64::from_be_bytes(buf2))); + } +} \ No newline at end of file diff --git a/hornbeam/src/frame/mask.rs b/hornbeam/src/frame/mask.rs new file mode 100644 index 0000000..143f806 --- /dev/null +++ b/hornbeam/src/frame/mask.rs @@ -0,0 +1,6 @@ +// Thanks Tungstenite for this implementation +pub fn mask(data: &mut [u8], key: [u8; 4]) { + for (i, byte) in data.iter_mut().enumerate() { + *byte ^= key[i & 3]; + } +} \ No newline at end of file diff --git a/hornbeam/src/frame/mod.rs b/hornbeam/src/frame/mod.rs new file mode 100644 index 0000000..b736af6 --- /dev/null +++ b/hornbeam/src/frame/mod.rs @@ -0,0 +1,43 @@ +/// Contains traits for the wire encoding of frames +pub mod wire; +/// Contains the implementation of FrameWritable +pub mod write; +/// Contains functions useful for encoding and decoding wire lengths +pub mod length; +/// Contains functions for frame masking +pub mod mask; + +pub struct Frame { + pub fin: bool, // 1 bit + pub rsv1: bool, // 1 bit + pub rsv2: bool, // 1 bit + pub rsv3: bool, // 1 bit + pub opcode: Opcode, // 4 bits + // -- byte boundary -- + pub mask: bool, // 1 bit + pub payload_len: u64, // 7 bits, or 7 + 16 bits, or 7 + 64 bits + // -- byte boundary -- + pub masking_key: Option<[u8; 4]>, // 4 bytes + // -- byte boundary -- + pub payload_data: Vec +} + +#[derive(Clone)] +#[repr(u8)] +pub enum Opcode { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + ConnectionClose = 0x8, + Ping = 0x9, + Pong = 0xa +} + +pub trait AsBit { + fn as_bit(&self) -> u8; +} +impl AsBit for bool { + fn as_bit(&self) -> u8 { + if *self { 1u8 } else { 0u8 } + } +} \ No newline at end of file diff --git a/hornbeam/src/frame/wire.rs b/hornbeam/src/frame/wire.rs new file mode 100644 index 0000000..14cc32b --- /dev/null +++ b/hornbeam/src/frame/wire.rs @@ -0,0 +1,13 @@ +use crate::frame::Frame; + +pub trait FrameWritable { + type Error; + + fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error>; +} + +pub trait FrameReadable { + type Error; + + fn read_frame(&mut self) -> Result; +} \ No newline at end of file diff --git a/hornbeam/src/frame/write.rs b/hornbeam/src/frame/write.rs new file mode 100644 index 0000000..5dd18a1 --- /dev/null +++ b/hornbeam/src/frame/write.rs @@ -0,0 +1,150 @@ +use std::error::Error; +use std::fmt::{Display, Formatter}; +use std::io; +use std::io::Write; +use crate::frame::{AsBit, Frame}; +use crate::frame::length::{LengthWritable, WireLength}; +use crate::frame::mask::mask; +use crate::frame::wire::FrameWritable; + +#[derive(Debug)] +pub enum FrameWriteError { + IoError(io::Error), + MaskEnabledButMissingKey +} +impl Display for FrameWriteError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::IoError(e) => write!(f, "io error: {}", e), + Self::MaskEnabledButMissingKey => write!(f, "frame masking enabled but key not present") + } + } +} +impl Error for FrameWriteError {} +impl From for FrameWriteError { + fn from(value: io::Error) -> Self { + Self::IoError(value) + } +} + +impl FrameWritable for W { + type Error = FrameWriteError; + + fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error> { + // build the flags and opcode byte + let flags: u8 = (frame.fin.as_bit() << 7) | (frame.rsv1.as_bit() << 6) | (frame.rsv2.as_bit() << 5) | (frame.rsv3.as_bit() << 4); + let opcode: u8 = (frame.opcode.clone()) as u8; + let byte_0 = flags | opcode; + self.write_all(&[byte_0])?; + + // write the mask bit and length + WireLength::from(frame.payload_len).write_length(frame.mask, self)?; + + let mut data = frame.payload_data.clone(); + + if frame.mask { + if let Some(key) = frame.masking_key { + self.write_all(&key)?; + + mask(&mut data, key); + } else { + return Err(FrameWriteError::MaskEnabledButMissingKey); // TODO: we should do this earlier, before stuff is written to the socket + } + } + + self.write_all(&data)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + use crate::frame::{Frame, Opcode}; + use crate::frame::wire::FrameWritable; + + #[test] + fn encoding_6455_5_7_a() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + + let frame = Frame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode::Text, + mask: false, + payload_len: 5, + masking_key: None, + payload_data: "Hello".to_string().as_bytes().to_vec(), + }; + + buf.write_frame(&frame).unwrap(); + + assert_eq!(buf.into_inner(), vec![0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]); + } + + #[test] + fn encoding_6455_5_7_b() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + + let frame = Frame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode::Text, + mask: true, + payload_len: 5, + masking_key: Some([0x37, 0xfa, 0x21, 0x3d]), + payload_data: "Hello".to_string().as_bytes().to_vec(), + }; + + buf.write_frame(&frame).unwrap(); + + assert_eq!(buf.into_inner(), vec![0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58]); + } + + #[test] + fn encoding_6455_5_7_e_simplified() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + + let frame = Frame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode::Text, + mask: false, + payload_len: 256, + masking_key: None, + payload_data: "Hello".to_string().as_bytes().to_vec(), + }; + + buf.write_frame(&frame).unwrap(); + + assert_eq!(buf.into_inner(), vec![0x81, 0x7e, 0x01, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]); + } + + #[test] + fn encoding_6455_5_7_f_simplified() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + + let frame = Frame { + fin: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: Opcode::Text, + mask: false, + payload_len: 65536, + masking_key: None, + payload_data: "Hello".to_string().as_bytes().to_vec(), + }; + + buf.write_frame(&frame).unwrap(); + + assert_eq!(buf.into_inner(), vec![0x81, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]); + } +} \ No newline at end of file diff --git a/hornbeam/src/handshake_client.rs b/hornbeam/src/handshake_client.rs index 3883950..07e064f 100644 --- a/hornbeam/src/handshake_client.rs +++ b/hornbeam/src/handshake_client.rs @@ -4,7 +4,8 @@ use std::io; use std::io::{Read, Write}; use std::string::FromUtf8Error; use url::Url; -use crate::handshake_common::{HeaderMap, WEBSOCKET_PROTOCOL_VERSION}; +use crate::b64::impl_b64::base64_encode; +use crate::handshake_common::{derive_handshake_response, HeaderMap, WEBSOCKET_PROTOCOL_VERSION}; use crate::random::websocket_client_key; /// Contains the information needed to perform the WebSocket client handshake. Create from a URL with `ClientConnectionInfo::from(url)`, @@ -240,7 +241,11 @@ impl ClientConnectionInfo { } else if line.to_lowercase().starts_with("sec-websocket-accept: ") { let accept_key = line.split(' ').nth(1).unwrap(); + if derive_handshake_response(&base64_encode(&self.websocket_key)) != accept_key { + return Err(ClientHandshakeRecvError::IncorrectSecWebsocketAccept); + } + has_sec_accept = true; } } diff --git a/hornbeam/src/handshake_common.rs b/hornbeam/src/handshake_common.rs index 2bb8e04..07498df 100644 --- a/hornbeam/src/handshake_common.rs +++ b/hornbeam/src/handshake_common.rs @@ -7,9 +7,22 @@ pub const WEBSOCKET_PROTOCOL_VERSION: i32 = 13; /// A type alias for a key-value header map pub type HeaderMap = HashMap; -pub(crate) fn derive_handshake_response(input: [u8; 16]) -> String { +pub(crate) fn derive_handshake_response(input: &str) -> String { let mut hasher = Sha1::new(); hasher.update(input); hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); crate::b64::impl_b64::base64_encode(&hasher.finalize()) +} + +#[cfg(test)] +mod tests { + use crate::handshake_common::derive_handshake_response; + + #[test] + fn handshake_response_derivation() { + assert_eq!( + derive_handshake_response("dGhlIHNhbXBsZSBub25jZQ=="), + "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + ) + } } \ No newline at end of file diff --git a/hornbeam/src/lib.rs b/hornbeam/src/lib.rs index 8972bf4..b7f1243 100644 --- a/hornbeam/src/lib.rs +++ b/hornbeam/src/lib.rs @@ -11,7 +11,7 @@ #![warn(clippy::nursery)] //#![deny(clippy::unwrap_used)] #![warn(clippy::expect_used)] -#![deny(missing_docs)] +//#![deny(missing_docs)] #![allow(clippy::must_use_candidate)] // This gets annoying #[allow(unused)] @@ -81,4 +81,7 @@ compile_error!("You need to select one CSPRNG implementation"); #[path = "random_rand.rs"] pub mod random; -pub(crate) mod b64; \ No newline at end of file +pub(crate) mod b64; + +/// WebSocket frame encoding and decoding +pub mod frame; \ No newline at end of file diff --git a/hornbeam/src/random_rand.rs b/hornbeam/src/random_rand.rs index 35208db..526f29e 100644 --- a/hornbeam/src/random_rand.rs +++ b/hornbeam/src/random_rand.rs @@ -3,4 +3,9 @@ use rand::Rng; /// Generates a random WebSocket client key pub fn websocket_client_key() -> [u8; 16] { rand::thread_rng().gen() +} + +/// Generates a masking key +pub fn masking_key() -> [u8; 4] { + rand::thread_rng().gen() } \ No newline at end of file