e3pf/libepf/src/handshake_stream.rs

551 lines
18 KiB
Rust
Raw Normal View History

use std::error::Error;
use std::io;
use async_trait::async_trait;
use chacha20poly1305::{AeadCore, Key, KeyInit, XChaCha20Poly1305, XNonce};
use chacha20poly1305::aead::{Aead, Payload};
use ed25519_dalek::{SigningKey};
use rand::Rng;
use rand::rngs::OsRng;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use x25519_dalek::x25519;
use crate::ca_pool::{load_ca_pool};
use crate::error::EpfHandshakeError;
use crate::pki::{EPFCertificate, EpfPkiCertificateOps, EpfPrivateKey, EpfPublicKey};
use crate::protocol::{encode_packet, EpfApplicationData, EpfClientHello, EpfClientState, EpfFinished, EpfMessage, EpfServerHello, EpfServerState, PACKET_APPLICATION_DATA, PACKET_CLIENT_HELLO, PACKET_FINISHED, PACKET_SERVER_HELLO, PROTOCOL_VERSION, recv_packet};
///// CLIENT /////
#[derive(Clone)]
pub struct EpfClientUpgraded<T: AsyncWriteExt + AsyncReadExt> {
inner: T,
state: EpfClientState,
client_random: [u8; 24],
server_random: [u8; 16],
client_cert: Option<EPFCertificate>,
packet_queue: Vec<EpfMessage>,
server_cert: Option<EPFCertificate>,
cipher: Option<XChaCha20Poly1305>,
private_key: EpfPrivateKey,
public_key: EpfPublicKey
}
pub enum ClientAuthentication {
Cert(Box<EPFCertificate>, EpfPrivateKey),
Ephemeral
}
#[async_trait]
pub trait EpfClientUpgradable {
async fn upgrade(self, auth: ClientAuthentication) -> EpfClientUpgraded<Self> where Self: Sized + AsyncWriteExt + AsyncReadExt + Send;
}
#[async_trait]
impl<T> EpfClientUpgradable for T where T: AsyncWriteExt + AsyncReadExt + Send {
async fn upgrade(self, auth: ClientAuthentication) -> EpfClientUpgraded<Self> where Self: Sized + AsyncWriteExt + AsyncReadExt + Send {
let private_key;
let public_key: [u8; 32];
let cert;
match auth {
ClientAuthentication::Cert(cert_d, key) => {
cert = Some(cert_d);
private_key = key;
public_key = key[32..].try_into().unwrap();
},
ClientAuthentication::Ephemeral => {
cert = None;
let private_key_l: [u8; 32] = OsRng.gen();
let private_key_real = SigningKey::from(private_key_l);
public_key = *private_key_real.verifying_key().as_bytes();
private_key = private_key_real.to_keypair_bytes();
}
}
EpfClientUpgraded {
inner: self,
state: EpfClientState::NotStarted,
client_random: OsRng.gen(),
server_random: [0u8; 16],
client_cert: cert.map(|u| *u),
server_cert: None,
packet_queue: vec![],
cipher: None,
private_key,
public_key,
}
}
}
#[async_trait]
pub trait EpfClientHandshaker<S: AsyncWriteExt + AsyncReadExt + Unpin> {
async fn handshake(&mut self) -> Result<(), Box<dyn Error>>;
async fn upgrade(self) -> EpfClientStream<Self, S> where Self: Sized;
}
#[async_trait]
impl<T: AsyncWriteExt + AsyncReadExt + Send + Unpin + Clone> EpfClientHandshaker<T> for EpfClientUpgraded<T> {
async fn handshake(&mut self) -> Result<(), Box<dyn Error>> {
match self.state {
EpfClientState::NotStarted => (),
_ => return Err(EpfHandshakeError::AlreadyTunnelled.into())
}
// Step 0: Load Trusted Cert Store
let cert_pool = load_ca_pool()?;
// Step 1: Send Client Hello
self.inner.write_all(&encode_packet(PACKET_CLIENT_HELLO, &EpfClientHello {
protocol_version: PROTOCOL_VERSION,
client_random: self.client_random,
client_certificate: self.client_cert.clone(),
client_public_key: self.public_key,
})?).await?;
self.state = EpfClientState::WaitingForServerHello;
// Step 2: Wait for Server Hello
loop {
let packet = recv_packet(&mut self.inner).await?;
if packet.packet_id != PACKET_SERVER_HELLO {
self.packet_queue.push(packet);
continue;
}
let server_hello: EpfServerHello = rmp_serde::from_slice(&packet.packet_data)?;
self.server_random = server_hello.server_random;
if server_hello.protocol_version != PROTOCOL_VERSION {
return Err(EpfHandshakeError::UnsupportedProtocolVersion(server_hello.protocol_version as usize).into());
}
self.server_cert = Some(server_hello.server_certificate);
break;
}
// Step 3: Validate Server Certificate
let cert_valid = self.server_cert.as_ref().unwrap().verify(&cert_pool);
if let Err(e) = cert_valid {
return Err(EpfHandshakeError::InvalidCertificate(e).into())
}
if let Ok(false) = cert_valid {
return Err(EpfHandshakeError::UntrustedCertificate.into())
}
// Server Cert OK
// Step 4: Build the cipher
let shared_key = x25519(self.private_key[..32].try_into().unwrap(), self.server_cert.as_ref().unwrap().details.public_key);
let cc20p1305_key = Key::from(shared_key);
let cc20p1305 = XChaCha20Poly1305::new(&cc20p1305_key);
self.cipher = Some(cc20p1305);
let payload = Payload {
msg: &[0x42],
aad: &self.server_random,
};
let nonce = XNonce::from_slice(&self.client_random);
let encrypted_0x42 = match self.cipher.as_ref().unwrap().encrypt(nonce, payload) {
Ok(d) => d,
Err(_) => {
return Err(EpfHandshakeError::EncryptionError.into())
}
};
self.inner.write_all(&encode_packet(PACKET_FINISHED, &EpfFinished {
protocol_version: PROTOCOL_VERSION,
encrypted_0x42
})?).await?;
self.state = EpfClientState::WaitingForFinished;
loop {
let packet = recv_packet(&mut self.inner).await?;
if packet.packet_id != PACKET_FINISHED {
self.packet_queue.push(packet);
continue;
}
let packet_finished: EpfFinished = rmp_serde::from_slice(&packet.packet_data)?;
let payload = Payload {
msg: &packet_finished.encrypted_0x42,
aad: &self.server_random,
};
let hopefully_0x42 = match self.cipher.as_ref().unwrap().decrypt(nonce, payload) {
Ok(d) => d,
Err(_) => {
return Err(EpfHandshakeError::EncryptionError.into());
}
};
if hopefully_0x42 != vec![0x42] {
return Err(EpfHandshakeError::MissingKeyProof.into())
}
break;
}
self.state = EpfClientState::Transport;
Ok(())
}
async fn upgrade(self) -> EpfClientStream<Self, T> where Self: Sized {
EpfClientStream {
inner: self.clone(),
aad: self.server_random,
client_cert: self.client_cert,
packet_queue: self.packet_queue,
server_cert: self.server_cert.unwrap(),
cipher: self.cipher.unwrap(),
private_key: self.private_key,
public_key: self.public_key,
raw_stream: self.inner
}
}
}
pub struct EpfClientStream<T: EpfClientHandshaker<S>, S: AsyncReadExt + AsyncWriteExt + Unpin> {
inner: T,
raw_stream: S,
aad: [u8; 16],
client_cert: Option<EPFCertificate>,
packet_queue: Vec<EpfMessage>,
server_cert: EPFCertificate,
cipher: XChaCha20Poly1305,
private_key: EpfPrivateKey,
public_key: EpfPublicKey
}
#[async_trait]
pub trait EpfStreamOps {
async fn write(&mut self, data: &[u8]) -> Result<(), Box<dyn Error>>;
async fn read(&mut self) -> Result<Vec<u8>, Box<dyn Error>>;
}
#[async_trait]
impl<T: EpfClientHandshaker<S> + Send, S: AsyncReadExt + AsyncWriteExt + Unpin + Send> EpfStreamOps for EpfClientStream<T, S> {
async fn write(&mut self, data: &[u8]) -> Result<(), Box<dyn Error>> {
let nonce = XChaCha20Poly1305::generate_nonce(OsRng);
let payload = Payload {
msg: data,
aad: &self.aad,
};
let ciphertext = match self.cipher.encrypt(&nonce, payload) {
Ok(c) => c,
Err(_) => {
return Err(io::Error::new(io::ErrorKind::Other, "Encryption error").into())
}
};
let application_data = EpfApplicationData {
protocol_version: PROTOCOL_VERSION,
encrypted_application_data: ciphertext,
nonce: nonce.try_into().unwrap(),
};
let packet = encode_packet(PACKET_APPLICATION_DATA, &application_data)?;
self.raw_stream.write_all(&packet).await?;
Ok(())
}
async fn read(&mut self) -> Result<Vec<u8>, Box<dyn Error>> {
loop {
let packet = recv_packet(&mut self.raw_stream).await?;
if packet.packet_id != PACKET_APPLICATION_DATA {
self.packet_queue.push(packet);
continue;
}
let app_data: EpfApplicationData = rmp_serde::from_slice(&packet.packet_data)?;
let nonce = XNonce::from_slice(&app_data.nonce);
let payload = Payload {
msg: &app_data.encrypted_application_data,
aad: &self.aad,
};
let plaintext = match self.cipher.decrypt(nonce, payload) {
Ok(p) => p,
Err(_) => {
return Err(io::Error::new(io::ErrorKind::Other, "Decryption error").into())
}
};
return Ok(plaintext);
}
}
}
///// SERVER /////
#[derive(Clone)]
pub struct EpfServerUpgraded<T: AsyncWriteExt + AsyncReadExt> {
inner: T,
state: EpfServerState,
client_random: [u8; 24],
server_random: [u8; 16],
client_cert: Option<EPFCertificate>,
packet_queue: Vec<EpfMessage>,
cipher: Option<XChaCha20Poly1305>,
cert: EPFCertificate,
private_key: EpfPrivateKey,
public_key: EpfPublicKey
}
#[async_trait]
pub trait EpfServerUpgradable {
async fn upgrade(self, cert: EPFCertificate, private_key: EpfPrivateKey) -> EpfServerUpgraded<Self> where Self: Sized + AsyncWriteExt + AsyncReadExt + Send;
}
#[async_trait]
impl<T> EpfServerUpgradable for T where T: AsyncWriteExt + AsyncReadExt + Send {
async fn upgrade(self, cert: EPFCertificate, private_key: EpfPrivateKey) -> EpfServerUpgraded<Self> where Self: Sized + AsyncWriteExt + AsyncReadExt + Send {
EpfServerUpgraded {
inner: self,
state: EpfServerState::WaitingForClientHello,
server_random: OsRng.gen(),
client_random: [0u8; 24],
cert,
client_cert: None,
packet_queue: vec![],
cipher: None,
private_key,
public_key: SigningKey::from_keypair_bytes(&private_key).unwrap().verifying_key().to_bytes(),
}
}
}
#[async_trait]
pub trait EpfServerHandshaker<S: AsyncWriteExt + AsyncReadExt + Unpin> {
async fn handshake(&mut self) -> Result<(), Box<dyn Error>>;
async fn upgrade(self) -> EpfServerStream<Self, S> where Self: Sized;
}
#[async_trait]
impl<T: AsyncWriteExt + AsyncReadExt + Send + Unpin + Clone> EpfServerHandshaker<T> for EpfServerUpgraded<T> {
async fn handshake(&mut self) -> Result<(), Box<dyn Error>> {
match self.state {
EpfServerState::WaitingForClientHello => (),
_ => return Err(EpfHandshakeError::AlreadyTunnelled.into())
}
// Step 0: Load Trusted Cert Store
let cert_pool = load_ca_pool()?;
let client_public_key;
// Step 1: Wait for Client Hello
loop {
let packet = recv_packet(&mut self.inner).await?;
if packet.packet_id != PACKET_CLIENT_HELLO {
self.packet_queue.push(packet);
continue;
}
let client_hello: EpfClientHello = rmp_serde::from_slice(&packet.packet_data)?;
self.client_random = client_hello.client_random;
if client_hello.protocol_version != PROTOCOL_VERSION {
return Err(EpfHandshakeError::UnsupportedProtocolVersion(client_hello.protocol_version as usize).into());
}
self.client_cert = client_hello.client_certificate;
client_public_key = client_hello.client_public_key;
break;
}
// Step 2: Validate Client Certificate (if present)
if let Some(client_cert) = &self.client_cert {
let cert_valid = client_cert.verify(&cert_pool);
if let Err(e) = cert_valid {
return Err(EpfHandshakeError::InvalidCertificate(e).into())
}
if let Ok(false) = cert_valid {
return Err(EpfHandshakeError::UntrustedCertificate.into())
}
}
// Client Cert OK (if present)
// Step 3: Send Server Hello
self.inner.write_all(&encode_packet(PACKET_SERVER_HELLO, &EpfServerHello {
protocol_version: PROTOCOL_VERSION,
server_certificate: self.cert.clone(),
server_random: self.server_random,
})?).await?;
self.state = EpfServerState::WaitingForFinished;
// Step 4: Build the cipher
let shared_key = x25519(self.private_key[..32].try_into().unwrap(), client_public_key);
let cc20p1305_key = Key::from(shared_key);
let cc20p1305 = XChaCha20Poly1305::new(&cc20p1305_key);
self.cipher = Some(cc20p1305);
let payload = Payload {
msg: &[0x42],
aad: &self.server_random,
};
let nonce = XNonce::from_slice(&self.client_random);
loop {
let packet = recv_packet(&mut self.inner).await?;
if packet.packet_id != PACKET_FINISHED {
self.packet_queue.push(packet);
continue;
}
let packet_finished: EpfFinished = rmp_serde::from_slice(&packet.packet_data)?;
let payload = Payload {
msg: &packet_finished.encrypted_0x42,
aad: &self.server_random,
};
let hopefully_0x42 = match self.cipher.as_ref().unwrap().decrypt(nonce, payload) {
Ok(d) => d,
Err(_) => {
return Err(EpfHandshakeError::EncryptionError.into());
}
};
if hopefully_0x42 != vec![0x42] {
return Err(EpfHandshakeError::MissingKeyProof.into())
}
break;
}
let encrypted_0x42 = match self.cipher.as_ref().unwrap().encrypt(nonce, payload) {
Ok(d) => d,
Err(_) => {
return Err(EpfHandshakeError::EncryptionError.into())
}
};
self.inner.write_all(&encode_packet(PACKET_FINISHED, &EpfFinished {
protocol_version: PROTOCOL_VERSION,
encrypted_0x42
})?).await?;
self.state = EpfServerState::WaitingForFinished;
self.state = EpfServerState::Transport;
Ok(())
}
async fn upgrade(self) -> EpfServerStream<Self, T> where Self: Sized {
EpfServerStream {
inner: self.clone(),
aad: self.server_random,
server_cert: self.cert,
packet_queue: self.packet_queue,
client_cert: self.client_cert,
cipher: self.cipher.unwrap(),
private_key: self.private_key,
public_key: self.public_key,
raw_stream: self.inner
}
}
}
pub struct EpfServerStream<T: EpfServerHandshaker<S>, S: AsyncReadExt + AsyncWriteExt + Unpin> {
inner: T,
raw_stream: S,
aad: [u8; 16],
client_cert: Option<EPFCertificate>,
packet_queue: Vec<EpfMessage>,
server_cert: EPFCertificate,
cipher: XChaCha20Poly1305,
private_key: EpfPrivateKey,
public_key: EpfPublicKey
}
#[async_trait]
impl<T: EpfServerHandshaker<S> + Send, S: AsyncReadExt + AsyncWriteExt + Unpin + Send> EpfStreamOps for EpfServerStream<T, S> {
async fn write(&mut self, data: &[u8]) -> Result<(), Box<dyn Error>> {
let nonce = XChaCha20Poly1305::generate_nonce(OsRng);
let payload = Payload {
msg: data,
aad: &self.aad,
};
let ciphertext = match self.cipher.encrypt(&nonce, payload) {
Ok(c) => c,
Err(_) => {
return Err(io::Error::new(io::ErrorKind::Other, "Encryption error").into())
}
};
let application_data = EpfApplicationData {
protocol_version: PROTOCOL_VERSION,
encrypted_application_data: ciphertext,
nonce: nonce.try_into().unwrap(),
};
let packet = encode_packet(PACKET_APPLICATION_DATA, &application_data)?;
self.raw_stream.write_all(&packet).await?;
Ok(())
}
async fn read(&mut self) -> Result<Vec<u8>, Box<dyn Error>> {
loop {
let packet = recv_packet(&mut self.raw_stream).await?;
if packet.packet_id != PACKET_APPLICATION_DATA {
self.packet_queue.push(packet);
continue;
}
let app_data: EpfApplicationData = rmp_serde::from_slice(&packet.packet_data)?;
let nonce = XNonce::from_slice(&app_data.nonce);
let payload = Payload {
msg: &app_data.encrypted_application_data,
aad: &self.aad,
};
let plaintext = match self.cipher.decrypt(nonce, payload) {
Ok(p) => p,
Err(_) => {
return Err(io::Error::new(io::ErrorKind::Other, "Decryption error").into())
}
};
return Ok(plaintext);
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
#[test]
pub fn stream_test() {
}
}