diff --git a/.idea/trifid.iml b/.idea/trifid.iml index 7ee0faa..85e74e1 100644 --- a/.idea/trifid.iml +++ b/.idea/trifid.iml @@ -13,6 +13,7 @@ + diff --git a/Cargo.lock b/Cargo.lock index 47f537a..84fd150 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3103,6 +3103,7 @@ dependencies = [ "rand", "serde", "serde_json", + "serde_yaml", "thiserror", "toml 0.8.5", "totp-rs", diff --git a/trifid-api/Cargo.toml b/trifid-api/Cargo.toml index 56ee661..d06a5af 100644 --- a/trifid-api/Cargo.toml +++ b/trifid-api/Cargo.toml @@ -33,4 +33,5 @@ thiserror = "1" chrono = "0.4" dnapi-rs = { version = "0.2", path = "../dnapi-rs" } nebula-config = { version = "0.1", path = "../nebula-config" } -ipnet = "2.9" \ No newline at end of file +ipnet = "2.9" +serde_yaml = "0.9" \ No newline at end of file diff --git a/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql b/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql index 078cda2..e88608c 100644 --- a/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql +++ b/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql @@ -2,5 +2,5 @@ CREATE TABLE host_overrides ( id VARCHAR NOT NULL PRIMARY KEY, host_id VARCHAR NOT NULL REFERENCES hosts(id), key VARCHAR NOT NULL, - value jsonb NOT NULL + value VARCHAR NOT NULL ); \ No newline at end of file diff --git a/trifid-api/src/config_generator.rs b/trifid-api/src/config_generator.rs index e6212e3..591cf77 100644 --- a/trifid-api/src/config_generator.rs +++ b/trifid-api/src/config_generator.rs @@ -3,8 +3,9 @@ // and the entirety of trifid-pki, as it deals with CA private keys. // Review carefully what you write here! +use std::collections::{Bound, HashMap}; use std::error::Error; -use std::net::Ipv4Addr; +use std::net::{Ipv4Addr, SocketAddrV4}; use std::str::FromStr; use std::time::{Duration, SystemTime}; use actix_web::web::Data; @@ -12,14 +13,15 @@ use diesel::{ExpressionMethods, QueryDsl, SelectableHelper}; use diesel_async::pooled_connection::bb8::RunError; use diesel_async::RunQueryDsl; use ipnet::Ipv4Net; +use serde_yaml::{Mapping, Value}; use thiserror::Error; -use nebula_config::{NebulaConfig, NebulaConfigPki}; +use nebula_config::{NebulaConfig, NebulaConfigCipher, NebulaConfigFirewall, NebulaConfigFirewallRule, NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy, NebulaConfigRelay, NebulaConfigTun}; use trifid_pki::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate, NebulaCertificateDetails}; use trifid_pki::x25519_dalek::PublicKey; use crate::AppState; use crate::crypt::sign_cert_with_ca; -use crate::models::{Host, Network, SigningCA, HostKey}; -use crate::schema::{host_keys, hosts, networks, signing_cas}; +use crate::models::{Host, Network, SigningCA, HostKey, Role, RoleFirewallRule, HostOverride}; +use crate::schema::{host_keys, hosts, networks, role_firewall_rules, roles, signing_cas, host_overrides}; #[derive(Error, Debug)] pub enum ConfigGenError { @@ -44,8 +46,8 @@ pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data>()); + } + + for relay in &all_relays { + static_host_map.insert(Ipv4Addr::from_str(&relay.ip_address).unwrap(), relay.static_addresses.iter().map(|u| SocketAddrV4::from_str(&u.clone().unwrap()).unwrap()).collect::>()); + } + + let lighthouse = Some(NebulaConfigLighthouse { + am_lighthouse: host.is_lighthouse, + serve_dns: false, + dns: None, + interval: 10, + hosts: if host.is_lighthouse { + vec![] + } else { + all_lighthouses.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>() + }, + remote_allow_list: Default::default(), + local_allow_list: Default::default(), + }); + + let relay = Some(NebulaConfigRelay { + am_relay: host.is_relay || (host.is_lighthouse && network.lighthouses_as_relays), + relays: if host.is_relay || (host.is_lighthouse && network.lighthouses_as_relays) { + vec![] + } else if network.lighthouses_as_relays { + all_lighthouses.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>() + } else { + all_relays.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>() + }, + use_relays: true + }); + + let mut inbound_firewall = None; + + if let Some(role_id) = &host.role_id { + let role = roles::dsl::roles.find(role_id).first::(&mut conn).await.map_err(ConfigGenError::DbError)?; + + let firewall_rules = role_firewall_rules::dsl::role_firewall_rules.filter(role_firewall_rules::role_id.eq(role_id)).select(RoleFirewallRule::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; + + inbound_firewall = Some(firewall_rules.iter().map(|u| NebulaConfigFirewallRule { + port: if let Some((from, to)) = u.port_range { + let start_port = match from { + Bound::Included(u) => u, + Bound::Excluded(u) => u+1, + Bound::Unbounded => 0 + }; + let end_port = match from { + Bound::Included(u) => u, + Bound::Excluded(u) => u-1, + Bound::Unbounded => 65535 + }; + Some(format!("{}-{}", start_port, end_port)) + } else { + Some("any".to_string()) + }, + proto: Some(u.protocol.clone()), + ca_name: None, + ca_sha: None, + host: if u.allowed_role_id.is_some() { + None + } else { + Some("any".to_string()) + }, + group: None, + groups: if u.allowed_role_id.is_some() { + Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())]) + } else { + None + }, + cidr: None, + }).collect::>()) + } + + let config = NebulaConfig { + pki, + static_host_map, + lighthouse, + listen: Some(NebulaConfigListen { + host: "0.0.0.0".to_string(), + port: host.listen_port as u16, + batch: 64, + read_buffer: None, + write_buffer: None, + }), + punchy: Some(NebulaConfigPunchy { + punch: true, + respond: true, + delay: "1s".to_string(), + }), + cipher: NebulaConfigCipher::Aes, + preferred_ranges: vec![], + relay, + tun: Some(NebulaConfigTun { + disabled: false, + dev: Some("trifid1".to_string()), + drop_local_broadcast: false, + drop_multicast: false, + tx_queue: 500, + mtu: 1300, + routes: vec![], + unsafe_routes: vec![], + }), + logging: None, + sshd: None, + firewall: Some(NebulaConfigFirewall { + conntrack: None, + inbound: inbound_firewall, + outbound: Some(vec![ + NebulaConfigFirewallRule { + port: Some("any".to_string()), + proto: Some("any".to_string()), + ca_name: None, + ca_sha: None, + host: Some("any".to_string()), + group: None, + groups: None, + cidr: None, + } + ]), + }), + routines: 1, + stats: None, + local_range: None, + }; + + let mut yaml_value = serde_yaml::to_value(&config).unwrap(); + + let all_overrides = host_overrides::dsl::host_overrides.filter(host_overrides::host_id.eq(&host.id)).select(HostOverride::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; + + // Cursed value overrides + + for h_override in all_overrides { + // split up the key + // a.b.c.d = ['a']['b']['c']['d'] = value + let key_split = h_override.key.split('.').collect::>(); + + let mut current_val = &mut yaml_value; + + for key_iter in &key_split[..key_split.len() - 1] { + current_val = current_val + .as_mapping_mut() + .unwrap() + .entry(Value::String(key_iter.to_string())) + .or_insert(Value::Mapping(Mapping::new())); + } + + current_val.as_mapping_mut().unwrap().insert( + Value::String(key_split[key_split.len() - 1].to_string()), + serde_yaml::from_str(h_override.value)?, + ); + } + + let merged_config: NebulaConfig = serde_yaml::from_value(yaml_value).unwrap(); + + Ok(merged_config) } \ No newline at end of file diff --git a/trifid-api/src/routes/v2/enroll.rs b/trifid-api/src/routes/v2/enroll.rs index 8ee4b07..33d13ab 100644 --- a/trifid-api/src/routes/v2/enroll.rs +++ b/trifid-api/src/routes/v2/enroll.rs @@ -3,7 +3,7 @@ use crate::response::JsonAPIResponse; use crate::{AppState, randid}; use actix_web::web::{Data, Json}; use actix_web::http::StatusCode; -use dnapi_rs::message::{EnrollRequest, EnrollResponse}; +use dnapi_rs::message::{EnrollRequest, EnrollResponse, EnrollResponseData}; use crate::models::{EnrollmentCode, Host, HostKey}; use crate::schema::{enrollment_codes, hosts}; use diesel::{QueryDsl, OptionalExtension, ExpressionMethods}; @@ -78,21 +78,37 @@ pub async fn enroll_req( let (key_lockbox, trusted_key) = handle_error!(create_dnclient_ed_key(&state.config)); - let user_dh_key = PublicKey::from(req.dh_pubkey.try_into().unwrap()).unwrap(); + let fixed_dh_key: [u8; 32] = req.dh_pubkey.clone().try_into().unwrap(); - let config = handle_error!(generate_config(&host, state.clone()).await); + let user_dh_key = PublicKey::from(fixed_dh_key); + + let config = handle_error!(generate_config(&host, user_dh_key, state.clone()).await); let new_counter_1 = HostKey { id: randid!(id "hostkey"), host_id: host.id.clone(), counter: 1, - client_ed_pub: vec![], - client_dh_pub: vec![], - client_cert: vec![], - salt: vec![], - info: vec![], - server_ed_priv: vec![], + client_ed_pub: req.ed_pubkey.clone(), + client_dh_pub: req.dh_pubkey.clone(), + client_cert: config.pki.cert.as_bytes().to_vec(), + salt: key_lockbox.nonce, + info: key_lockbox.info, + server_ed_priv: key_lockbox.key, }; - todo!() + handle_error!( + diesel::insert_into(host_keys::table) + .values(&new_counter_1) + .execute(&mut conn) + .await + ); + + let config_bytes = serde_yaml::to_string(&config).unwrap().as_bytes(); + + ok!(EnrollResponse::Success { + data: EnrollResponseData { + host_id: host.id.clone(), + config: config_bytes.to_vec() + } + }) } diff --git a/trifid-api/src/schema.rs b/trifid-api/src/schema.rs index 3722f1f..4908518 100644 --- a/trifid-api/src/schema.rs +++ b/trifid-api/src/schema.rs @@ -35,7 +35,7 @@ diesel::table! { id -> Varchar, host_id -> Varchar, key -> Varchar, - value -> Jsonb, + value -> Varchar, } }