frame encoding

This commit is contained in:
core 2023-08-10 02:58:11 -04:00
parent 379ae045d0
commit 4f3a570ab4
Signed by: core
GPG Key ID: FDBF740DADDCEECF
14 changed files with 350 additions and 4 deletions

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

12
.idea/hornbeam.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="CPP_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/hornbeam-client/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/hornbeam/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

6
.idea/misc.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/hornbeam.iml" filepath="$PROJECT_DIR$/.idea/hornbeam.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@ -0,0 +1,68 @@
use std::io::{Error, Read, Write};
use crate::frame::AsBit;
pub enum WireLength {
Small(u8),
Medium(u16),
Large(u64)
}
impl From<u64> for WireLength {
fn from(value: u64) -> Self {
if value <= 125 {
Self::Small(value as u8)
} else if value > 125 && value < 65536 {
Self::Medium(value as u16)
} else {
Self::Large(value)
}
}
}
impl Into<u64> for WireLength {
fn into(self) -> u64 {
match self {
Self::Small(v) => v as u64,
Self::Medium(v) => v as u64,
Self::Large(v) => v
}
}
}
pub trait LengthWritable {
fn write_length<W: Write>(&self, masked: bool, w: &mut W) -> Result<(), std::io::Error>;
}
impl LengthWritable for WireLength {
fn write_length<W: Write>(&self, masked: bool, w: &mut W) -> Result<(), Error> {
match self {
Self::Small(v) => w.write_all(&[*v | masked.as_bit() << 7]),
Self::Medium(v) => {
let bytes = v.to_be_bytes();
w.write_all(&[126 | masked.as_bit() << 7, bytes[0], bytes[1]])
},
Self::Large(v) => {
let bytes = v.to_be_bytes();
w.write_all(&[127 | masked.as_bit() << 7, bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
},
}
}
}
pub trait LengthReadable {
fn read_length<R: Read>(initial: u8, r: &mut R) -> Result<Self, std::io::Error> where Self: Sized;
}
impl LengthReadable for WireLength {
fn read_length<R: Read>(initial: u8, r: &mut R) -> Result<Self, Error> where Self: Sized {
if initial <= 125 {
return Ok(WireLength::Small(initial));
} else if initial == 126 {
let mut buf2 = [0u8; 2];
r.read_exact(&mut buf2)?;
return Ok(WireLength::Medium(u16::from_be_bytes(buf2)));
}
let mut buf2 = [0u8; 8];
r.read_exact(&mut buf2)?;
return Ok(WireLength::Large(u64::from_be_bytes(buf2)));
}
}

View File

@ -0,0 +1,6 @@
// Thanks Tungstenite for this implementation
pub fn mask(data: &mut [u8], key: [u8; 4]) {
for (i, byte) in data.iter_mut().enumerate() {
*byte ^= key[i & 3];
}
}

43
hornbeam/src/frame/mod.rs Normal file
View File

@ -0,0 +1,43 @@
/// Contains traits for the wire encoding of frames
pub mod wire;
/// Contains the implementation of FrameWritable
pub mod write;
/// Contains functions useful for encoding and decoding wire lengths
pub mod length;
/// Contains functions for frame masking
pub mod mask;
pub struct Frame {
pub fin: bool, // 1 bit
pub rsv1: bool, // 1 bit
pub rsv2: bool, // 1 bit
pub rsv3: bool, // 1 bit
pub opcode: Opcode, // 4 bits
// -- byte boundary --
pub mask: bool, // 1 bit
pub payload_len: u64, // 7 bits, or 7 + 16 bits, or 7 + 64 bits
// -- byte boundary --
pub masking_key: Option<[u8; 4]>, // 4 bytes
// -- byte boundary --
pub payload_data: Vec<u8>
}
#[derive(Clone)]
#[repr(u8)]
pub enum Opcode {
Continuation = 0x0,
Text = 0x1,
Binary = 0x2,
ConnectionClose = 0x8,
Ping = 0x9,
Pong = 0xa
}
pub trait AsBit {
fn as_bit(&self) -> u8;
}
impl AsBit for bool {
fn as_bit(&self) -> u8 {
if *self { 1u8 } else { 0u8 }
}
}

View File

@ -0,0 +1,13 @@
use crate::frame::Frame;
pub trait FrameWritable {
type Error;
fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error>;
}
pub trait FrameReadable {
type Error;
fn read_frame(&mut self) -> Result<Frame, Self::Error>;
}

150
hornbeam/src/frame/write.rs Normal file
View File

@ -0,0 +1,150 @@
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::io;
use std::io::Write;
use crate::frame::{AsBit, Frame};
use crate::frame::length::{LengthWritable, WireLength};
use crate::frame::mask::mask;
use crate::frame::wire::FrameWritable;
#[derive(Debug)]
pub enum FrameWriteError {
IoError(io::Error),
MaskEnabledButMissingKey
}
impl Display for FrameWriteError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::IoError(e) => write!(f, "io error: {}", e),
Self::MaskEnabledButMissingKey => write!(f, "frame masking enabled but key not present")
}
}
}
impl Error for FrameWriteError {}
impl From<io::Error> for FrameWriteError {
fn from(value: io::Error) -> Self {
Self::IoError(value)
}
}
impl<W: Write> FrameWritable for W {
type Error = FrameWriteError;
fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error> {
// build the flags and opcode byte
let flags: u8 = (frame.fin.as_bit() << 7) | (frame.rsv1.as_bit() << 6) | (frame.rsv2.as_bit() << 5) | (frame.rsv3.as_bit() << 4);
let opcode: u8 = (frame.opcode.clone()) as u8;
let byte_0 = flags | opcode;
self.write_all(&[byte_0])?;
// write the mask bit and length
WireLength::from(frame.payload_len).write_length(frame.mask, self)?;
let mut data = frame.payload_data.clone();
if frame.mask {
if let Some(key) = frame.masking_key {
self.write_all(&key)?;
mask(&mut data, key);
} else {
return Err(FrameWriteError::MaskEnabledButMissingKey); // TODO: we should do this earlier, before stuff is written to the socket
}
}
self.write_all(&data)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use crate::frame::{Frame, Opcode};
use crate::frame::wire::FrameWritable;
#[test]
fn encoding_6455_5_7_a() {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Text,
mask: false,
payload_len: 5,
masking_key: None,
payload_data: "Hello".to_string().as_bytes().to_vec(),
};
buf.write_frame(&frame).unwrap();
assert_eq!(buf.into_inner(), vec![0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
}
#[test]
fn encoding_6455_5_7_b() {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Text,
mask: true,
payload_len: 5,
masking_key: Some([0x37, 0xfa, 0x21, 0x3d]),
payload_data: "Hello".to_string().as_bytes().to_vec(),
};
buf.write_frame(&frame).unwrap();
assert_eq!(buf.into_inner(), vec![0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58]);
}
#[test]
fn encoding_6455_5_7_e_simplified() {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Text,
mask: false,
payload_len: 256,
masking_key: None,
payload_data: "Hello".to_string().as_bytes().to_vec(),
};
buf.write_frame(&frame).unwrap();
assert_eq!(buf.into_inner(), vec![0x81, 0x7e, 0x01, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
}
#[test]
fn encoding_6455_5_7_f_simplified() {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Text,
mask: false,
payload_len: 65536,
masking_key: None,
payload_data: "Hello".to_string().as_bytes().to_vec(),
};
buf.write_frame(&frame).unwrap();
assert_eq!(buf.into_inner(), vec![0x81, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
}
}

View File

@ -4,7 +4,8 @@ use std::io;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use url::Url; use url::Url;
use crate::handshake_common::{HeaderMap, WEBSOCKET_PROTOCOL_VERSION}; use crate::b64::impl_b64::base64_encode;
use crate::handshake_common::{derive_handshake_response, HeaderMap, WEBSOCKET_PROTOCOL_VERSION};
use crate::random::websocket_client_key; use crate::random::websocket_client_key;
/// Contains the information needed to perform the WebSocket client handshake. Create from a URL with `ClientConnectionInfo::from(url)`, /// Contains the information needed to perform the WebSocket client handshake. Create from a URL with `ClientConnectionInfo::from(url)`,
@ -240,7 +241,11 @@ impl ClientConnectionInfo {
} else if line.to_lowercase().starts_with("sec-websocket-accept: ") { } else if line.to_lowercase().starts_with("sec-websocket-accept: ") {
let accept_key = line.split(' ').nth(1).unwrap(); let accept_key = line.split(' ').nth(1).unwrap();
if derive_handshake_response(&base64_encode(&self.websocket_key)) != accept_key {
return Err(ClientHandshakeRecvError::IncorrectSecWebsocketAccept);
}
has_sec_accept = true;
} }
} }

