// trifid-api, an open source reimplementation of the Defined Networking nebula management server.
// Copyright (C) 2023 c0repwn3r
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
use std::fmt::{Display, Formatter};
use crate::format::PEMValidationError::{IncorrectSegmentLength, InvalidBase64Data, MissingStartSentinel};
use crate::util::base64decode;
pub const ED_PUBKEY_START_STR: &str = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----";
pub const ED_PUBKEY_END_STR: &str = "-----END NEBULA ED25519 PUBLIC KEY-----";
pub const DH_PUBKEY_START_STR: &str = "-----BEGIN NEBULA X25519 PUBLIC KEY-----";
pub const DH_PUBKEY_END_STR: &str = "-----END NEBULA X25519 PUBLIC KEY-----";
pub enum PEMValidationError {
MissingStartSentinel,
InvalidBase64Data,
MissingEndSentinel,
IncorrectSegmentLength(usize, usize)
}
impl Display for PEMValidationError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingEndSentinel => write!(f, "Missing ending sentinel"),
Self::MissingStartSentinel => write!(f, "Missing starting sentinel"),
Self::InvalidBase64Data => write!(f, "invalid base64 data"),
Self::IncorrectSegmentLength(expected, got) => write!(f, "incorrect number of segments, expected {} got {}", expected, got)
}
}
}
pub fn validate_ed_pubkey_pem(pubkey: &str) -> Result<(), PEMValidationError> {
let segments = pubkey.split('\n');
let segs = segments.collect::>();
if segs.len() < 3 {
return Err(IncorrectSegmentLength(3, segs.len()))
}
if segs[0] != ED_PUBKEY_START_STR {
return Err(MissingStartSentinel)
}
if base64decode(segs[1]).is_err() {
return Err(InvalidBase64Data)
}
if segs[2] != ED_PUBKEY_END_STR {
return Err(MissingStartSentinel)
}
Ok(())
}
pub fn validate_dh_pubkey_pem(pubkey: &str) -> Result<(), PEMValidationError> {
let segments = pubkey.split('\n');
let segs = segments.collect::>();
if segs.len() < 3 {
return Err(IncorrectSegmentLength(3, segs.len()))
}
if segs[0] != DH_PUBKEY_START_STR {
return Err(MissingStartSentinel)
}
if base64decode(segs[1]).is_err() {
return Err(InvalidBase64Data)
}
if segs[2] != DH_PUBKEY_END_STR {
return Err(MissingStartSentinel)
}
Ok(())
}
pub fn validate_ed_pubkey_base64(pubkey: &str) -> Result<(), PEMValidationError> {
match base64decode(pubkey) {
Ok(k) => validate_ed_pubkey_pem(match std::str::from_utf8(k.as_ref()) {
Ok(k) => k,
Err(_) => return Err(InvalidBase64Data)
}),
Err(_) => Err(InvalidBase64Data)
}
}
pub fn validate_dh_pubkey_base64(pubkey: &str) -> Result<(), PEMValidationError> {
match base64decode(pubkey) {
Ok(k) => validate_dh_pubkey_pem(match std::str::from_utf8(k.as_ref()) {
Ok(k) => k,
Err(_) => return Err(InvalidBase64Data)
}),
Err(_) => Err(InvalidBase64Data)
}
}