some work on config generation

This commit is contained in:
core 2023-12-25 23:23:57 -05:00
parent f7de7ff592
commit 13dec2963c
Signed by: core
GPG Key ID: FDBF740DADDCEECF
12 changed files with 830 additions and 7 deletions

10
Cargo.lock generated
View File

@ -1831,6 +1831,14 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "nebula-config"
version = "0.1.0"
dependencies = [
"ipnet",
"serde",
]
[[package]]
name = "nebula-ffi"
version = "1.7.3"
@ -3088,8 +3096,10 @@ dependencies = [
"dnapi-rs",
"env_logger",
"hex",
"ipnet",
"log",
"mail-send",
"nebula-config",
"rand",
"serde",
"serde_json",

View File

@ -5,7 +5,7 @@ members = [
"dnapi-rs",
"tfcli",
"nebula-ffi",
"nebula-config",
"trifid-api",
"trifid-api-derive"
]

10
nebula-config/Cargo.toml Normal file
View File

@ -0,0 +1,10 @@
[package]
name = "nebula-config"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serde = { version = "1", features = ["derive"] }
ipnet = { version = "2.9", features = ["serde"] }

532
nebula-config/src/lib.rs Normal file
View File

@ -0,0 +1,532 @@
use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddrV4};
use ipnet::{IpNet, Ipv4Net};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfig {
pub pki: NebulaConfigPki,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub static_host_map: HashMap<Ipv4Addr, Vec<SocketAddrV4>>,
#[serde(skip_serializing_if = "is_none")]
pub lighthouse: Option<NebulaConfigLighthouse>,
#[serde(skip_serializing_if = "is_none")]
pub listen: Option<NebulaConfigListen>,
#[serde(skip_serializing_if = "is_none")]
pub punchy: Option<NebulaConfigPunchy>,
#[serde(default = "cipher_aes")]
#[serde(skip_serializing_if = "is_cipher_aes")]
pub cipher: NebulaConfigCipher,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub preferred_ranges: Vec<IpNet>,
#[serde(skip_serializing_if = "is_none")]
pub relay: Option<NebulaConfigRelay>,
#[serde(skip_serializing_if = "is_none")]
pub tun: Option<NebulaConfigTun>,
#[serde(skip_serializing_if = "is_none")]
pub logging: Option<NebulaConfigLogging>,
#[serde(skip_serializing_if = "is_none")]
pub sshd: Option<NebulaConfigSshd>,
#[serde(skip_serializing_if = "is_none")]
pub firewall: Option<NebulaConfigFirewall>,
#[serde(default = "u64_1")]
#[serde(skip_serializing_if = "is_u64_1")]
pub routines: u64,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub stats: Option<NebulaConfigStats>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub local_range: Option<Ipv4Net>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigPki {
pub ca: String,
pub cert: String,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub key: Option<String>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub blocklist: Vec<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disconnect_invalid: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLighthouse {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub am_lighthouse: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub serve_dns: bool,
#[serde(skip_serializing_if = "is_none")]
pub dns: Option<NebulaConfigLighthouseDns>,
#[serde(default = "u32_10")]
#[serde(skip_serializing_if = "is_u32_10")]
pub interval: u32,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub hosts: Vec<Ipv4Addr>,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub remote_allow_list: HashMap<Ipv4Net, bool>,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub local_allow_list: HashMap<Ipv4Net, bool>, // `interfaces` is not supported
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLighthouseDns {
#[serde(default = "string_empty")]
#[serde(skip_serializing_if = "is_string_empty")]
pub host: String,
#[serde(default = "u16_53")]
#[serde(skip_serializing_if = "is_u16_53")]
pub port: u16,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigListen {
#[serde(default = "string_empty")]
#[serde(skip_serializing_if = "is_string_empty")]
pub host: String,
#[serde(default = "u16_0")]
#[serde(skip_serializing_if = "is_u16_0")]
pub port: u16,
#[serde(default = "u32_64")]
#[serde(skip_serializing_if = "is_u32_64")]
pub batch: u32,
#[serde(skip_serializing_if = "is_none")]
pub read_buffer: Option<u32>,
#[serde(skip_serializing_if = "is_none")]
pub write_buffer: Option<u32>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigPunchy {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub punch: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub respond: bool,
#[serde(default = "string_1s")]
#[serde(skip_serializing_if = "is_string_1s")]
pub delay: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigCipher {
#[serde(rename = "aes")]
Aes,
#[serde(rename = "chachapoly")]
ChaChaPoly,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigRelay {
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub relays: Vec<Ipv4Addr>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub am_relay: bool,
#[serde(default = "bool_true")]
#[serde(skip_serializing_if = "is_bool_true")]
pub use_relays: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTun {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disabled: bool,
#[serde(skip_serializing_if = "is_none")]
pub dev: Option<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub drop_local_broadcast: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub drop_multicast: bool,
#[serde(default = "u64_500")]
#[serde(skip_serializing_if = "is_u64_500")]
pub tx_queue: u64,
#[serde(default = "u64_1300")]
#[serde(skip_serializing_if = "is_u64_1300")]
pub mtu: u64,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub routes: Vec<NebulaConfigTunRouteOverride>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunRouteOverride {
pub mtu: u64,
pub route: Ipv4Net,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunUnsafeRoute {
pub route: Ipv4Net,
pub via: Ipv4Addr,
#[serde(default = "u64_1300")]
#[serde(skip_serializing_if = "is_u64_1300")]
pub mtu: u64,
#[serde(default = "i64_100")]
#[serde(skip_serializing_if = "is_i64_100")]
pub metric: i64,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLogging {
#[serde(default = "loglevel_info")]
#[serde(skip_serializing_if = "is_loglevel_info")]
pub level: NebulaConfigLoggingLevel,
#[serde(default = "format_text")]
#[serde(skip_serializing_if = "is_format_text")]
pub format: NebulaConfigLoggingFormat,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disable_timestamp: bool,
#[serde(default = "timestamp")]
#[serde(skip_serializing_if = "is_timestamp")]
pub timestamp_format: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigLoggingLevel {
#[serde(rename = "panic")]
Panic,
#[serde(rename = "fatal")]
Fatal,
#[serde(rename = "error")]
Error,
#[serde(rename = "warning")]
Warning,
#[serde(rename = "info")]
Info,
#[serde(rename = "debug")]
Debug,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigLoggingFormat {
#[serde(rename = "json")]
Json,
#[serde(rename = "text")]
Text,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigSshd {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub enabled: bool,
pub listen: SocketAddrV4,
pub host_key: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigSshdAuthorizedUser {
pub user: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub keys: Vec<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "type")]
pub enum NebulaConfigStats {
#[serde(rename = "graphite")]
Graphite(NebulaConfigStatsGraphite),
#[serde(rename = "prometheus")]
Prometheus(NebulaConfigStatsPrometheus),
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigStatsGraphite {
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub prefix: String,
#[serde(default = "protocol_tcp")]
#[serde(skip_serializing_if = "is_protocol_tcp")]
pub protocol: NebulaConfigStatsGraphiteProtocol,
pub host: SocketAddrV4,
pub interval: String,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigStatsGraphiteProtocol {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "udp")]
Udp,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigStatsPrometheus {
pub listen: String,
pub path: String,
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub namespace: String,
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub subsystem: String,
pub interval: String,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewall {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub conntrack: Option<NebulaConfigFirewallConntrack>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub inbound: Option<Vec<NebulaConfigFirewallRule>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub outbound: Option<Vec<NebulaConfigFirewallRule>>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewallConntrack {
#[serde(default = "string_12m")]
#[serde(skip_serializing_if = "is_string_12m")]
pub tcp_timeout: String,
#[serde(default = "string_3m")]
#[serde(skip_serializing_if = "is_string_3m")]
pub udp_timeout: String,
#[serde(default = "string_10m")]
#[serde(skip_serializing_if = "is_string_10m")]
pub default_timeout: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewallRule {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub port: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub proto: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub ca_name: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub ca_sha: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub host: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub group: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub groups: Option<Vec<String>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub cidr: Option<String>,
}
// Default values for serde
fn string_12m() -> String {
"12m".to_string()
}
fn is_string_12m(s: &str) -> bool {
s == "12m"
}
fn string_3m() -> String {
"3m".to_string()
}
fn is_string_3m(s: &str) -> bool {
s == "3m"
}
fn string_10m() -> String {
"10m".to_string()
}
fn is_string_10m(s: &str) -> bool {
s == "10m"
}
fn empty_vec<T>() -> Vec<T> {
vec![]
}
fn is_empty_vec<T>(v: &Vec<T>) -> bool {
v.is_empty()
}
fn empty_hashmap<A, B>() -> HashMap<A, B> {
HashMap::new()
}
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool {
h.is_empty()
}
fn bool_false() -> bool {
false
}
fn is_bool_false(b: &bool) -> bool {
!*b
}
fn bool_true() -> bool {
true
}
fn is_bool_true(b: &bool) -> bool {
*b
}
fn u16_53() -> u16 {
53
}
fn is_u16_53(u: &u16) -> bool {
*u == 53
}
fn u32_10() -> u32 {
10
}
fn is_u32_10(u: &u32) -> bool {
*u == 10
}
fn u16_0() -> u16 {
0
}
fn is_u16_0(u: &u16) -> bool {
*u == 0
}
fn u32_64() -> u32 {
64
}
fn is_u32_64(u: &u32) -> bool {
*u == 64
}
fn string_1s() -> String {
"1s".to_string()
}
fn is_string_1s(s: &str) -> bool {
s == "1s"
}
fn cipher_aes() -> NebulaConfigCipher {
NebulaConfigCipher::Aes
}
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool {
matches!(c, NebulaConfigCipher::Aes)
}
fn u64_500() -> u64 {
500
}
fn is_u64_500(u: &u64) -> bool {
*u == 500
}
fn u64_1300() -> u64 {
1300
}
fn is_u64_1300(u: &u64) -> bool {
*u == 1300
}
fn i64_100() -> i64 {
100
}
fn is_i64_100(i: &i64) -> bool {
*i == 100
}
fn loglevel_info() -> NebulaConfigLoggingLevel {
NebulaConfigLoggingLevel::Info
}
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool {
matches!(l, NebulaConfigLoggingLevel::Info)
}
fn format_text() -> NebulaConfigLoggingFormat {
NebulaConfigLoggingFormat::Text
}
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool {
matches!(f, NebulaConfigLoggingFormat::Text)
}
fn timestamp() -> String {
"2006-01-02T15:04:05Z07:00".to_string()
}
fn is_timestamp(s: &str) -> bool {
s == "2006-01-02T15:04:05Z07:00"
}
fn u64_1() -> u64 {
1
}
fn is_u64_1(u: &u64) -> bool {
*u == 1
}
fn string_nebula() -> String {
"nebula".to_string()
}
fn is_string_nebula(s: &str) -> bool {
s == "nebula"
}
fn string_empty() -> String {
String::new()
}
fn is_string_empty(s: &str) -> bool {
s.is_empty()
}
fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol {
NebulaConfigStatsGraphiteProtocol::Tcp
}
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool {
matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp)
}
fn none<T>() -> Option<T> {
None
}
fn is_none<T>(o: &Option<T>) -> bool {
o.is_none()
}

View File

@ -31,4 +31,6 @@ chacha20poly1305 = "0.10"
hex = "0.4"
thiserror = "1"
chrono = "0.4"
dnapi-rs = { version = "0.2", path = "../dnapi-rs" }
dnapi-rs = { version = "0.2", path = "../dnapi-rs" }
nebula-config = { version = "0.1", path = "../nebula-config" }
ipnet = "2.9"

View File

@ -50,3 +50,6 @@ auth_token_expiry_seconds = 86400 # 24 hours
# It is INCREDIBLY IMPORTANT that you change this value! It should be a 32-byte/256-bit hex-encoded randomly generated
# key.
data_encryption_key = "dd5aa62f0fd9b7fb4ff65567493f889557212f3a8e9587a79268161f9ae070a6"
# (Required) How long should client certs be valid for?
# Clients will require a config update after this period.
cert_expiry_time_seconds = 31536000 # ~1 year

View File

@ -44,4 +44,5 @@ pub struct ConfigTokens {
pub session_token_expiry_seconds: u64,
pub auth_token_expiry_seconds: u64,
pub data_encryption_key: String,
pub cert_expiry_time_seconds: u64
}

View File

@ -0,0 +1,102 @@
// The main event! This file handles generation of actual config files.
// This file is part of the sensitive stack, along with `crypt.rs`
// and the entirety of trifid-pki, as it deals with CA private keys.
// Review carefully what you write here!
use std::error::Error;
use std::net::Ipv4Addr;
use std::str::FromStr;
use std::time::{Duration, SystemTime};
use actix_web::web::Data;
use diesel::{ExpressionMethods, QueryDsl, SelectableHelper};
use diesel_async::pooled_connection::bb8::RunError;
use diesel_async::RunQueryDsl;
use ipnet::Ipv4Net;
use thiserror::Error;
use nebula_config::{NebulaConfig, NebulaConfigPki};
use trifid_pki::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate, NebulaCertificateDetails};
use trifid_pki::x25519_dalek::PublicKey;
use crate::AppState;
use crate::crypt::sign_cert_with_ca;
use crate::models::{Host, Network, SigningCA, HostKey};
use crate::schema::{host_keys, hosts, networks, signing_cas};
#[derive(Error, Debug)]
pub enum ConfigGenError {
#[error("error acquiring connection from pool: {0}")]
AcquireError(RunError),
#[error("error in database: {0}")]
DbError(diesel::result::Error),
#[error("error parsing a signing CA: {0}")]
InvalidCACert(serde_json::Error),
#[error("an error occured: {0}")]
GenericError(Box<dyn Error>)
}
pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data<AppState>) -> Result<NebulaConfig, ConfigGenError> {
let mut conn = state.pool.get().await.map_err(ConfigGenError::AcquireError)?;
let cas = signing_cas::dsl::signing_cas.filter(signing_cas::organization_id.eq(&host.organization_id)).select(SigningCA::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?;
let mut good_cas = vec![];
let mut ca_string = String::new();
for ca in cas {
if ca.expires_at < SystemTime::now() {
let ca_cert: NebulaCertificate = serde_json::from_value(ca.cert.clone()).map_err(ConfigGenError::InvalidCACert)?;
good_cas.push((ca, ca_cert));
ca_string += &String::from_utf8_lossy(&ca_cert.serialize_to_pem().map_err(ConfigGenError::GenericError)?);
}
}
let (signing_ca, ca_cert) = &good_cas[0];
let network = networks::dsl::networks.find(&host.network_id).first::<Network>(&mut conn).await.map_err(ConfigGenError::DbError)?;
let mut cert = NebulaCertificate {
details: NebulaCertificateDetails {
name: host.name.clone(),
ips: vec![
Ipv4Net::new(
Ipv4Addr::from_str(&host.ip_address).unwrap(),
Ipv4Net::from_str(&network.cidr).unwrap().prefix_len()
).unwrap()
],
subnets: vec![],
groups: if let Some(role_id) = &host.role_id {
vec![format!("role:{}", role_id)]
} else { vec![] },
not_before: SystemTime::now() - Duration::from_secs(3600),
not_after: SystemTime::now() + Duration::from_secs(state.config.tokens.cert_expiry_time_seconds),
public_key: *dh_pubkey.as_bytes(),
is_ca: false,
issuer: ca_cert.sha256sum().unwrap(),
},
signature: vec![],
};
sign_cert_with_ca(signing_ca, &mut cert, &state.config).unwrap();
let all_blocked_hosts = hosts::dsl::hosts.filter(hosts::network_id.eq(&host.network_id)).filter(hosts::is_blocked.eq(true)).select(Host::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?;
let mut all_blocked_fingerprints = vec![];
for blocked_host in all_blocked_hosts {
let hosts_blocked_key_entries = host_keys::dsl::host_keys.filter(host_keys::host_id.eq(&blocked_host.id)).select(HostKey::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?;
for blocked_key in hosts_blocked_key_entries {
let cert = deserialize_nebula_certificate_from_pem(&blocked_key.client_cert).unwrap();
let fingerprint = cert.sha256sum().unwrap();
all_blocked_fingerprints.push(fingerprint);
}
}
let pki = NebulaConfigPki {
ca: ca_string,
cert: String::from_utf8(cert.serialize_to_pem().unwrap()).unwrap(),
key: None,
blocklist: all_blocked_fingerprints,
disconnect_invalid: true,
};
todo!()
}

View File

@ -2,7 +2,7 @@ use crate::config::Config;
use crate::models::SigningCA;
use actix_web::cookie::time::Duration;
use chacha20poly1305::aead::{Aead, Payload};
use chacha20poly1305::{AeadCore, KeyInit, Nonce, XChaCha20Poly1305, XNonce};
use chacha20poly1305::{AeadCore, KeyInit, XChaCha20Poly1305, XNonce};
use log::error;
use rand::rngs::OsRng;
use rand::Rng;
@ -10,7 +10,7 @@ use std::error::Error;
use std::time::SystemTime;
use thiserror::Error;
use trifid_pki::cert::{NebulaCertificate, NebulaCertificateDetails};
use trifid_pki::ed25519_dalek::{SignatureError, SigningKey};
use trifid_pki::ed25519_dalek::{Signature, SignatureError, Signer, SigningKey, VerifyingKey};
#[derive(Error, Debug)]
pub enum CryptographyError {
@ -139,3 +139,80 @@ pub fn sign_cert_with_ca(
cert.sign(&key)
.map_err(|e| CryptographyError::CertificateSigningError(e))
}
pub struct DnclientKeyLockbox {
pub info: Vec<u8>,
pub nonce: Vec<u8>,
pub key: Vec<u8>
}
pub fn create_dnclient_ed_key(
config: &Config,
) -> Result<(DnclientKeyLockbox, VerifyingKey), CryptographyError> {
let key = SigningKey::generate(&mut OsRng);
let lockbox_key = XChaCha20Poly1305::new_from_slice(
&hex::decode(&config.tokens.data_encryption_key)
.map_err(|e| CryptographyError::InvalidKey(e))?,
)
.map_err(|_| CryptographyError::InvalidKeyLength)?;
let salt = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let aad: [u8; 16] = OsRng.gen();
let lockbox = lockbox_key
.encrypt(
&salt,
Payload {
msg: &key.to_keypair_bytes(),
aad: &aad,
},
)
.map_err(|e| CryptographyError::LockingError)?;
Ok((DnclientKeyLockbox {
info: aad.to_vec(),
nonce: salt.as_slice().to_vec(),
key: lockbox,
}, key.verifying_key()))
}
pub fn sign_dnclient_with_lockbox(
lockbox: &DnclientKeyLockbox,
bytes: &[u8],
config: &Config,
) -> Result<Signature, CryptographyError> {
let lockbox_key = XChaCha20Poly1305::new_from_slice(
&hex::decode(&config.tokens.data_encryption_key)
.map_err(|e| CryptographyError::InvalidKey(e))?,
)
.map_err(|_| CryptographyError::InvalidKeyLength)?;
let salt_u24: [u8; 24] = lockbox
.nonce
.clone()
.try_into()
.map_err(|_| CryptographyError::InvalidSaltLength)?;
let salt = XNonce::from(salt_u24);
let plaintext = lockbox_key
.decrypt(
&salt,
Payload {
msg: &lockbox.key,
aad: &lockbox.info,
},
)
.map_err(|_| CryptographyError::DecryptFailed)?;
let key = SigningKey::from_keypair_bytes(
&plaintext
.try_into()
.map_err(|_| CryptographyError::InvalidSigningKeyLength)?,
)
.map_err(|e| CryptographyError::SignatureError(e))?;
Ok(key.sign(bytes))
}

View File

@ -26,7 +26,8 @@ pub mod auth;
pub mod email;
#[macro_use]
pub mod macros;
pub mod ca;
pub mod crypt;
mod config_generator;
#[derive(Clone)]
pub struct AppState {

View File

@ -1,4 +1,4 @@
use crate::ca::create_signing_ca;
use crate::crypt::create_signing_ca;
use crate::models::{Network, NetworkNormalized, Organization, SigningCA, User};
use crate::response::JsonAPIResponse;
use crate::schema::networks::dsl::networks;

View File

@ -1,7 +1,17 @@
use std::time::SystemTime;
use crate::response::JsonAPIResponse;
use crate::AppState;
use crate::{AppState, randid};
use actix_web::web::{Data, Json};
use actix_web::http::StatusCode;
use dnapi_rs::message::{EnrollRequest, EnrollResponse};
use crate::models::{EnrollmentCode, Host, HostKey};
use crate::schema::{enrollment_codes, hosts};
use diesel::{QueryDsl, OptionalExtension, ExpressionMethods};
use diesel_async::RunQueryDsl;
use trifid_pki::x25519_dalek::PublicKey;
use crate::config_generator::generate_config;
use crate::crypt::create_dnclient_ed_key;
use crate::schema::host_keys;
pub async fn enroll_req(
req: Json<EnrollRequest>,
@ -9,5 +19,80 @@ pub async fn enroll_req(
) -> JsonAPIResponse<EnrollResponse> {
let mut conn = handle_error!(state.pool.get().await);
let token: EnrollmentCode = match handle_error!(enrollment_codes::table
.find(&req.code)
.first::<EnrollmentCode>(&mut conn)
.await
.optional())
{
Some(t) => t,
None => {
err!(
StatusCode::BAD_REQUEST,
make_err!(
"ERR_INVALID_VALUE",
"does not exist (maybe it expired?)",
"magicLinkToken"
)
)
}
};
if token.expires < SystemTime::now() {
err!(
StatusCode::BAD_REQUEST,
make_err!(
"ERR_INVALID_VALUE",
"does not exist (maybe it expired?)",
"magicLinkToken"
)
);
}
// valid token
let host: Host = match handle_error!(hosts::table
.find(&token.host_id)
.first::<Host>(&mut conn)
.await
.optional())
{
Some(t) => t,
None => {
err!(
StatusCode::BAD_REQUEST,
make_err!(
"ERR_INVALID_VALUE",
"does not exist (maybe it expired?)",
"magicLinkToken.host_id"
)
)
}
};
handle_error!(diesel::delete(&token).execute(&mut conn).await);
// reset the host's key entries
handle_error!(diesel::delete(host_keys::dsl::host_keys.filter(host_keys::host_id.eq(&host.id))).execute(&mut conn).await);
let (key_lockbox, trusted_key) = handle_error!(create_dnclient_ed_key(&state.config));
let user_dh_key = PublicKey::from(req.dh_pubkey.try_into().unwrap()).unwrap();
let config = handle_error!(generate_config(&host, state.clone()).await);
let new_counter_1 = HostKey {
id: randid!(id "hostkey"),
host_id: host.id.clone(),
counter: 1,
client_ed_pub: vec![],
client_dh_pub: vec![],
client_cert: vec![],
salt: vec![],
info: vec![],
server_ed_priv: vec![],
};
todo!()
}