From bdeb371b1a8dc5536fd1d67d74f649916a534613 Mon Sep 17 00:00:00 2001
From: core <core@coredoes.dev>
Date: Thu, 10 Aug 2023 03:24:09 -0400
Subject: [PATCH] frame decoding

---
 hornbeam/src/frame/mod.rs  |   5 +-
 hornbeam/src/frame/read.rs | 136 +++++++++++++++++++++++++++++++++++++
 2 files changed, 140 insertions(+), 1 deletion(-)
 create mode 100644 hornbeam/src/frame/read.rs

diff --git a/hornbeam/src/frame/mod.rs b/hornbeam/src/frame/mod.rs
index b736af6..3964de0 100644
--- a/hornbeam/src/frame/mod.rs
+++ b/hornbeam/src/frame/mod.rs
@@ -2,11 +2,14 @@
 pub mod wire;
 /// Contains the implementation of FrameWritable
 pub mod write;
+/// Contains the implementation of FrameReadable
+pub mod read;
 /// Contains functions useful for encoding and decoding wire lengths
 pub mod length;
 /// Contains functions for frame masking
 pub mod mask;
 
+#[derive(PartialEq, Eq, Clone, Debug)]
 pub struct Frame {
     pub fin: bool, // 1 bit
     pub rsv1: bool, // 1 bit
@@ -22,7 +25,7 @@ pub struct Frame {
     pub payload_data: Vec<u8>
 }
 
-#[derive(Clone)]
+#[derive(Clone, PartialEq, Eq, Debug)]
 #[repr(u8)]
 pub enum Opcode {
     Continuation = 0x0,
diff --git a/hornbeam/src/frame/read.rs b/hornbeam/src/frame/read.rs
new file mode 100644
index 0000000..b7a7c68
--- /dev/null
+++ b/hornbeam/src/frame/read.rs
@@ -0,0 +1,136 @@
+use std::error::Error;
+use std::fmt::{Display, Formatter};
+use std::io;
+use std::io::Read;
+use crate::frame::{Frame, Opcode};
+use crate::frame::length::{LengthReadable, WireLength};
+use crate::frame::mask::mask;
+use crate::frame::wire::FrameReadable;
+
+#[derive(Debug)]
+pub enum FrameReadError {
+    IoError(io::Error),
+    InvalidOpcode
+}
+impl Display for FrameReadError {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Self::IoError(e) => write!(f, "io error: {}", e),
+            Self::InvalidOpcode => write!(f, "invalid opcode")
+        }
+    }
+}
+impl Error for FrameReadError {}
+impl From<io::Error> for FrameReadError {
+    fn from(value: io::Error) -> Self {
+        Self::IoError(value)
+    }
+}
+
+impl<R: Read> FrameReadable for R {
+    type Error = FrameReadError;
+
+    fn read_frame(&mut self) -> Result<Frame, Self::Error> {
+        // read the first two bytes
+        let mut first_two = [0u8; 2];
+        self.read_exact(&mut first_two)?;
+
+        let fin = (first_two[0] & 0b1000_0000) >> 7;
+        let rsv1 = (first_two[0] & 0b0100_0000) >> 6;
+        let rsv2 = (first_two[0] & 0b0010_0000) >> 5;
+        let rsv3 = (first_two[0] & 0b0001_0000) >> 4;
+        let opcode = first_two[0] & 0b0000_1111;
+
+        let do_mask = (first_two[1] & 0b1000_0000) >> 7;
+        let length_without_mask = first_two[1] & 0b0111_1111;
+
+
+        let length: u64 = WireLength::read_length(length_without_mask, self)?.into();
+
+
+        let mut mask_key = None;
+
+        if do_mask == 1 {
+            let mut buf = [0u8; 4];
+
+            self.read_exact(&mut buf)?;
+
+            mask_key = Some(buf);
+        }
+
+        let mut payload = vec![0u8; length as usize];
+
+        self.read_exact(&mut payload)?;
+
+        if let Some(key) = mask_key {
+            mask(&mut payload, key);
+        }
+
+        let opcode = match opcode {
+            0 => Opcode::Continuation,
+            1 => Opcode::Text,
+            2 => Opcode::Binary,
+            8 => Opcode::ConnectionClose,
+            9 => Opcode::Ping,
+            10 => Opcode::Pong,
+            _ => return Err(FrameReadError::InvalidOpcode)
+        };
+
+        Ok(Frame {
+            fin: fin == 1,
+            rsv1: rsv1 == 1,
+            rsv2: rsv2 == 1,
+            rsv3: rsv3 == 1,
+            opcode,
+            mask: do_mask == 1,
+            payload_len: length,
+            masking_key: mask_key,
+            payload_data: payload,
+        })
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::io::{Cursor, Write};
+    use crate::frame::{Frame, Opcode};
+    use crate::frame::wire::{FrameReadable, FrameWritable};
+
+    #[test]
+    fn decoding_6455_5_7_a() {
+        let mut buf: Cursor<Vec<u8>> = Cursor::new(vec![0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
+
+        let frame = buf.read_frame().unwrap();
+
+        assert_eq!(frame, Frame {
+            fin: true,
+            rsv1: false,
+            rsv2: false,
+            rsv3: false,
+            opcode: Opcode::Text,
+            mask: false,
+            payload_len: 5,
+            masking_key: None,
+            payload_data: "Hello".to_string().as_bytes().to_vec(),
+        });
+    }
+
+    #[test]
+    fn decoding_6455_5_7_b() {
+        let mut buf: Cursor<Vec<u8>> = Cursor::new(vec![0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58]);
+
+        assert_eq!(buf.read_frame().unwrap(), Frame {
+            fin: true,
+            rsv1: false,
+            rsv2: false,
+            rsv3: false,
+            opcode: Opcode::Text,
+            mask: true,
+            payload_len: 5,
+            masking_key: Some([0x37, 0xfa, 0x21, 0x3d]),
+            payload_data: "Hello".to_string().as_bytes().to_vec(),
+        });
+    }
+
+    // 5.7(e) and 5.7(f) are skipped because they are impractical to write
+}
\ No newline at end of file