trifid/trifid-pki/src/ca.rs

140 lines
4.9 KiB
Rust

//! Structs to represent a pool of CA's and blacklisted certificates
use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
use ed25519_dalek::VerifyingKey;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::time::SystemTime;
#[cfg(feature = "serde_derive")]
use serde::{Deserialize, Serialize};
/// A pool of trusted CA certificates, and certificates that should be blocked.
/// This is equivalent to the `pki` section in a typical Nebula config.yml.
#[derive(Default, Clone)]
#[cfg_attr(feature = "serde_derive", derive(Serialize, Deserialize))]
pub struct NebulaCAPool {
/// The list of CA root certificates that should be trusted.
pub cas: HashMap<String, NebulaCertificate>,
/// The list of blocklisted certificate fingerprints
pub cert_blocklist: Vec<String>,
/// True if any of the member CAs certificates are expired. Must be handled.
pub expired: bool,
}
impl NebulaCAPool {
/// Create a new, blank CA pool
pub fn new() -> Self {
Self::default()
}
/// Create a new CA pool from a set of PEM encoded CA certificates.
/// If any of the certificates are expired, the pool will **still be returned**, with the expired flag set.
/// This must be handled properly.
/// # Errors
/// This function will return an error if PEM data provided was invalid.
pub fn new_from_pem(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
let pems = pem::parse_many(bytes)?;
let mut pool = Self::new();
for cert in pems {
match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) {
Ok(did_expire) => {
if did_expire {
pool.expired = true;
}
}
Err(e) => return Err(e),
}
}
Ok(pool)
}
/// Add a given CA certificate to the CA pool. If the certificate is expired, it will **still be added** - the return value will be `true` instead of `false`
/// # Errors
/// This function will return an error if the certificate is invalid in any way.
pub fn add_ca_certificate(&mut self, bytes: &[u8]) -> Result<bool, Box<dyn Error>> {
let cert = deserialize_nebula_certificate_from_pem(bytes)?;
if !cert.details.is_ca {
return Err(CaPoolError::NotACA.into());
}
if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? {
return Err(CaPoolError::NotSelfSigned.into());
}
let fingerprint = cert.sha256sum()?;
let expired = cert.expired(SystemTime::now());
if expired {
self.expired = true;
}
self.cas.insert(fingerprint, cert);
Ok(expired)
}
/// Blocklist the given certificate in the CA pool
pub fn blocklist_fingerprint(&mut self, fingerprint: &str) {
self.cert_blocklist.push(fingerprint.to_string());
}
/// Clears the list of blocklisted fingerprints
pub fn reset_blocklist(&mut self) {
self.cert_blocklist = vec![];
}
/// Checks if the given certificate is blocklisted
pub fn is_blocklisted(&self, cert: &NebulaCertificate) -> bool {
let Ok(h) = cert.sha256sum() else { return false };
self.cert_blocklist.contains(&h)
}
/// Gets the CA certificate used to sign the given certificate
/// # Errors
/// This function will return an error if the certificate does not have an issuer attached (it is self-signed)
pub fn get_ca_for_cert(
&self,
cert: &NebulaCertificate,
) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> {
if cert.details.issuer == String::new() {
return Err(CaPoolError::NoIssuer.into());
}
Ok(self.cas.get(&cert.details.issuer))
}
/// Get a list of trusted CA fingerprints
pub fn get_fingerprints(&self) -> Vec<&String> {
self.cas.keys().collect()
}
}
/// A list of errors that can happen when working with a CA Pool
#[derive(Debug)]
#[cfg_attr(feature = "serde_derive", derive(Serialize, Deserialize))]
pub enum CaPoolError {
/// Tried to add a non-CA cert to the CA pool
NotACA,
/// Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)
NotSelfSigned,
/// Tried to look up a certificate that does not have an issuer field
NoIssuer,
}
impl Error for CaPoolError {}
#[cfg(not(tarpaulin_include))]
impl Display for CaPoolError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotACA => write!(f, "Tried to add a non-CA cert to the CA pool"),
Self::NotSelfSigned => write!(f, "Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)"),
Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field")
}
}
}