use std::error::Error; use std::io; use async_trait::async_trait; use chacha20poly1305::{AeadCore, Key, KeyInit, XChaCha20Poly1305, XNonce}; use chacha20poly1305::aead::{Aead, Payload}; use ed25519_dalek::{SigningKey}; use rand::Rng; use rand::rngs::OsRng; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use x25519_dalek::x25519; use crate::ca_pool::{load_ca_pool}; use crate::error::EpfHandshakeError; use crate::pki::{EPFCertificate, EpfPkiCertificateOps, EpfPrivateKey, EpfPublicKey}; use crate::protocol::{encode_packet, EpfApplicationData, EpfClientHello, EpfClientState, EpfFinished, EpfMessage, EpfServerHello, EpfServerState, PACKET_APPLICATION_DATA, PACKET_CLIENT_HELLO, PACKET_FINISHED, PACKET_SERVER_HELLO, PROTOCOL_VERSION, recv_packet}; ///// CLIENT ///// #[derive(Clone)] pub struct EpfClientUpgraded { inner: T, state: EpfClientState, client_random: [u8; 24], server_random: [u8; 16], client_cert: Option, packet_queue: Vec, server_cert: Option, cipher: Option, private_key: EpfPrivateKey, public_key: EpfPublicKey } pub enum ClientAuthentication { Cert(Box, EpfPrivateKey), Ephemeral } #[async_trait] pub trait EpfClientUpgradable { async fn upgrade(self, auth: ClientAuthentication) -> EpfClientUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send; } #[async_trait] impl EpfClientUpgradable for T where T: AsyncWriteExt + AsyncReadExt + Send { async fn upgrade(self, auth: ClientAuthentication) -> EpfClientUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send { let private_key; let public_key: [u8; 32]; let cert; match auth { ClientAuthentication::Cert(cert_d, key) => { cert = Some(cert_d); private_key = key; public_key = key[32..].try_into().unwrap(); }, ClientAuthentication::Ephemeral => { cert = None; let private_key_l: [u8; 32] = OsRng.gen(); let private_key_real = SigningKey::from(private_key_l); public_key = *private_key_real.verifying_key().as_bytes(); private_key = private_key_real.to_keypair_bytes(); } } EpfClientUpgraded { inner: self, state: EpfClientState::NotStarted, client_random: OsRng.gen(), server_random: [0u8; 16], client_cert: cert.map(|u| *u), server_cert: None, packet_queue: vec![], cipher: None, private_key, public_key, } } } #[async_trait] pub trait EpfClientHandshaker { async fn handshake(&mut self) -> Result<(), Box>; async fn upgrade(self) -> EpfClientStream where Self: Sized; } #[async_trait] impl EpfClientHandshaker for EpfClientUpgraded { async fn handshake(&mut self) -> Result<(), Box> { match self.state { EpfClientState::NotStarted => (), _ => return Err(EpfHandshakeError::AlreadyTunnelled.into()) } // Step 0: Load Trusted Cert Store let cert_pool = load_ca_pool()?; // Step 1: Send Client Hello self.inner.write_all(&encode_packet(PACKET_CLIENT_HELLO, &EpfClientHello { protocol_version: PROTOCOL_VERSION, client_random: self.client_random, client_certificate: self.client_cert.clone(), client_public_key: self.public_key, })?).await?; self.state = EpfClientState::WaitingForServerHello; // Step 2: Wait for Server Hello loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_SERVER_HELLO { self.packet_queue.push(packet); continue; } let server_hello: EpfServerHello = rmp_serde::from_slice(&packet.packet_data)?; self.server_random = server_hello.server_random; if server_hello.protocol_version != PROTOCOL_VERSION { return Err(EpfHandshakeError::UnsupportedProtocolVersion(server_hello.protocol_version as usize).into()); } self.server_cert = Some(server_hello.server_certificate); break; } // Step 3: Validate Server Certificate let cert_valid = self.server_cert.as_ref().unwrap().verify(&cert_pool); if let Err(e) = cert_valid { return Err(EpfHandshakeError::InvalidCertificate(e).into()) } if let Ok(false) = cert_valid { return Err(EpfHandshakeError::UntrustedCertificate.into()) } // Server Cert OK // Step 4: Build the cipher let shared_key = x25519(self.private_key[..32].try_into().unwrap(), self.server_cert.as_ref().unwrap().details.public_key); let cc20p1305_key = Key::from(shared_key); let cc20p1305 = XChaCha20Poly1305::new(&cc20p1305_key); self.cipher = Some(cc20p1305); let payload = Payload { msg: &[0x42], aad: &self.server_random, }; let nonce = XNonce::from_slice(&self.client_random); let encrypted_0x42 = match self.cipher.as_ref().unwrap().encrypt(nonce, payload) { Ok(d) => d, Err(_) => { return Err(EpfHandshakeError::EncryptionError.into()) } }; self.inner.write_all(&encode_packet(PACKET_FINISHED, &EpfFinished { protocol_version: PROTOCOL_VERSION, encrypted_0x42 })?).await?; self.state = EpfClientState::WaitingForFinished; loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_FINISHED { self.packet_queue.push(packet); continue; } let packet_finished: EpfFinished = rmp_serde::from_slice(&packet.packet_data)?; let payload = Payload { msg: &packet_finished.encrypted_0x42, aad: &self.server_random, }; let hopefully_0x42 = match self.cipher.as_ref().unwrap().decrypt(nonce, payload) { Ok(d) => d, Err(_) => { return Err(EpfHandshakeError::EncryptionError.into()); } }; if hopefully_0x42 != vec![0x42] { return Err(EpfHandshakeError::MissingKeyProof.into()) } break; } self.state = EpfClientState::Transport; Ok(()) } async fn upgrade(self) -> EpfClientStream where Self: Sized { EpfClientStream { inner: self.clone(), aad: self.server_random, client_cert: self.client_cert, packet_queue: self.packet_queue, server_cert: self.server_cert.unwrap(), cipher: self.cipher.unwrap(), private_key: self.private_key, public_key: self.public_key, raw_stream: self.inner } } } pub struct EpfClientStream, S: AsyncReadExt + AsyncWriteExt + Unpin> { inner: T, raw_stream: S, aad: [u8; 16], client_cert: Option, packet_queue: Vec, server_cert: EPFCertificate, cipher: XChaCha20Poly1305, private_key: EpfPrivateKey, public_key: EpfPublicKey } #[async_trait] pub trait EpfStreamOps { async fn write(&mut self, data: &[u8]) -> Result<(), Box>; async fn read(&mut self) -> Result, Box>; } #[async_trait] impl + Send, S: AsyncReadExt + AsyncWriteExt + Unpin + Send> EpfStreamOps for EpfClientStream { async fn write(&mut self, data: &[u8]) -> Result<(), Box> { let nonce = XChaCha20Poly1305::generate_nonce(OsRng); let payload = Payload { msg: data, aad: &self.aad, }; let ciphertext = match self.cipher.encrypt(&nonce, payload) { Ok(c) => c, Err(_) => { return Err(io::Error::new(io::ErrorKind::Other, "Encryption error").into()) } }; let application_data = EpfApplicationData { protocol_version: PROTOCOL_VERSION, encrypted_application_data: ciphertext, nonce: nonce.try_into().unwrap(), }; let packet = encode_packet(PACKET_APPLICATION_DATA, &application_data)?; self.raw_stream.write_all(&packet).await?; Ok(()) } async fn read(&mut self) -> Result, Box> { loop { let packet = recv_packet(&mut self.raw_stream).await?; if packet.packet_id != PACKET_APPLICATION_DATA { self.packet_queue.push(packet); continue; } let app_data: EpfApplicationData = rmp_serde::from_slice(&packet.packet_data)?; let nonce = XNonce::from_slice(&app_data.nonce); let payload = Payload { msg: &app_data.encrypted_application_data, aad: &self.aad, }; let plaintext = match self.cipher.decrypt(nonce, payload) { Ok(p) => p, Err(_) => { return Err(io::Error::new(io::ErrorKind::Other, "Decryption error").into()) } }; return Ok(plaintext); } } } ///// SERVER ///// #[derive(Clone)] pub struct EpfServerUpgraded { inner: T, state: EpfServerState, client_random: [u8; 24], server_random: [u8; 16], client_cert: Option, packet_queue: Vec, cipher: Option, cert: EPFCertificate, private_key: EpfPrivateKey, public_key: EpfPublicKey } #[async_trait] pub trait EpfServerUpgradable { async fn upgrade(self, cert: EPFCertificate, private_key: EpfPrivateKey) -> EpfServerUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send; } #[async_trait] impl EpfServerUpgradable for T where T: AsyncWriteExt + AsyncReadExt + Send { async fn upgrade(self, cert: EPFCertificate, private_key: EpfPrivateKey) -> EpfServerUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send { EpfServerUpgraded { inner: self, state: EpfServerState::WaitingForClientHello, server_random: OsRng.gen(), client_random: [0u8; 24], cert, client_cert: None, packet_queue: vec![], cipher: None, private_key, public_key: SigningKey::from_keypair_bytes(&private_key).unwrap().verifying_key().to_bytes(), } } } #[async_trait] pub trait EpfServerHandshaker { async fn handshake(&mut self) -> Result<(), Box>; async fn upgrade(self) -> EpfServerStream where Self: Sized; } #[async_trait] impl EpfServerHandshaker for EpfServerUpgraded { async fn handshake(&mut self) -> Result<(), Box> { match self.state { EpfServerState::WaitingForClientHello => (), _ => return Err(EpfHandshakeError::AlreadyTunnelled.into()) } // Step 0: Load Trusted Cert Store let cert_pool = load_ca_pool()?; let client_public_key; // Step 1: Wait for Client Hello loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_CLIENT_HELLO { self.packet_queue.push(packet); continue; } let client_hello: EpfClientHello = rmp_serde::from_slice(&packet.packet_data)?; self.client_random = client_hello.client_random; if client_hello.protocol_version != PROTOCOL_VERSION { return Err(EpfHandshakeError::UnsupportedProtocolVersion(client_hello.protocol_version as usize).into()); } self.client_cert = client_hello.client_certificate; client_public_key = client_hello.client_public_key; break; } // Step 2: Validate Client Certificate (if present) if let Some(client_cert) = &self.client_cert { let cert_valid = client_cert.verify(&cert_pool); if let Err(e) = cert_valid { return Err(EpfHandshakeError::InvalidCertificate(e).into()) } if let Ok(false) = cert_valid { return Err(EpfHandshakeError::UntrustedCertificate.into()) } } // Client Cert OK (if present) // Step 3: Send Server Hello self.inner.write_all(&encode_packet(PACKET_SERVER_HELLO, &EpfServerHello { protocol_version: PROTOCOL_VERSION, server_certificate: self.cert.clone(), server_random: self.server_random, })?).await?; self.state = EpfServerState::WaitingForFinished; // Step 4: Build the cipher let shared_key = x25519(self.private_key[..32].try_into().unwrap(), client_public_key); let cc20p1305_key = Key::from(shared_key); let cc20p1305 = XChaCha20Poly1305::new(&cc20p1305_key); self.cipher = Some(cc20p1305); let payload = Payload { msg: &[0x42], aad: &self.server_random, }; let nonce = XNonce::from_slice(&self.client_random); loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_FINISHED { self.packet_queue.push(packet); continue; } let packet_finished: EpfFinished = rmp_serde::from_slice(&packet.packet_data)?; let payload = Payload { msg: &packet_finished.encrypted_0x42, aad: &self.server_random, }; let hopefully_0x42 = match self.cipher.as_ref().unwrap().decrypt(nonce, payload) { Ok(d) => d, Err(_) => { return Err(EpfHandshakeError::EncryptionError.into()); } }; if hopefully_0x42 != vec![0x42] { return Err(EpfHandshakeError::MissingKeyProof.into()) } break; } let encrypted_0x42 = match self.cipher.as_ref().unwrap().encrypt(nonce, payload) { Ok(d) => d, Err(_) => { return Err(EpfHandshakeError::EncryptionError.into()) } }; self.inner.write_all(&encode_packet(PACKET_FINISHED, &EpfFinished { protocol_version: PROTOCOL_VERSION, encrypted_0x42 })?).await?; self.state = EpfServerState::WaitingForFinished; self.state = EpfServerState::Transport; Ok(()) } async fn upgrade(self) -> EpfServerStream where Self: Sized { EpfServerStream { inner: self.clone(), aad: self.server_random, server_cert: self.cert, packet_queue: self.packet_queue, client_cert: self.client_cert, cipher: self.cipher.unwrap(), private_key: self.private_key, public_key: self.public_key, raw_stream: self.inner } } } pub struct EpfServerStream, S: AsyncReadExt + AsyncWriteExt + Unpin> { inner: T, raw_stream: S, aad: [u8; 16], client_cert: Option, packet_queue: Vec, server_cert: EPFCertificate, cipher: XChaCha20Poly1305, private_key: EpfPrivateKey, public_key: EpfPublicKey } #[async_trait] impl + Send, S: AsyncReadExt + AsyncWriteExt + Unpin + Send> EpfStreamOps for EpfServerStream { async fn write(&mut self, data: &[u8]) -> Result<(), Box> { let nonce = XChaCha20Poly1305::generate_nonce(OsRng); let payload = Payload { msg: data, aad: &self.aad, }; let ciphertext = match self.cipher.encrypt(&nonce, payload) { Ok(c) => c, Err(_) => { return Err(io::Error::new(io::ErrorKind::Other, "Encryption error").into()) } }; let application_data = EpfApplicationData { protocol_version: PROTOCOL_VERSION, encrypted_application_data: ciphertext, nonce: nonce.try_into().unwrap(), }; let packet = encode_packet(PACKET_APPLICATION_DATA, &application_data)?; self.raw_stream.write_all(&packet).await?; Ok(()) } async fn read(&mut self) -> Result, Box> { loop { let packet = recv_packet(&mut self.raw_stream).await?; if packet.packet_id != PACKET_APPLICATION_DATA { self.packet_queue.push(packet); continue; } let app_data: EpfApplicationData = rmp_serde::from_slice(&packet.packet_data)?; let nonce = XNonce::from_slice(&app_data.nonce); let payload = Payload { msg: &app_data.encrypted_application_data, aad: &self.aad, }; let plaintext = match self.cipher.decrypt(nonce, payload) { Ok(p) => p, Err(_) => { return Err(io::Error::new(io::ErrorKind::Other, "Decryption error").into()) } }; return Ok(plaintext); } } } #[cfg(test)] mod tests { use std::io::Cursor; #[test] pub fn stream_test() { } }