diff --git a/.idea/trifid.iml b/.idea/trifid.iml
index 7ee0faa..85e74e1 100644
--- a/.idea/trifid.iml
+++ b/.idea/trifid.iml
@@ -13,6 +13,7 @@
+
diff --git a/Cargo.lock b/Cargo.lock
index 47f537a..84fd150 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3103,6 +3103,7 @@ dependencies = [
"rand",
"serde",
"serde_json",
+ "serde_yaml",
"thiserror",
"toml 0.8.5",
"totp-rs",
diff --git a/trifid-api/Cargo.toml b/trifid-api/Cargo.toml
index 56ee661..d06a5af 100644
--- a/trifid-api/Cargo.toml
+++ b/trifid-api/Cargo.toml
@@ -33,4 +33,5 @@ thiserror = "1"
chrono = "0.4"
dnapi-rs = { version = "0.2", path = "../dnapi-rs" }
nebula-config = { version = "0.1", path = "../nebula-config" }
-ipnet = "2.9"
\ No newline at end of file
+ipnet = "2.9"
+serde_yaml = "0.9"
\ No newline at end of file
diff --git a/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql b/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql
index 078cda2..e88608c 100644
--- a/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql
+++ b/trifid-api/migrations/2023-12-24-222225_create_table_host_overrides/up.sql
@@ -2,5 +2,5 @@ CREATE TABLE host_overrides (
id VARCHAR NOT NULL PRIMARY KEY,
host_id VARCHAR NOT NULL REFERENCES hosts(id),
key VARCHAR NOT NULL,
- value jsonb NOT NULL
+ value VARCHAR NOT NULL
);
\ No newline at end of file
diff --git a/trifid-api/src/config_generator.rs b/trifid-api/src/config_generator.rs
index e6212e3..591cf77 100644
--- a/trifid-api/src/config_generator.rs
+++ b/trifid-api/src/config_generator.rs
@@ -3,8 +3,9 @@
// and the entirety of trifid-pki, as it deals with CA private keys.
// Review carefully what you write here!
+use std::collections::{Bound, HashMap};
use std::error::Error;
-use std::net::Ipv4Addr;
+use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::time::{Duration, SystemTime};
use actix_web::web::Data;
@@ -12,14 +13,15 @@ use diesel::{ExpressionMethods, QueryDsl, SelectableHelper};
use diesel_async::pooled_connection::bb8::RunError;
use diesel_async::RunQueryDsl;
use ipnet::Ipv4Net;
+use serde_yaml::{Mapping, Value};
use thiserror::Error;
-use nebula_config::{NebulaConfig, NebulaConfigPki};
+use nebula_config::{NebulaConfig, NebulaConfigCipher, NebulaConfigFirewall, NebulaConfigFirewallRule, NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy, NebulaConfigRelay, NebulaConfigTun};
use trifid_pki::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate, NebulaCertificateDetails};
use trifid_pki::x25519_dalek::PublicKey;
use crate::AppState;
use crate::crypt::sign_cert_with_ca;
-use crate::models::{Host, Network, SigningCA, HostKey};
-use crate::schema::{host_keys, hosts, networks, signing_cas};
+use crate::models::{Host, Network, SigningCA, HostKey, Role, RoleFirewallRule, HostOverride};
+use crate::schema::{host_keys, hosts, networks, role_firewall_rules, roles, signing_cas, host_overrides};
#[derive(Error, Debug)]
pub enum ConfigGenError {
@@ -44,8 +46,8 @@ pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data>());
+ }
+
+ for relay in &all_relays {
+ static_host_map.insert(Ipv4Addr::from_str(&relay.ip_address).unwrap(), relay.static_addresses.iter().map(|u| SocketAddrV4::from_str(&u.clone().unwrap()).unwrap()).collect::>());
+ }
+
+ let lighthouse = Some(NebulaConfigLighthouse {
+ am_lighthouse: host.is_lighthouse,
+ serve_dns: false,
+ dns: None,
+ interval: 10,
+ hosts: if host.is_lighthouse {
+ vec![]
+ } else {
+ all_lighthouses.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>()
+ },
+ remote_allow_list: Default::default(),
+ local_allow_list: Default::default(),
+ });
+
+ let relay = Some(NebulaConfigRelay {
+ am_relay: host.is_relay || (host.is_lighthouse && network.lighthouses_as_relays),
+ relays: if host.is_relay || (host.is_lighthouse && network.lighthouses_as_relays) {
+ vec![]
+ } else if network.lighthouses_as_relays {
+ all_lighthouses.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>()
+ } else {
+ all_relays.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>()
+ },
+ use_relays: true
+ });
+
+ let mut inbound_firewall = None;
+
+ if let Some(role_id) = &host.role_id {
+ let role = roles::dsl::roles.find(role_id).first::(&mut conn).await.map_err(ConfigGenError::DbError)?;
+
+ let firewall_rules = role_firewall_rules::dsl::role_firewall_rules.filter(role_firewall_rules::role_id.eq(role_id)).select(RoleFirewallRule::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?;
+
+ inbound_firewall = Some(firewall_rules.iter().map(|u| NebulaConfigFirewallRule {
+ port: if let Some((from, to)) = u.port_range {
+ let start_port = match from {
+ Bound::Included(u) => u,
+ Bound::Excluded(u) => u+1,
+ Bound::Unbounded => 0
+ };
+ let end_port = match from {
+ Bound::Included(u) => u,
+ Bound::Excluded(u) => u-1,
+ Bound::Unbounded => 65535
+ };
+ Some(format!("{}-{}", start_port, end_port))
+ } else {
+ Some("any".to_string())
+ },
+ proto: Some(u.protocol.clone()),
+ ca_name: None,
+ ca_sha: None,
+ host: if u.allowed_role_id.is_some() {
+ None
+ } else {
+ Some("any".to_string())
+ },
+ group: None,
+ groups: if u.allowed_role_id.is_some() {
+ Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())])
+ } else {
+ None
+ },
+ cidr: None,
+ }).collect::>())
+ }
+
+ let config = NebulaConfig {
+ pki,
+ static_host_map,
+ lighthouse,
+ listen: Some(NebulaConfigListen {
+ host: "0.0.0.0".to_string(),
+ port: host.listen_port as u16,
+ batch: 64,
+ read_buffer: None,
+ write_buffer: None,
+ }),
+ punchy: Some(NebulaConfigPunchy {
+ punch: true,
+ respond: true,
+ delay: "1s".to_string(),
+ }),
+ cipher: NebulaConfigCipher::Aes,
+ preferred_ranges: vec![],
+ relay,
+ tun: Some(NebulaConfigTun {
+ disabled: false,
+ dev: Some("trifid1".to_string()),
+ drop_local_broadcast: false,
+ drop_multicast: false,
+ tx_queue: 500,
+ mtu: 1300,
+ routes: vec![],
+ unsafe_routes: vec![],
+ }),
+ logging: None,
+ sshd: None,
+ firewall: Some(NebulaConfigFirewall {
+ conntrack: None,
+ inbound: inbound_firewall,
+ outbound: Some(vec![
+ NebulaConfigFirewallRule {
+ port: Some("any".to_string()),
+ proto: Some("any".to_string()),
+ ca_name: None,
+ ca_sha: None,
+ host: Some("any".to_string()),
+ group: None,
+ groups: None,
+ cidr: None,
+ }
+ ]),
+ }),
+ routines: 1,
+ stats: None,
+ local_range: None,
+ };
+
+ let mut yaml_value = serde_yaml::to_value(&config).unwrap();
+
+ let all_overrides = host_overrides::dsl::host_overrides.filter(host_overrides::host_id.eq(&host.id)).select(HostOverride::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?;
+
+ // Cursed value overrides
+
+ for h_override in all_overrides {
+ // split up the key
+ // a.b.c.d = ['a']['b']['c']['d'] = value
+ let key_split = h_override.key.split('.').collect::>();
+
+ let mut current_val = &mut yaml_value;
+
+ for key_iter in &key_split[..key_split.len() - 1] {
+ current_val = current_val
+ .as_mapping_mut()
+ .unwrap()
+ .entry(Value::String(key_iter.to_string()))
+ .or_insert(Value::Mapping(Mapping::new()));
+ }
+
+ current_val.as_mapping_mut().unwrap().insert(
+ Value::String(key_split[key_split.len() - 1].to_string()),
+ serde_yaml::from_str(h_override.value)?,
+ );
+ }
+
+ let merged_config: NebulaConfig = serde_yaml::from_value(yaml_value).unwrap();
+
+ Ok(merged_config)
}
\ No newline at end of file
diff --git a/trifid-api/src/routes/v2/enroll.rs b/trifid-api/src/routes/v2/enroll.rs
index 8ee4b07..33d13ab 100644
--- a/trifid-api/src/routes/v2/enroll.rs
+++ b/trifid-api/src/routes/v2/enroll.rs
@@ -3,7 +3,7 @@ use crate::response::JsonAPIResponse;
use crate::{AppState, randid};
use actix_web::web::{Data, Json};
use actix_web::http::StatusCode;
-use dnapi_rs::message::{EnrollRequest, EnrollResponse};
+use dnapi_rs::message::{EnrollRequest, EnrollResponse, EnrollResponseData};
use crate::models::{EnrollmentCode, Host, HostKey};
use crate::schema::{enrollment_codes, hosts};
use diesel::{QueryDsl, OptionalExtension, ExpressionMethods};
@@ -78,21 +78,37 @@ pub async fn enroll_req(
let (key_lockbox, trusted_key) = handle_error!(create_dnclient_ed_key(&state.config));
- let user_dh_key = PublicKey::from(req.dh_pubkey.try_into().unwrap()).unwrap();
+ let fixed_dh_key: [u8; 32] = req.dh_pubkey.clone().try_into().unwrap();
- let config = handle_error!(generate_config(&host, state.clone()).await);
+ let user_dh_key = PublicKey::from(fixed_dh_key);
+
+ let config = handle_error!(generate_config(&host, user_dh_key, state.clone()).await);
let new_counter_1 = HostKey {
id: randid!(id "hostkey"),
host_id: host.id.clone(),
counter: 1,
- client_ed_pub: vec![],
- client_dh_pub: vec![],
- client_cert: vec![],
- salt: vec![],
- info: vec![],
- server_ed_priv: vec![],
+ client_ed_pub: req.ed_pubkey.clone(),
+ client_dh_pub: req.dh_pubkey.clone(),
+ client_cert: config.pki.cert.as_bytes().to_vec(),
+ salt: key_lockbox.nonce,
+ info: key_lockbox.info,
+ server_ed_priv: key_lockbox.key,
};
- todo!()
+ handle_error!(
+ diesel::insert_into(host_keys::table)
+ .values(&new_counter_1)
+ .execute(&mut conn)
+ .await
+ );
+
+ let config_bytes = serde_yaml::to_string(&config).unwrap().as_bytes();
+
+ ok!(EnrollResponse::Success {
+ data: EnrollResponseData {
+ host_id: host.id.clone(),
+ config: config_bytes.to_vec()
+ }
+ })
}
diff --git a/trifid-api/src/schema.rs b/trifid-api/src/schema.rs
index 3722f1f..4908518 100644
--- a/trifid-api/src/schema.rs
+++ b/trifid-api/src/schema.rs
@@ -35,7 +35,7 @@ diesel::table! {
id -> Varchar,
host_id -> Varchar,
key -> Varchar,
- value -> Jsonb,
+ value -> Varchar,
}
}