//! Structs to represent a pool of CA's and blacklisted certificates use std::collections::HashMap; use std::error::Error; use std::fmt::{Display, Formatter}; use std::time::SystemTime; use ed25519_dalek::VerifyingKey; use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate}; /// A pool of trusted CA certificates, and certificates that should be blocked. /// This is equivalent to the `pki` section in a typical Nebula config.yml. #[derive(Default)] pub struct NebulaCAPool { /// The list of CA root certificates that should be trusted. pub cas: HashMap, /// The list of blocklisted certificate fingerprints pub cert_blocklist: Vec, /// True if any of the member CAs certificates are expired. Must be handled. pub expired: bool } impl NebulaCAPool { /// Create a new, blank CA pool pub fn new() -> Self { Self::default() } /// Create a new CA pool from a set of PEM encoded CA certificates. /// If any of the certificates are expired, the pool will **still be returned**, with the expired flag set. /// This must be handled properly. /// # Errors /// This function will return an error if PEM data provided was invalid. pub fn new_from_pem(bytes: &[u8]) -> Result> { let pems = pem::parse_many(bytes)?; let mut pool = Self::new(); for cert in pems { match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) { Ok(did_expire) => if did_expire { pool.expired = true }, Err(e) => return Err(e) } } Ok(pool) } /// Add a given CA certificate to the CA pool. If the certificate is expired, it will **still be added** - the return value will be `true` instead of `false` /// # Errors /// This function will return an error if the certificate is invalid in any way. pub fn add_ca_certificate(&mut self, bytes: &[u8]) -> Result> { let cert = deserialize_nebula_certificate_from_pem(bytes)?; if !cert.details.is_ca { return Err(CaPoolError::NotACA.into()) } if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? { return Err(CaPoolError::NotSelfSigned.into()) } let fingerprint = cert.sha256sum()?; let expired = cert.expired(SystemTime::now()); if expired { self.expired = true } self.cas.insert(fingerprint, cert); Ok(expired) } /// Blocklist the given certificate in the CA pool pub fn blocklist_fingerprint(&mut self, fingerprint: &str) { self.cert_blocklist.push(fingerprint.to_string()); } /// Clears the list of blocklisted fingerprints pub fn reset_blocklist(&mut self) { self.cert_blocklist = vec![]; } /// Checks if the given certificate is blocklisted pub fn is_blocklisted(&self, cert: &NebulaCertificate) -> bool { let Ok(h) = cert.sha256sum() else { return false }; self.cert_blocklist.contains(&h) } /// Gets the CA certificate used to sign the given certificate /// # Errors /// This function will return an error if the certificate does not have an issuer attached (it is self-signed) pub fn get_ca_for_cert(&self, cert: &NebulaCertificate) -> Result, Box> { if cert.details.issuer == String::new() { return Err(CaPoolError::NoIssuer.into()) } Ok(self.cas.get(&cert.details.issuer)) } /// Get a list of trusted CA fingerprints pub fn get_fingerprints(&self) -> Vec<&String> { self.cas.keys().collect() } } #[derive(Debug)] /// A list of errors that can happen when working with a CA Pool pub enum CaPoolError { /// Tried to add a non-CA cert to the CA pool NotACA, /// Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates) NotSelfSigned, /// Tried to look up a certificate that does not have an issuer field NoIssuer } impl Error for CaPoolError {} #[cfg(not(tarpaulin_include))] impl Display for CaPoolError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::NotACA => write!(f, "Tried to add a non-CA cert to the CA pool"), Self::NotSelfSigned => write!(f, "Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)"), Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field") } } }