View File

@ -7,9 +7,22 @@ pub const WEBSOCKET_PROTOCOL_VERSION: i32 = 13;
/// A type alias for a key-value header map /// A type alias for a key-value header map
pub type HeaderMap = HashMap<String, String>; pub type HeaderMap = HashMap<String, String>;
pub(crate) fn derive_handshake_response(input: [u8; 16]) -> String { pub(crate) fn derive_handshake_response(input: &str) -> String {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(input); hasher.update(input);
hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
crate::b64::impl_b64::base64_encode(&hasher.finalize()) crate::b64::impl_b64::base64_encode(&hasher.finalize())
}
#[cfg(test)]
mod tests {
use crate::handshake_common::derive_handshake_response;
#[test]
fn handshake_response_derivation() {
assert_eq!(
derive_handshake_response("dGhlIHNhbXBsZSBub25jZQ=="),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
)
}
} }

View File

@ -11,7 +11,7 @@
#![warn(clippy::nursery)] #![warn(clippy::nursery)]
//#![deny(clippy::unwrap_used)] //#![deny(clippy::unwrap_used)]
#![warn(clippy::expect_used)] #![warn(clippy::expect_used)]
#![deny(missing_docs)] //#![deny(missing_docs)]
#![allow(clippy::must_use_candidate)] // This gets annoying #![allow(clippy::must_use_candidate)] // This gets annoying
#[allow(unused)] #[allow(unused)]
@ -81,4 +81,7 @@ compile_error!("You need to select one CSPRNG implementation");
#[path = "random_rand.rs"] #[path = "random_rand.rs"]
pub mod random; pub mod random;
pub(crate) mod b64; pub(crate) mod b64;
/// WebSocket frame encoding and decoding
pub mod frame;

View File

@ -3,4 +3,9 @@ use rand::Rng;
/// Generates a random WebSocket client key /// Generates a random WebSocket client key
pub fn websocket_client_key() -> [u8; 16] { pub fn websocket_client_key() -> [u8; 16] {
rand::thread_rng().gen() rand::thread_rng().gen()
}
/// Generates a masking key
pub fn masking_key() -> [u8; 4] {
rand::thread_rng().gen()
} }