From 13dec2963ced55a81a979625f01042f202dd25b6 Mon Sep 17 00:00:00 2001 From: core Date: Mon, 25 Dec 2023 23:23:57 -0500 Subject: [PATCH] some work on config generation --- Cargo.lock | 10 + Cargo.toml | 2 +- nebula-config/Cargo.toml | 10 + nebula-config/src/lib.rs | 532 +++++++++++++++++++++++++++ trifid-api/Cargo.toml | 4 +- trifid-api/config.toml | 3 + trifid-api/src/config.rs | 1 + trifid-api/src/config_generator.rs | 102 +++++ trifid-api/src/{ca.rs => crypt.rs} | 81 +++- trifid-api/src/main.rs | 3 +- trifid-api/src/routes/v1/networks.rs | 2 +- trifid-api/src/routes/v2/enroll.rs | 87 ++++- 12 files changed, 830 insertions(+), 7 deletions(-) create mode 100644 nebula-config/Cargo.toml create mode 100644 nebula-config/src/lib.rs create mode 100644 trifid-api/src/config_generator.rs rename trifid-api/src/{ca.rs => crypt.rs} (65%) diff --git a/Cargo.lock b/Cargo.lock index bfc5f76..47f537a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1831,6 +1831,14 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "nebula-config" +version = "0.1.0" +dependencies = [ + "ipnet", + "serde", +] + [[package]] name = "nebula-ffi" version = "1.7.3" @@ -3088,8 +3096,10 @@ dependencies = [ "dnapi-rs", "env_logger", "hex", + "ipnet", "log", "mail-send", + "nebula-config", "rand", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index ec80311..082148f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ "dnapi-rs", "tfcli", "nebula-ffi", - + "nebula-config", "trifid-api", "trifid-api-derive" ] diff --git a/nebula-config/Cargo.toml b/nebula-config/Cargo.toml new file mode 100644 index 0000000..cf72f43 --- /dev/null +++ b/nebula-config/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "nebula-config" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde = { version = "1", features = ["derive"] } +ipnet = { version = "2.9", features = ["serde"] } \ No newline at end of file diff --git a/nebula-config/src/lib.rs b/nebula-config/src/lib.rs new file mode 100644 index 0000000..81bb19e --- /dev/null +++ b/nebula-config/src/lib.rs @@ -0,0 +1,532 @@ +use std::collections::HashMap; +use std::net::{Ipv4Addr, SocketAddrV4}; +use ipnet::{IpNet, Ipv4Net}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfig { + pub pki: NebulaConfigPki, + #[serde(default = "empty_hashmap")] + #[serde(skip_serializing_if = "is_empty_hashmap")] + pub static_host_map: HashMap>, + #[serde(skip_serializing_if = "is_none")] + pub lighthouse: Option, + #[serde(skip_serializing_if = "is_none")] + pub listen: Option, + #[serde(skip_serializing_if = "is_none")] + pub punchy: Option, + #[serde(default = "cipher_aes")] + #[serde(skip_serializing_if = "is_cipher_aes")] + pub cipher: NebulaConfigCipher, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub preferred_ranges: Vec, + #[serde(skip_serializing_if = "is_none")] + pub relay: Option, + #[serde(skip_serializing_if = "is_none")] + pub tun: Option, + #[serde(skip_serializing_if = "is_none")] + pub logging: Option, + #[serde(skip_serializing_if = "is_none")] + pub sshd: Option, + + #[serde(skip_serializing_if = "is_none")] + pub firewall: Option, + + #[serde(default = "u64_1")] + #[serde(skip_serializing_if = "is_u64_1")] + pub routines: u64, + + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub stats: Option, + + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub local_range: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigPki { + pub ca: String, + pub cert: String, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub key: Option, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub blocklist: Vec, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub disconnect_invalid: bool, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigLighthouse { + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub am_lighthouse: bool, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub serve_dns: bool, + #[serde(skip_serializing_if = "is_none")] + pub dns: Option, + #[serde(default = "u32_10")] + #[serde(skip_serializing_if = "is_u32_10")] + pub interval: u32, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub hosts: Vec, + #[serde(default = "empty_hashmap")] + #[serde(skip_serializing_if = "is_empty_hashmap")] + pub remote_allow_list: HashMap, + #[serde(default = "empty_hashmap")] + #[serde(skip_serializing_if = "is_empty_hashmap")] + pub local_allow_list: HashMap, // `interfaces` is not supported +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigLighthouseDns { + #[serde(default = "string_empty")] + #[serde(skip_serializing_if = "is_string_empty")] + pub host: String, + #[serde(default = "u16_53")] + #[serde(skip_serializing_if = "is_u16_53")] + pub port: u16, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigListen { + #[serde(default = "string_empty")] + #[serde(skip_serializing_if = "is_string_empty")] + pub host: String, + #[serde(default = "u16_0")] + #[serde(skip_serializing_if = "is_u16_0")] + pub port: u16, + #[serde(default = "u32_64")] + #[serde(skip_serializing_if = "is_u32_64")] + pub batch: u32, + #[serde(skip_serializing_if = "is_none")] + pub read_buffer: Option, + #[serde(skip_serializing_if = "is_none")] + pub write_buffer: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigPunchy { + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub punch: bool, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub respond: bool, + #[serde(default = "string_1s")] + #[serde(skip_serializing_if = "is_string_1s")] + pub delay: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum NebulaConfigCipher { + #[serde(rename = "aes")] + Aes, + #[serde(rename = "chachapoly")] + ChaChaPoly, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigRelay { + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub relays: Vec, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub am_relay: bool, + #[serde(default = "bool_true")] + #[serde(skip_serializing_if = "is_bool_true")] + pub use_relays: bool, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigTun { + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub disabled: bool, + #[serde(skip_serializing_if = "is_none")] + pub dev: Option, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub drop_local_broadcast: bool, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub drop_multicast: bool, + #[serde(default = "u64_500")] + #[serde(skip_serializing_if = "is_u64_500")] + pub tx_queue: u64, + #[serde(default = "u64_1300")] + #[serde(skip_serializing_if = "is_u64_1300")] + pub mtu: u64, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub routes: Vec, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub unsafe_routes: Vec, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigTunRouteOverride { + pub mtu: u64, + pub route: Ipv4Net, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigTunUnsafeRoute { + pub route: Ipv4Net, + pub via: Ipv4Addr, + #[serde(default = "u64_1300")] + #[serde(skip_serializing_if = "is_u64_1300")] + pub mtu: u64, + #[serde(default = "i64_100")] + #[serde(skip_serializing_if = "is_i64_100")] + pub metric: i64, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigLogging { + #[serde(default = "loglevel_info")] + #[serde(skip_serializing_if = "is_loglevel_info")] + pub level: NebulaConfigLoggingLevel, + #[serde(default = "format_text")] + #[serde(skip_serializing_if = "is_format_text")] + pub format: NebulaConfigLoggingFormat, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub disable_timestamp: bool, + #[serde(default = "timestamp")] + #[serde(skip_serializing_if = "is_timestamp")] + pub timestamp_format: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum NebulaConfigLoggingLevel { + #[serde(rename = "panic")] + Panic, + #[serde(rename = "fatal")] + Fatal, + #[serde(rename = "error")] + Error, + #[serde(rename = "warning")] + Warning, + #[serde(rename = "info")] + Info, + #[serde(rename = "debug")] + Debug, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum NebulaConfigLoggingFormat { + #[serde(rename = "json")] + Json, + #[serde(rename = "text")] + Text, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigSshd { + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub enabled: bool, + pub listen: SocketAddrV4, + pub host_key: String, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub authorized_users: Vec, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigSshdAuthorizedUser { + pub user: String, + #[serde(default = "empty_vec")] + #[serde(skip_serializing_if = "is_empty_vec")] + pub keys: Vec, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "type")] +pub enum NebulaConfigStats { + #[serde(rename = "graphite")] + Graphite(NebulaConfigStatsGraphite), + #[serde(rename = "prometheus")] + Prometheus(NebulaConfigStatsPrometheus), +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigStatsGraphite { + #[serde(default = "string_nebula")] + #[serde(skip_serializing_if = "is_string_nebula")] + pub prefix: String, + #[serde(default = "protocol_tcp")] + #[serde(skip_serializing_if = "is_protocol_tcp")] + pub protocol: NebulaConfigStatsGraphiteProtocol, + pub host: SocketAddrV4, + pub interval: String, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub message_metrics: bool, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub lighthouse_metrics: bool, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum NebulaConfigStatsGraphiteProtocol { + #[serde(rename = "tcp")] + Tcp, + #[serde(rename = "udp")] + Udp, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigStatsPrometheus { + pub listen: String, + pub path: String, + #[serde(default = "string_nebula")] + #[serde(skip_serializing_if = "is_string_nebula")] + pub namespace: String, + #[serde(default = "string_nebula")] + #[serde(skip_serializing_if = "is_string_nebula")] + pub subsystem: String, + pub interval: String, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub message_metrics: bool, + #[serde(default = "bool_false")] + #[serde(skip_serializing_if = "is_bool_false")] + pub lighthouse_metrics: bool, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigFirewall { + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub conntrack: Option, + + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub inbound: Option>, + + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub outbound: Option>, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigFirewallConntrack { + #[serde(default = "string_12m")] + #[serde(skip_serializing_if = "is_string_12m")] + pub tcp_timeout: String, + #[serde(default = "string_3m")] + #[serde(skip_serializing_if = "is_string_3m")] + pub udp_timeout: String, + #[serde(default = "string_10m")] + #[serde(skip_serializing_if = "is_string_10m")] + pub default_timeout: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct NebulaConfigFirewallRule { + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub port: Option, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub proto: Option, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub ca_name: Option, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub ca_sha: Option, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub host: Option, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub group: Option, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub groups: Option>, + #[serde(default = "none")] + #[serde(skip_serializing_if = "is_none")] + pub cidr: Option, +} + +// Default values for serde + +fn string_12m() -> String { + "12m".to_string() +} +fn is_string_12m(s: &str) -> bool { + s == "12m" +} + +fn string_3m() -> String { + "3m".to_string() +} +fn is_string_3m(s: &str) -> bool { + s == "3m" +} + +fn string_10m() -> String { + "10m".to_string() +} +fn is_string_10m(s: &str) -> bool { + s == "10m" +} + +fn empty_vec() -> Vec { + vec![] +} +fn is_empty_vec(v: &Vec) -> bool { + v.is_empty() +} + +fn empty_hashmap() -> HashMap { + HashMap::new() +} +fn is_empty_hashmap(h: &HashMap) -> bool { + h.is_empty() +} + +fn bool_false() -> bool { + false +} +fn is_bool_false(b: &bool) -> bool { + !*b +} + +fn bool_true() -> bool { + true +} +fn is_bool_true(b: &bool) -> bool { + *b +} + +fn u16_53() -> u16 { + 53 +} +fn is_u16_53(u: &u16) -> bool { + *u == 53 +} + +fn u32_10() -> u32 { + 10 +} +fn is_u32_10(u: &u32) -> bool { + *u == 10 +} + +fn u16_0() -> u16 { + 0 +} +fn is_u16_0(u: &u16) -> bool { + *u == 0 +} + +fn u32_64() -> u32 { + 64 +} +fn is_u32_64(u: &u32) -> bool { + *u == 64 +} + +fn string_1s() -> String { + "1s".to_string() +} +fn is_string_1s(s: &str) -> bool { + s == "1s" +} + +fn cipher_aes() -> NebulaConfigCipher { + NebulaConfigCipher::Aes +} +fn is_cipher_aes(c: &NebulaConfigCipher) -> bool { + matches!(c, NebulaConfigCipher::Aes) +} + +fn u64_500() -> u64 { + 500 +} +fn is_u64_500(u: &u64) -> bool { + *u == 500 +} + +fn u64_1300() -> u64 { + 1300 +} +fn is_u64_1300(u: &u64) -> bool { + *u == 1300 +} + +fn i64_100() -> i64 { + 100 +} +fn is_i64_100(i: &i64) -> bool { + *i == 100 +} + +fn loglevel_info() -> NebulaConfigLoggingLevel { + NebulaConfigLoggingLevel::Info +} +fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool { + matches!(l, NebulaConfigLoggingLevel::Info) +} + +fn format_text() -> NebulaConfigLoggingFormat { + NebulaConfigLoggingFormat::Text +} +fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool { + matches!(f, NebulaConfigLoggingFormat::Text) +} + +fn timestamp() -> String { + "2006-01-02T15:04:05Z07:00".to_string() +} +fn is_timestamp(s: &str) -> bool { + s == "2006-01-02T15:04:05Z07:00" +} + +fn u64_1() -> u64 { + 1 +} +fn is_u64_1(u: &u64) -> bool { + *u == 1 +} + +fn string_nebula() -> String { + "nebula".to_string() +} +fn is_string_nebula(s: &str) -> bool { + s == "nebula" +} + +fn string_empty() -> String { + String::new() +} +fn is_string_empty(s: &str) -> bool { + s.is_empty() +} + +fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol { + NebulaConfigStatsGraphiteProtocol::Tcp +} +fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool { + matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp) +} + +fn none() -> Option { + None +} +fn is_none(o: &Option) -> bool { + o.is_none() +} \ No newline at end of file diff --git a/trifid-api/Cargo.toml b/trifid-api/Cargo.toml index 5feb8fb..56ee661 100644 --- a/trifid-api/Cargo.toml +++ b/trifid-api/Cargo.toml @@ -31,4 +31,6 @@ chacha20poly1305 = "0.10" hex = "0.4" thiserror = "1" chrono = "0.4" -dnapi-rs = { version = "0.2", path = "../dnapi-rs" } \ No newline at end of file +dnapi-rs = { version = "0.2", path = "../dnapi-rs" } +nebula-config = { version = "0.1", path = "../nebula-config" } +ipnet = "2.9" \ No newline at end of file diff --git a/trifid-api/config.toml b/trifid-api/config.toml index 2602008..a2cd0c7 100644 --- a/trifid-api/config.toml +++ b/trifid-api/config.toml @@ -50,3 +50,6 @@ auth_token_expiry_seconds = 86400 # 24 hours # It is INCREDIBLY IMPORTANT that you change this value! It should be a 32-byte/256-bit hex-encoded randomly generated # key. data_encryption_key = "dd5aa62f0fd9b7fb4ff65567493f889557212f3a8e9587a79268161f9ae070a6" +# (Required) How long should client certs be valid for? +# Clients will require a config update after this period. +cert_expiry_time_seconds = 31536000 # ~1 year \ No newline at end of file diff --git a/trifid-api/src/config.rs b/trifid-api/src/config.rs index ddd0efc..3cfb14b 100644 --- a/trifid-api/src/config.rs +++ b/trifid-api/src/config.rs @@ -44,4 +44,5 @@ pub struct ConfigTokens { pub session_token_expiry_seconds: u64, pub auth_token_expiry_seconds: u64, pub data_encryption_key: String, + pub cert_expiry_time_seconds: u64 } diff --git a/trifid-api/src/config_generator.rs b/trifid-api/src/config_generator.rs new file mode 100644 index 0000000..e6212e3 --- /dev/null +++ b/trifid-api/src/config_generator.rs @@ -0,0 +1,102 @@ +// The main event! This file handles generation of actual config files. +// This file is part of the sensitive stack, along with `crypt.rs` +// and the entirety of trifid-pki, as it deals with CA private keys. +// Review carefully what you write here! + +use std::error::Error; +use std::net::Ipv4Addr; +use std::str::FromStr; +use std::time::{Duration, SystemTime}; +use actix_web::web::Data; +use diesel::{ExpressionMethods, QueryDsl, SelectableHelper}; +use diesel_async::pooled_connection::bb8::RunError; +use diesel_async::RunQueryDsl; +use ipnet::Ipv4Net; +use thiserror::Error; +use nebula_config::{NebulaConfig, NebulaConfigPki}; +use trifid_pki::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate, NebulaCertificateDetails}; +use trifid_pki::x25519_dalek::PublicKey; +use crate::AppState; +use crate::crypt::sign_cert_with_ca; +use crate::models::{Host, Network, SigningCA, HostKey}; +use crate::schema::{host_keys, hosts, networks, signing_cas}; + +#[derive(Error, Debug)] +pub enum ConfigGenError { + #[error("error acquiring connection from pool: {0}")] + AcquireError(RunError), + #[error("error in database: {0}")] + DbError(diesel::result::Error), + #[error("error parsing a signing CA: {0}")] + InvalidCACert(serde_json::Error), + #[error("an error occured: {0}")] + GenericError(Box) +} + +pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data) -> Result { + let mut conn = state.pool.get().await.map_err(ConfigGenError::AcquireError)?; + + let cas = signing_cas::dsl::signing_cas.filter(signing_cas::organization_id.eq(&host.organization_id)).select(SigningCA::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; + + let mut good_cas = vec![]; + let mut ca_string = String::new(); + + for ca in cas { + if ca.expires_at < SystemTime::now() { + let ca_cert: NebulaCertificate = serde_json::from_value(ca.cert.clone()).map_err(ConfigGenError::InvalidCACert)?; + good_cas.push((ca, ca_cert)); + ca_string += &String::from_utf8_lossy(&ca_cert.serialize_to_pem().map_err(ConfigGenError::GenericError)?); + } + } + + let (signing_ca, ca_cert) = &good_cas[0]; + + let network = networks::dsl::networks.find(&host.network_id).first::(&mut conn).await.map_err(ConfigGenError::DbError)?; + + let mut cert = NebulaCertificate { + details: NebulaCertificateDetails { + name: host.name.clone(), + ips: vec![ + Ipv4Net::new( + Ipv4Addr::from_str(&host.ip_address).unwrap(), + Ipv4Net::from_str(&network.cidr).unwrap().prefix_len() + ).unwrap() + ], + subnets: vec![], + groups: if let Some(role_id) = &host.role_id { + vec![format!("role:{}", role_id)] + } else { vec![] }, + not_before: SystemTime::now() - Duration::from_secs(3600), + not_after: SystemTime::now() + Duration::from_secs(state.config.tokens.cert_expiry_time_seconds), + public_key: *dh_pubkey.as_bytes(), + is_ca: false, + issuer: ca_cert.sha256sum().unwrap(), + }, + signature: vec![], + }; + + sign_cert_with_ca(signing_ca, &mut cert, &state.config).unwrap(); + + let all_blocked_hosts = hosts::dsl::hosts.filter(hosts::network_id.eq(&host.network_id)).filter(hosts::is_blocked.eq(true)).select(Host::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; + + let mut all_blocked_fingerprints = vec![]; + + for blocked_host in all_blocked_hosts { + let hosts_blocked_key_entries = host_keys::dsl::host_keys.filter(host_keys::host_id.eq(&blocked_host.id)).select(HostKey::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; + for blocked_key in hosts_blocked_key_entries { + let cert = deserialize_nebula_certificate_from_pem(&blocked_key.client_cert).unwrap(); + let fingerprint = cert.sha256sum().unwrap(); + all_blocked_fingerprints.push(fingerprint); + } + } + + let pki = NebulaConfigPki { + ca: ca_string, + cert: String::from_utf8(cert.serialize_to_pem().unwrap()).unwrap(), + key: None, + blocklist: all_blocked_fingerprints, + disconnect_invalid: true, + }; + + todo!() +} \ No newline at end of file diff --git a/trifid-api/src/ca.rs b/trifid-api/src/crypt.rs similarity index 65% rename from trifid-api/src/ca.rs rename to trifid-api/src/crypt.rs index 45fd023..3f2310a 100644 --- a/trifid-api/src/ca.rs +++ b/trifid-api/src/crypt.rs @@ -2,7 +2,7 @@ use crate::config::Config; use crate::models::SigningCA; use actix_web::cookie::time::Duration; use chacha20poly1305::aead::{Aead, Payload}; -use chacha20poly1305::{AeadCore, KeyInit, Nonce, XChaCha20Poly1305, XNonce}; +use chacha20poly1305::{AeadCore, KeyInit, XChaCha20Poly1305, XNonce}; use log::error; use rand::rngs::OsRng; use rand::Rng; @@ -10,7 +10,7 @@ use std::error::Error; use std::time::SystemTime; use thiserror::Error; use trifid_pki::cert::{NebulaCertificate, NebulaCertificateDetails}; -use trifid_pki::ed25519_dalek::{SignatureError, SigningKey}; +use trifid_pki::ed25519_dalek::{Signature, SignatureError, Signer, SigningKey, VerifyingKey}; #[derive(Error, Debug)] pub enum CryptographyError { @@ -139,3 +139,80 @@ pub fn sign_cert_with_ca( cert.sign(&key) .map_err(|e| CryptographyError::CertificateSigningError(e)) } + +pub struct DnclientKeyLockbox { + pub info: Vec, + pub nonce: Vec, + pub key: Vec +} + +pub fn create_dnclient_ed_key( + config: &Config, +) -> Result<(DnclientKeyLockbox, VerifyingKey), CryptographyError> { + let key = SigningKey::generate(&mut OsRng); + + let lockbox_key = XChaCha20Poly1305::new_from_slice( + &hex::decode(&config.tokens.data_encryption_key) + .map_err(|e| CryptographyError::InvalidKey(e))?, + ) + .map_err(|_| CryptographyError::InvalidKeyLength)?; + + let salt = XChaCha20Poly1305::generate_nonce(&mut OsRng); + + let aad: [u8; 16] = OsRng.gen(); + + let lockbox = lockbox_key + .encrypt( + &salt, + Payload { + msg: &key.to_keypair_bytes(), + aad: &aad, + }, + ) + .map_err(|e| CryptographyError::LockingError)?; + + Ok((DnclientKeyLockbox { + info: aad.to_vec(), + nonce: salt.as_slice().to_vec(), + key: lockbox, + }, key.verifying_key())) +} + +pub fn sign_dnclient_with_lockbox( + lockbox: &DnclientKeyLockbox, + bytes: &[u8], + config: &Config, +) -> Result { + let lockbox_key = XChaCha20Poly1305::new_from_slice( + &hex::decode(&config.tokens.data_encryption_key) + .map_err(|e| CryptographyError::InvalidKey(e))?, + ) + .map_err(|_| CryptographyError::InvalidKeyLength)?; + + let salt_u24: [u8; 24] = lockbox + .nonce + .clone() + .try_into() + .map_err(|_| CryptographyError::InvalidSaltLength)?; + + let salt = XNonce::from(salt_u24); + + let plaintext = lockbox_key + .decrypt( + &salt, + Payload { + msg: &lockbox.key, + aad: &lockbox.info, + }, + ) + .map_err(|_| CryptographyError::DecryptFailed)?; + + let key = SigningKey::from_keypair_bytes( + &plaintext + .try_into() + .map_err(|_| CryptographyError::InvalidSigningKeyLength)?, + ) + .map_err(|e| CryptographyError::SignatureError(e))?; + + Ok(key.sign(bytes)) +} \ No newline at end of file diff --git a/trifid-api/src/main.rs b/trifid-api/src/main.rs index 886bf7f..d5730d8 100644 --- a/trifid-api/src/main.rs +++ b/trifid-api/src/main.rs @@ -26,7 +26,8 @@ pub mod auth; pub mod email; #[macro_use] pub mod macros; -pub mod ca; +pub mod crypt; +mod config_generator; #[derive(Clone)] pub struct AppState { diff --git a/trifid-api/src/routes/v1/networks.rs b/trifid-api/src/routes/v1/networks.rs index c8baf8e..b85ee81 100644 --- a/trifid-api/src/routes/v1/networks.rs +++ b/trifid-api/src/routes/v1/networks.rs @@ -1,4 +1,4 @@ -use crate::ca::create_signing_ca; +use crate::crypt::create_signing_ca; use crate::models::{Network, NetworkNormalized, Organization, SigningCA, User}; use crate::response::JsonAPIResponse; use crate::schema::networks::dsl::networks; diff --git a/trifid-api/src/routes/v2/enroll.rs b/trifid-api/src/routes/v2/enroll.rs index 9a78cd0..8ee4b07 100644 --- a/trifid-api/src/routes/v2/enroll.rs +++ b/trifid-api/src/routes/v2/enroll.rs @@ -1,7 +1,17 @@ +use std::time::SystemTime; use crate::response::JsonAPIResponse; -use crate::AppState; +use crate::{AppState, randid}; use actix_web::web::{Data, Json}; +use actix_web::http::StatusCode; use dnapi_rs::message::{EnrollRequest, EnrollResponse}; +use crate::models::{EnrollmentCode, Host, HostKey}; +use crate::schema::{enrollment_codes, hosts}; +use diesel::{QueryDsl, OptionalExtension, ExpressionMethods}; +use diesel_async::RunQueryDsl; +use trifid_pki::x25519_dalek::PublicKey; +use crate::config_generator::generate_config; +use crate::crypt::create_dnclient_ed_key; +use crate::schema::host_keys; pub async fn enroll_req( req: Json, @@ -9,5 +19,80 @@ pub async fn enroll_req( ) -> JsonAPIResponse { let mut conn = handle_error!(state.pool.get().await); + let token: EnrollmentCode = match handle_error!(enrollment_codes::table + .find(&req.code) + .first::(&mut conn) + .await + .optional()) + { + Some(t) => t, + None => { + err!( + StatusCode::BAD_REQUEST, + make_err!( + "ERR_INVALID_VALUE", + "does not exist (maybe it expired?)", + "magicLinkToken" + ) + ) + } + }; + + if token.expires < SystemTime::now() { + err!( + StatusCode::BAD_REQUEST, + make_err!( + "ERR_INVALID_VALUE", + "does not exist (maybe it expired?)", + "magicLinkToken" + ) + ); + } + + // valid token + + let host: Host = match handle_error!(hosts::table + .find(&token.host_id) + .first::(&mut conn) + .await + .optional()) + { + Some(t) => t, + None => { + err!( + StatusCode::BAD_REQUEST, + make_err!( + "ERR_INVALID_VALUE", + "does not exist (maybe it expired?)", + "magicLinkToken.host_id" + ) + ) + } + }; + + handle_error!(diesel::delete(&token).execute(&mut conn).await); + + // reset the host's key entries + + handle_error!(diesel::delete(host_keys::dsl::host_keys.filter(host_keys::host_id.eq(&host.id))).execute(&mut conn).await); + + let (key_lockbox, trusted_key) = handle_error!(create_dnclient_ed_key(&state.config)); + + let user_dh_key = PublicKey::from(req.dh_pubkey.try_into().unwrap()).unwrap(); + + let config = handle_error!(generate_config(&host, state.clone()).await); + + let new_counter_1 = HostKey { + id: randid!(id "hostkey"), + host_id: host.id.clone(), + counter: 1, + client_ed_pub: vec![], + client_dh_pub: vec![], + client_cert: vec![], + salt: vec![], + info: vec![], + server_ed_priv: vec![], + }; + todo!() }