use crate::ca_pool::{EpfCaPool}; use crate::danger_trace; use crate::error::EpfHandshakeError; use crate::pki::{ EPFCertificate, EpfPkiCertificateOps, EpfPrivateKey, EpfPublicKey, }; use crate::protocol::{ encode_packet, recv_packet, EpfApplicationData, EpfClientHello, EpfClientState, EpfFinished, EpfMessage, EpfServerHello, EpfServerState, PACKET_APPLICATION_DATA, PACKET_CLIENT_HELLO, PACKET_FINISHED, PACKET_SERVER_HELLO, PROTOCOL_VERSION, }; use async_trait::async_trait; use chacha20poly1305::aead::{Aead, Payload}; use chacha20poly1305::{AeadCore, Key, KeyInit, XChaCha20Poly1305, XNonce}; use ed25519_dalek::{SigningKey}; use log::{trace}; use rand::rngs::OsRng; use rand::Rng; use std::error::Error; use std::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use x25519_dalek::{PublicKey, StaticSecret}; ///// CLIENT ///// pub struct EpfClientUpgraded { inner: T, state: EpfClientState, client_random: [u8; 24], server_random: [u8; 16], client_cert: Option, packet_queue: Vec, server_cert: Option, cipher: Option, private_key: EpfPrivateKey, public_key: PublicKey, } #[derive(Debug)] pub enum ClientAuthentication { Cert(Box, Box), Ephemeral, } #[async_trait] pub trait EpfClientUpgradable { async fn upgrade(self, auth: ClientAuthentication) -> EpfClientUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send; } #[async_trait] impl EpfClientUpgradable for T where T: AsyncWriteExt + AsyncReadExt + Send, { async fn upgrade(self, auth: ClientAuthentication) -> EpfClientUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send, { danger_trace!(target: "EpfClientUpgradable", "upgrade(auth: {:?})", auth); let private_key; let public_key; let cert; match auth { ClientAuthentication::Cert(cert_d, key) => { trace!("----!!!!! CERT AUTHENTICATION !!!!!----"); cert = Some(cert_d); private_key = key; public_key = PublicKey::from(&StaticSecret::from(private_key.to_bytes())); } ClientAuthentication::Ephemeral => { cert = None; let private_key_l: [u8; 32] = OsRng.gen(); let private_key_real = SigningKey::from(private_key_l); public_key = PublicKey::from(&StaticSecret::from(private_key_real.to_bytes())); private_key = Box::new(private_key_real); } } EpfClientUpgraded { inner: self, state: EpfClientState::NotStarted, client_random: OsRng.gen(), server_random: [0u8; 16], client_cert: cert.map(|u| *u), server_cert: None, packet_queue: vec![], cipher: None, private_key: *private_key, public_key, } } } #[async_trait] pub trait EpfClientHandshaker { async fn handshake(&mut self, cert_pool: EpfCaPool) -> Result<(), Box>; async fn upgrade(self) -> EpfClientStream where Self: Sized; } #[async_trait] impl EpfClientHandshaker for EpfClientUpgraded { async fn handshake(&mut self, cert_pool: EpfCaPool) -> Result<(), Box> { match self.state { EpfClientState::NotStarted => (), _ => return Err(EpfHandshakeError::AlreadyTunnelled.into()), } // Step 1: Send Client Hello self.inner .write_all(&encode_packet( PACKET_CLIENT_HELLO, &EpfClientHello { protocol_version: PROTOCOL_VERSION, client_random: self.client_random, client_certificate: self.client_cert.clone(), client_x25519_public_key: self.public_key.to_bytes(), }, )?) .await?; self.inner.flush().await?; trace!("---- !!!!! SENT CLIENT HELLO"); self.state = EpfClientState::WaitingForServerHello; let server_x25519_key; // Step 2: Wait for Server Hello loop { trace!("waiting for server hello"); let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_SERVER_HELLO { self.packet_queue.push(packet); continue; } let server_hello: EpfServerHello = rmp_serde::from_slice(&packet.packet_data)?; self.server_random = server_hello.server_random; if server_hello.protocol_version != PROTOCOL_VERSION { return Err(EpfHandshakeError::UnsupportedProtocolVersion( server_hello.protocol_version as usize, ) .into()); } self.server_cert = Some(server_hello.server_certificate); server_x25519_key = server_hello.server_x25519_public_key; break; } // Step 3: Validate Server Certificate let cert_valid = self.server_cert.as_ref().unwrap().verify(&cert_pool); if let Err(e) = cert_valid { return Err(EpfHandshakeError::InvalidCertificate(e).into()); } if let Ok(false) = cert_valid { return Err(EpfHandshakeError::UntrustedCertificate.into()); } // Server Cert OK // Step 4: Build the cipher let private_key = StaticSecret::from(self.private_key.to_bytes()); let their_public_key = PublicKey::from(server_x25519_key); assert_ne!( their_public_key.to_bytes(), PublicKey::from(&private_key).to_bytes() ); danger_trace!( "pr: {}, their pub: {}, my pub: {}", hex::encode(self.private_key.to_bytes()), hex::encode(self.server_cert.as_ref().unwrap().details.public_key), hex::encode(self.private_key.verifying_key().to_bytes()) ); let shared_key = private_key.diffie_hellman(&their_public_key).to_bytes(); trace!( "server public key: {:x?}", self.server_cert.as_ref().unwrap().details.public_key ); danger_trace!("shared key: {}", hex::encode(shared_key)); let cc20p1305_key = Key::from(shared_key); let cc20p1305 = XChaCha20Poly1305::new(&cc20p1305_key); self.cipher = Some(cc20p1305); let payload = Payload { msg: &[0x42], aad: &self.server_random, }; let nonce = XNonce::from_slice(&self.client_random); trace!("encrypting 0x42"); danger_trace!("aad: {:?} nonce: {:?}", payload.aad, nonce); let encrypted_0x42 = match self.cipher.as_ref().unwrap().encrypt(nonce, payload) { Ok(d) => d, Err(_) => return Err(EpfHandshakeError::EncryptionError.into()), }; self.inner .write_all(&encode_packet( PACKET_FINISHED, &EpfFinished { protocol_version: PROTOCOL_VERSION, encrypted_0x42, }, )?) .await?; self.inner.flush().await?; self.state = EpfClientState::WaitingForFinished; loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_FINISHED { self.packet_queue.push(packet); continue; } let packet_finished: EpfFinished = rmp_serde::from_slice(&packet.packet_data)?; trace!("trying to debug 0x42"); let payload = Payload { msg: &packet_finished.encrypted_0x42, aad: &self.server_random, }; danger_trace!( "ciphertext: {:?}, aad: {:?}, nonce: {:?}", packet_finished.encrypted_0x42, payload.aad, nonce ); let hopefully_0x42 = match self.cipher.as_ref().unwrap().decrypt(nonce, payload) { Ok(d) => d, Err(_) => { return Err(EpfHandshakeError::EncryptionError.into()); } }; if hopefully_0x42 != vec![0x42] { return Err(EpfHandshakeError::MissingKeyProof.into()); } break; } self.state = EpfClientState::Transport; Ok(()) } async fn upgrade(self) -> EpfClientStream where Self: Sized, { let aad = self.server_random; let client_cert = self.client_cert.clone(); let packet_queue = self.packet_queue.clone(); let server_cert = self.server_cert.unwrap(); let cipher = self.cipher.unwrap(); let private_key = self.private_key.clone(); let public_key = self.public_key; let raw_stream = self.inner; EpfClientStream { raw_stream, aad, client_cert, packet_queue, server_cert, cipher, private_key, public_key, } } } #[allow(dead_code)] pub struct EpfClientStream { raw_stream: S, aad: [u8; 16], client_cert: Option, packet_queue: Vec, server_cert: EPFCertificate, cipher: XChaCha20Poly1305, private_key: EpfPrivateKey, public_key: PublicKey, } #[async_trait] pub trait EpfStreamOps { async fn write(&mut self, data: &[u8]) -> Result<(), Box>; async fn read(&mut self) -> Result, Box>; } #[async_trait] impl EpfStreamOps for EpfClientStream { async fn write(&mut self, data: &[u8]) -> Result<(), Box> { let nonce = XChaCha20Poly1305::generate_nonce(OsRng); let payload = Payload { msg: data, aad: &self.aad, }; let ciphertext = match self.cipher.encrypt(&nonce, payload) { Ok(c) => c, Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "Encryption error").into()), }; let application_data = EpfApplicationData { protocol_version: PROTOCOL_VERSION, encrypted_application_data: ciphertext, nonce: nonce.try_into().unwrap(), }; let packet = encode_packet(PACKET_APPLICATION_DATA, &application_data)?; self.raw_stream.write_all(&packet).await?; self.raw_stream.flush().await?; Ok(()) } async fn read(&mut self) -> Result, Box> { loop { let packet = recv_packet(&mut self.raw_stream).await?; if packet.packet_id != PACKET_APPLICATION_DATA { self.packet_queue.push(packet); continue; } let app_data: EpfApplicationData = rmp_serde::from_slice(&packet.packet_data)?; let nonce = XNonce::from_slice(&app_data.nonce); let payload = Payload { msg: &app_data.encrypted_application_data, aad: &self.aad, }; let plaintext = match self.cipher.decrypt(nonce, payload) { Ok(p) => p, Err(_) => { return Err(io::Error::new(io::ErrorKind::Other, "Decryption error").into()) } }; return Ok(plaintext); } } } ///// SERVER ///// pub struct EpfServerUpgraded { inner: T, state: EpfServerState, client_random: [u8; 24], server_random: [u8; 16], client_cert: Option, packet_queue: Vec, cipher: Option, cert: EPFCertificate, private_key: EpfPrivateKey, public_key: EpfPublicKey, } #[async_trait] pub trait EpfServerUpgradable { async fn upgrade( self, cert: EPFCertificate, private_key: EpfPrivateKey, ) -> EpfServerUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send; } #[async_trait] impl EpfServerUpgradable for T where T: AsyncWriteExt + AsyncReadExt + Send, { async fn upgrade( self, cert: EPFCertificate, private_key: EpfPrivateKey, ) -> EpfServerUpgraded where Self: Sized + AsyncWriteExt + AsyncReadExt + Send, { EpfServerUpgraded { inner: self, state: EpfServerState::WaitingForClientHello, server_random: OsRng.gen(), client_random: [0u8; 24], cert, client_cert: None, packet_queue: vec![], cipher: None, private_key: private_key.clone(), public_key: private_key.verifying_key(), } } } #[async_trait] pub trait EpfServerHandshaker { async fn handshake(&mut self, cert_pool: EpfCaPool) -> Result<(), Box>; async fn upgrade(self) -> EpfServerStream where Self: Sized; } #[async_trait] impl EpfServerHandshaker for EpfServerUpgraded { async fn handshake(&mut self, cert_pool: EpfCaPool) -> Result<(), Box> { match self.state { EpfServerState::WaitingForClientHello => (), _ => return Err(EpfHandshakeError::AlreadyTunnelled.into()), } let client_public_key; // Step 1: Wait for Client Hello loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_CLIENT_HELLO { self.packet_queue.push(packet); continue; } trace!("got client hello"); let client_hello: EpfClientHello = rmp_serde::from_slice(&packet.packet_data)?; self.client_random = client_hello.client_random; if client_hello.protocol_version != PROTOCOL_VERSION { return Err(EpfHandshakeError::UnsupportedProtocolVersion( client_hello.protocol_version as usize, ) .into()); } self.client_cert = client_hello.client_certificate; client_public_key = client_hello.client_x25519_public_key; trace!("exiting loop"); break; } // Step 2: Validate Client Certificate (if present) if let Some(client_cert) = &self.client_cert { let cert_valid = client_cert.verify(&cert_pool); if let Err(e) = cert_valid { return Err(EpfHandshakeError::InvalidCertificate(e).into()); } if let Ok(false) = cert_valid { return Err(EpfHandshakeError::UntrustedCertificate.into()); } } // Client Cert OK (if present) trace!("client cert okay"); // Step 3: Send Server Hello self.inner .write_all(&encode_packet( PACKET_SERVER_HELLO, &EpfServerHello { protocol_version: PROTOCOL_VERSION, server_certificate: self.cert.clone(), server_random: self.server_random, server_x25519_public_key: PublicKey::from(&StaticSecret::from( self.private_key.to_bytes(), )) .to_bytes(), }, )?) .await?; self.inner.flush().await?; trace!("sent server hello"); self.state = EpfServerState::WaitingForFinished; // Step 4: Build the cipher let private_key = StaticSecret::from(self.private_key.to_bytes()); let their_public_key = PublicKey::from(client_public_key); assert_ne!( their_public_key.to_bytes(), PublicKey::from(&private_key).to_bytes() ); danger_trace!( "pr: {}, their pub: {}, my pub: {}", hex::encode(self.private_key.to_bytes()), hex::encode(client_public_key), hex::encode(self.private_key.verifying_key().to_bytes()) ); let shared_key = private_key.diffie_hellman(&their_public_key).to_bytes(); trace!("client public key: {:x?}", client_public_key); danger_trace!("shared key: {}", hex::encode(shared_key)); let cc20p1305_key = Key::from(shared_key); let cc20p1305 = XChaCha20Poly1305::new(&cc20p1305_key); self.cipher = Some(cc20p1305); let payload = Payload { msg: &[0x42], aad: &self.server_random, }; let nonce = XNonce::from_slice(&self.client_random); loop { let packet = recv_packet(&mut self.inner).await?; if packet.packet_id != PACKET_FINISHED { self.packet_queue.push(packet); continue; } let packet_finished: EpfFinished = rmp_serde::from_slice(&packet.packet_data)?; let payload = Payload { msg: &packet_finished.encrypted_0x42, aad: &self.server_random, }; trace!("trying to decrypt 0x42"); danger_trace!( "ciphertext: {:?}, nonce: {:?}, aad: {:?}", payload.msg, nonce, payload.aad ); let hopefully_0x42 = match self.cipher.as_ref().unwrap().decrypt(nonce, payload) { Ok(d) => d, Err(_) => { return Err(EpfHandshakeError::EncryptionError.into()); } }; if hopefully_0x42 != vec![0x42] { return Err(EpfHandshakeError::MissingKeyProof.into()); } break; } let encrypted_0x42 = match self.cipher.as_ref().unwrap().encrypt(nonce, payload) { Ok(d) => d, Err(_) => return Err(EpfHandshakeError::EncryptionError.into()), }; self.inner .write_all(&encode_packet( PACKET_FINISHED, &EpfFinished { protocol_version: PROTOCOL_VERSION, encrypted_0x42, }, )?) .await?; self.inner.flush().await?; self.state = EpfServerState::WaitingForFinished; self.state = EpfServerState::Transport; Ok(()) } async fn upgrade(self) -> EpfServerStream where Self: Sized, { EpfServerStream { aad: self.server_random, server_cert: self.cert, packet_queue: self.packet_queue, client_cert: self.client_cert, cipher: self.cipher.unwrap(), private_key: self.private_key, public_key: self.public_key, raw_stream: self.inner, } } } #[allow(dead_code)] pub struct EpfServerStream { raw_stream: S, aad: [u8; 16], client_cert: Option, packet_queue: Vec, server_cert: EPFCertificate, cipher: XChaCha20Poly1305, private_key: EpfPrivateKey, public_key: EpfPublicKey, } #[async_trait] impl EpfStreamOps for EpfServerStream { async fn write(&mut self, data: &[u8]) -> Result<(), Box> { let nonce = XChaCha20Poly1305::generate_nonce(OsRng); let payload = Payload { msg: data, aad: &self.aad, }; let ciphertext = match self.cipher.encrypt(&nonce, payload) { Ok(c) => c, Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "Encryption error").into()), }; let application_data = EpfApplicationData { protocol_version: PROTOCOL_VERSION, encrypted_application_data: ciphertext, nonce: nonce.try_into().unwrap(), }; let packet = encode_packet(PACKET_APPLICATION_DATA, &application_data)?; self.raw_stream.write_all(&packet).await?; self.raw_stream.flush().await?; Ok(()) } async fn read(&mut self) -> Result, Box> { loop { let packet = recv_packet(&mut self.raw_stream).await?; if packet.packet_id != PACKET_APPLICATION_DATA { self.packet_queue.push(packet); continue; } let app_data: EpfApplicationData = rmp_serde::from_slice(&packet.packet_data)?; let nonce = XNonce::from_slice(&app_data.nonce); let payload = Payload { msg: &app_data.encrypted_application_data, aad: &self.aad, }; let plaintext = match self.cipher.decrypt(nonce, payload) { Ok(p) => p, Err(_) => { return Err(io::Error::new(io::ErrorKind::Other, "Decryption error").into()) } }; return Ok(plaintext); } } } #[cfg(test)] mod tests { use crate::ca_pool::{EpfCaPool, EpfCaPoolOps}; use crate::handshake_stream::{ ClientAuthentication, EpfClientHandshaker, EpfClientUpgradable, EpfClientUpgraded, EpfServerHandshaker, EpfServerUpgradable, EpfServerUpgraded, EpfStreamOps, }; use crate::pki::{EPFCertificate, EPFCertificateDetails, EpfPkiCertificateOps}; use ed25519_dalek::{SigningKey}; use log::{debug, trace}; use std::net::SocketAddr; use std::str::FromStr; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::join; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use x25519_dalek::{PublicKey, StaticSecret}; #[tokio::test] pub async fn stream_test() { simple_logger::init().unwrap(); let tcp_listener = TcpListener::bind("0.0.0.0:36116").await.unwrap(); let tcp_client_future = TcpSocket::new_v4() .unwrap() .connect(SocketAddr::from_str("127.0.0.1:36116").unwrap()); let (a, b) = join![tcp_listener.accept(), tcp_client_future]; let (s, _) = a.unwrap(); let c = b.unwrap(); let server_private_key = SigningKey::from([1u8; 32]); let client_private_key = SigningKey::from([2u8; 32]); let mut server_cert = EPFCertificate { details: EPFCertificateDetails { name: "Testing Server Certificate".to_string(), not_before: 0, not_after: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() + 30, public_key: server_private_key.verifying_key().to_bytes(), issuer_public_key: [0u8; 32], claims: Default::default(), }, fingerprint: "".to_string(), signature: [0u8; 64], }; server_cert.sign(&server_private_key).unwrap(); debug!( "{}", hex::encode(server_private_key.verifying_key().to_bytes()) ); let mut client_cert = EPFCertificate { details: EPFCertificateDetails { name: "Testing Client Certificate".to_string(), not_before: 0, not_after: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() + 30, public_key: client_private_key.verifying_key().to_bytes(), issuer_public_key: [0u8; 32], claims: Default::default(), }, fingerprint: "".to_string(), signature: [0u8; 64], }; client_cert.sign(&client_private_key).unwrap(); let mut cert_pool = EpfCaPool::new(); let mut cert_pool_2 = EpfCaPool::new(); cert_pool.insert(&server_cert); cert_pool.insert(&client_cert); cert_pool_2.insert(&client_cert); cert_pool_2.insert(&server_cert); let mut c: EpfClientUpgraded = EpfClientUpgradable::upgrade( c, ClientAuthentication::Cert(Box::new(client_cert), Box::new(client_private_key)), ) .await; let mut s: EpfServerUpgraded = EpfServerUpgradable::upgrade(s, server_cert, server_private_key).await; let server_handshake_accept_task = tokio::spawn(async move { trace!("starting server handshake listener"); s.handshake(cert_pool_2).await.unwrap(); let mut upgraded = s.upgrade().await; assert_eq!(upgraded.read().await.unwrap(), vec![0x42, 0x42]) }); let client_handshake_send_task = tokio::spawn(async move { trace!("starting client handshake sender"); c.handshake(cert_pool).await.unwrap(); let mut upgraded = EpfClientHandshaker::upgrade(c).await; upgraded.write(&[0x42, 0x42]).await.unwrap(); }); let (a, b) = join![server_handshake_accept_task, client_handshake_send_task]; a.unwrap(); b.unwrap(); } #[test] pub fn x25519_sanity_check() { let bob_key = StaticSecret::from([1u8; 32]); let bob_pub = PublicKey::from(&bob_key); let alice_key = StaticSecret::from([2u8; 32]); let alice_pub = PublicKey::from(&alice_key); let ss_1 = bob_key.diffie_hellman(&alice_pub); let ss_2 = alice_key.diffie_hellman(&bob_pub); assert_eq!(ss_1.to_bytes(), ss_2.to_bytes()); println!( "SS: {}, B_p: {}, A_p: {}", hex::encode(ss_1.to_bytes()), hex::encode(bob_pub.to_bytes()), hex::encode(alice_pub.to_bytes()) ); } }