frame encoding
This commit is contained in:
parent
379ae045d0
commit
4f3a570ab4
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
|
@ -0,0 +1,12 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="CPP_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/hornbeam-client/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/hornbeam/src" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="MarkdownSettingsMigration">
|
||||
<option name="stateVersion" value="1" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/hornbeam.iml" filepath="$PROJECT_DIR$/.idea/hornbeam.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,68 @@
|
|||
use std::io::{Error, Read, Write};
|
||||
use crate::frame::AsBit;
|
||||
|
||||
pub enum WireLength {
|
||||
Small(u8),
|
||||
Medium(u16),
|
||||
Large(u64)
|
||||
}
|
||||
|
||||
impl From<u64> for WireLength {
|
||||
fn from(value: u64) -> Self {
|
||||
if value <= 125 {
|
||||
Self::Small(value as u8)
|
||||
} else if value > 125 && value < 65536 {
|
||||
Self::Medium(value as u16)
|
||||
} else {
|
||||
Self::Large(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<u64> for WireLength {
|
||||
fn into(self) -> u64 {
|
||||
match self {
|
||||
Self::Small(v) => v as u64,
|
||||
Self::Medium(v) => v as u64,
|
||||
Self::Large(v) => v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LengthWritable {
|
||||
fn write_length<W: Write>(&self, masked: bool, w: &mut W) -> Result<(), std::io::Error>;
|
||||
}
|
||||
impl LengthWritable for WireLength {
|
||||
fn write_length<W: Write>(&self, masked: bool, w: &mut W) -> Result<(), Error> {
|
||||
match self {
|
||||
Self::Small(v) => w.write_all(&[*v | masked.as_bit() << 7]),
|
||||
Self::Medium(v) => {
|
||||
let bytes = v.to_be_bytes();
|
||||
w.write_all(&[126 | masked.as_bit() << 7, bytes[0], bytes[1]])
|
||||
},
|
||||
Self::Large(v) => {
|
||||
let bytes = v.to_be_bytes();
|
||||
w.write_all(&[127 | masked.as_bit() << 7, bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LengthReadable {
|
||||
fn read_length<R: Read>(initial: u8, r: &mut R) -> Result<Self, std::io::Error> where Self: Sized;
|
||||
}
|
||||
impl LengthReadable for WireLength {
|
||||
fn read_length<R: Read>(initial: u8, r: &mut R) -> Result<Self, Error> where Self: Sized {
|
||||
if initial <= 125 {
|
||||
return Ok(WireLength::Small(initial));
|
||||
} else if initial == 126 {
|
||||
let mut buf2 = [0u8; 2];
|
||||
r.read_exact(&mut buf2)?;
|
||||
return Ok(WireLength::Medium(u16::from_be_bytes(buf2)));
|
||||
}
|
||||
|
||||
let mut buf2 = [0u8; 8];
|
||||
r.read_exact(&mut buf2)?;
|
||||
return Ok(WireLength::Large(u64::from_be_bytes(buf2)));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// Thanks Tungstenite for this implementation
|
||||
pub fn mask(data: &mut [u8], key: [u8; 4]) {
|
||||
for (i, byte) in data.iter_mut().enumerate() {
|
||||
*byte ^= key[i & 3];
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/// Contains traits for the wire encoding of frames
|
||||
pub mod wire;
|
||||
/// Contains the implementation of FrameWritable
|
||||
pub mod write;
|
||||
/// Contains functions useful for encoding and decoding wire lengths
|
||||
pub mod length;
|
||||
/// Contains functions for frame masking
|
||||
pub mod mask;
|
||||
|
||||
pub struct Frame {
|
||||
pub fin: bool, // 1 bit
|
||||
pub rsv1: bool, // 1 bit
|
||||
pub rsv2: bool, // 1 bit
|
||||
pub rsv3: bool, // 1 bit
|
||||
pub opcode: Opcode, // 4 bits
|
||||
// -- byte boundary --
|
||||
pub mask: bool, // 1 bit
|
||||
pub payload_len: u64, // 7 bits, or 7 + 16 bits, or 7 + 64 bits
|
||||
// -- byte boundary --
|
||||
pub masking_key: Option<[u8; 4]>, // 4 bytes
|
||||
// -- byte boundary --
|
||||
pub payload_data: Vec<u8>
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum Opcode {
|
||||
Continuation = 0x0,
|
||||
Text = 0x1,
|
||||
Binary = 0x2,
|
||||
ConnectionClose = 0x8,
|
||||
Ping = 0x9,
|
||||
Pong = 0xa
|
||||
}
|
||||
|
||||
pub trait AsBit {
|
||||
fn as_bit(&self) -> u8;
|
||||
}
|
||||
impl AsBit for bool {
|
||||
fn as_bit(&self) -> u8 {
|
||||
if *self { 1u8 } else { 0u8 }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
use crate::frame::Frame;
|
||||
|
||||
pub trait FrameWritable {
|
||||
type Error;
|
||||
|
||||
fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub trait FrameReadable {
|
||||
type Error;
|
||||
|
||||
fn read_frame(&mut self) -> Result<Frame, Self::Error>;
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
use std::error::Error;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use crate::frame::{AsBit, Frame};
|
||||
use crate::frame::length::{LengthWritable, WireLength};
|
||||
use crate::frame::mask::mask;
|
||||
use crate::frame::wire::FrameWritable;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FrameWriteError {
|
||||
IoError(io::Error),
|
||||
MaskEnabledButMissingKey
|
||||
}
|
||||
impl Display for FrameWriteError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::IoError(e) => write!(f, "io error: {}", e),
|
||||
Self::MaskEnabledButMissingKey => write!(f, "frame masking enabled but key not present")
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Error for FrameWriteError {}
|
||||
impl From<io::Error> for FrameWriteError {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::IoError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: Write> FrameWritable for W {
|
||||
type Error = FrameWriteError;
|
||||
|
||||
fn write_frame(&mut self, frame: &Frame) -> Result<(), Self::Error> {
|
||||
// build the flags and opcode byte
|
||||
let flags: u8 = (frame.fin.as_bit() << 7) | (frame.rsv1.as_bit() << 6) | (frame.rsv2.as_bit() << 5) | (frame.rsv3.as_bit() << 4);
|
||||
let opcode: u8 = (frame.opcode.clone()) as u8;
|
||||
let byte_0 = flags | opcode;
|
||||
self.write_all(&[byte_0])?;
|
||||
|
||||
// write the mask bit and length
|
||||
WireLength::from(frame.payload_len).write_length(frame.mask, self)?;
|
||||
|
||||
let mut data = frame.payload_data.clone();
|
||||
|
||||
if frame.mask {
|
||||
if let Some(key) = frame.masking_key {
|
||||
self.write_all(&key)?;
|
||||
|
||||
mask(&mut data, key);
|
||||
} else {
|
||||
return Err(FrameWriteError::MaskEnabledButMissingKey); // TODO: we should do this earlier, before stuff is written to the socket
|
||||
}
|
||||
}
|
||||
|
||||
self.write_all(&data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::Cursor;
|
||||
use crate::frame::{Frame, Opcode};
|
||||
use crate::frame::wire::FrameWritable;
|
||||
|
||||
#[test]
|
||||
fn encoding_6455_5_7_a() {
|
||||
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
|
||||
|
||||
let frame = Frame {
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode::Text,
|
||||
mask: false,
|
||||
payload_len: 5,
|
||||
masking_key: None,
|
||||
payload_data: "Hello".to_string().as_bytes().to_vec(),
|
||||
};
|
||||
|
||||
buf.write_frame(&frame).unwrap();
|
||||
|
||||
assert_eq!(buf.into_inner(), vec![0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_6455_5_7_b() {
|
||||
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
|
||||
|
||||
let frame = Frame {
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode::Text,
|
||||
mask: true,
|
||||
payload_len: 5,
|
||||
masking_key: Some([0x37, 0xfa, 0x21, 0x3d]),
|
||||
payload_data: "Hello".to_string().as_bytes().to_vec(),
|
||||
};
|
||||
|
||||
buf.write_frame(&frame).unwrap();
|
||||
|
||||
assert_eq!(buf.into_inner(), vec![0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_6455_5_7_e_simplified() {
|
||||
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
|
||||
|
||||
let frame = Frame {
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode::Text,
|
||||
mask: false,
|
||||
payload_len: 256,
|
||||
masking_key: None,
|
||||
payload_data: "Hello".to_string().as_bytes().to_vec(),
|
||||
};
|
||||
|
||||
buf.write_frame(&frame).unwrap();
|
||||
|
||||
assert_eq!(buf.into_inner(), vec![0x81, 0x7e, 0x01, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encoding_6455_5_7_f_simplified() {
|
||||
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
|
||||
|
||||
let frame = Frame {
|
||||
fin: true,
|
||||
rsv1: false,
|
||||
rsv2: false,
|
||||
rsv3: false,
|
||||
opcode: Opcode::Text,
|
||||
mask: false,
|
||||
payload_len: 65536,
|
||||
masking_key: None,
|
||||
payload_data: "Hello".to_string().as_bytes().to_vec(),
|
||||
};
|
||||
|
||||
buf.write_frame(&frame).unwrap();
|
||||
|
||||
assert_eq!(buf.into_inner(), vec![0x81, 0x7f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
|
||||
}
|
||||
}
|
|
@ -4,7 +4,8 @@ use std::io;
|
|||
use std::io::{Read, Write};
|
||||
use std::string::FromUtf8Error;
|
||||
use url::Url;
|
||||
use crate::handshake_common::{HeaderMap, WEBSOCKET_PROTOCOL_VERSION};
|
||||
use crate::b64::impl_b64::base64_encode;
|
||||
use crate::handshake_common::{derive_handshake_response, HeaderMap, WEBSOCKET_PROTOCOL_VERSION};
|
||||
use crate::random::websocket_client_key;
|
||||
|
||||
/// Contains the information needed to perform the WebSocket client handshake. Create from a URL with `ClientConnectionInfo::from(url)`,
|
||||
|
@ -240,7 +241,11 @@ impl ClientConnectionInfo {
|
|||
} else if line.to_lowercase().starts_with("sec-websocket-accept: ") {
|
||||
let accept_key = line.split(' ').nth(1).unwrap();
|
||||
|
||||
if derive_handshake_response(&base64_encode(&self.websocket_key)) != accept_key {
|
||||
return Err(ClientHandshakeRecvError::IncorrectSecWebsocketAccept);
|
||||
}
|
||||
|
||||
has_sec_accept = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,9 +7,22 @@ pub const WEBSOCKET_PROTOCOL_VERSION: i32 = 13;
|
|||
/// A type alias for a key-value header map
|
||||
pub type HeaderMap = HashMap<String, String>;
|
||||
|
||||
pub(crate) fn derive_handshake_response(input: [u8; 16]) -> String {
|
||||
pub(crate) fn derive_handshake_response(input: &str) -> String {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(input);
|
||||
hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
crate::b64::impl_b64::base64_encode(&hasher.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::handshake_common::derive_handshake_response;
|
||||
|
||||
#[test]
|
||||
fn handshake_response_derivation() {
|
||||
assert_eq!(
|
||||
derive_handshake_response("dGhlIHNhbXBsZSBub25jZQ=="),
|
||||
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
)
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@
|
|||
#![warn(clippy::nursery)]
|
||||
//#![deny(clippy::unwrap_used)]
|
||||
#![warn(clippy::expect_used)]
|
||||
#![deny(missing_docs)]
|
||||
//#![deny(missing_docs)]
|
||||
#![allow(clippy::must_use_candidate)] // This gets annoying
|
||||
|
||||
#[allow(unused)]
|
||||
|
@ -82,3 +82,6 @@ compile_error!("You need to select one CSPRNG implementation");
|
|||
pub mod random;
|
||||
|
||||
pub(crate) mod b64;
|
||||
|
||||
/// WebSocket frame encoding and decoding
|
||||
pub mod frame;
|
|
@ -4,3 +4,8 @@ use rand::Rng;
|
|||
pub fn websocket_client_key() -> [u8; 16] {
|
||||
rand::thread_rng().gen()
|
||||
}
|
||||
|
||||
/// Generates a masking key
|
||||
pub fn masking_key() -> [u8; 4] {
|
||||
rand::thread_rng().gen()
|
||||
}
|
Loading…
Reference in New Issue