trifid/trifid-api/src/codegen/mod.rs

423 lines
14 KiB
Rust

use std::collections::HashMap;
use std::error::Error;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::time::{Duration, SystemTime};
use actix_web::web::Data;
use crate::config::{
NebulaConfig, NebulaConfigCipher, NebulaConfigFirewall, NebulaConfigFirewallRule,
NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy,
NebulaConfigRelay, NebulaConfigTun, CONFIG,
};
use crate::crypto::{decrypt_with_nonce, get_cipher_from_config};
use crate::keystore::keystore_init;
use crate::AppState;
use ed25519_dalek::SigningKey;
use ipnet::Ipv4Net;
use log::{debug, error};
use sea_orm::{ColumnTrait, Condition, EntityTrait, QueryFilter};
use serde_yaml::{Mapping, Value};
use trifid_api_entities::entity::{
firewall_rule, host, host_config_override, host_static_address, network, organization,
signing_ca,
};
use trifid_pki::cert::{
deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate,
NebulaCertificateDetails,
};
pub struct CodegenRequiredInfo {
pub host: host::Model,
pub host_static_addresses: HashMap<String, Vec<SocketAddrV4>>,
pub network: network::Model,
pub organization: organization::Model,
pub dh_pubkey: Vec<u8>,
pub ca: signing_ca::Model,
pub other_cas: Vec<signing_ca::Model>,
pub relay_ips: Vec<Ipv4Addr>,
pub lighthouse_ips: Vec<Ipv4Addr>,
pub blocked_hosts: Vec<String>,
pub firewall_rules: Vec<NebulaConfigFirewallRule>,
pub config_overrides: Vec<(String, String)>
}
pub async fn generate_config(
_data: &Data<AppState>,
info: &CodegenRequiredInfo,
) -> Result<(NebulaConfig, NebulaCertificate), Box<dyn Error>> {
debug!(
"chk: deserialize CA cert {:x?}",
hex::decode(&info.ca.cert)?
);
// decode the CA data
let ca_cert = deserialize_nebula_certificate_from_pem(&hex::decode(&info.ca.cert)?)?;
// generate the client's new cert
let mut cert = NebulaCertificate {
details: NebulaCertificateDetails {
name: info.host.name.clone(),
ips: vec![Ipv4Net::new(
Ipv4Addr::from_str(&info.host.ip).unwrap(),
Ipv4Net::from_str(&info.network.cidr).unwrap().prefix_len(),
)
.unwrap()],
subnets: vec![],
groups: vec![format!("role:{}", info.host.role)],
not_before: SystemTime::now(),
not_after: SystemTime::now() + Duration::from_secs(CONFIG.crypto.certs_expiry_time),
public_key: info.dh_pubkey.clone().try_into().unwrap(),
is_ca: false,
issuer: ca_cert.sha256sum()?,
},
signature: vec![],
};
// decrypt the private key
let private_pem = decrypt_with_nonce(
&hex::decode(&info.ca.key)?,
hex::decode(&info.ca.nonce)?.try_into().unwrap(),
&get_cipher_from_config(&CONFIG)?,
)
.map_err(|_| "Encryption error")?;
let private_key = deserialize_ed25519_private(&private_pem)?;
let signing_key = SigningKey::from_keypair_bytes(&private_key.try_into().unwrap()).unwrap();
cert.sign(&signing_key)?;
// cas
let mut cas = String::new();
for ca in &info.other_cas {
cas += &String::from_utf8(hex::decode(&ca.cert)?)?;
}
let ks = keystore_init()?;
// blocked hosts
let mut blocked_hosts_fingerprints = vec![];
for host in &info.blocked_hosts {
if let Some(host) = ks.hosts.iter().find(|u| &u.id == host) {
for cert in &host.certs {
blocked_hosts_fingerprints.push(cert.cert.sha256sum()?);
}
}
}
let nebula_config = NebulaConfig {
pki: NebulaConfigPki {
ca: cas,
cert: String::from_utf8(cert.serialize_to_pem()?)?,
key: None,
blocklist: blocked_hosts_fingerprints,
disconnect_invalid: true,
},
static_host_map: info
.host_static_addresses
.iter()
.map(|(u, addrs)| (Ipv4Addr::from_str(u).unwrap(), addrs.clone()))
.collect(),
lighthouse: match info.host.is_lighthouse {
true => Some(NebulaConfigLighthouse {
am_lighthouse: true,
serve_dns: false,
dns: None,
interval: 60,
hosts: vec![],
remote_allow_list: HashMap::new(),
local_allow_list: HashMap::new(),
}),
false => Some(NebulaConfigLighthouse {
am_lighthouse: false,
serve_dns: false,
dns: None,
interval: 60,
hosts: info.lighthouse_ips.to_vec(),
remote_allow_list: HashMap::new(),
local_allow_list: HashMap::new(),
}),
},
listen: match info.host.is_lighthouse || info.host.is_relay {
true => Some(NebulaConfigListen {
host: "[::]".to_string(),
port: info.host.listen_port as u16,
batch: 64,
read_buffer: Some(10485760),
write_buffer: Some(10485760),
}),
false => Some(NebulaConfigListen {
host: "[::]".to_string(),
port: 0u16,
batch: 64,
read_buffer: Some(10485760),
write_buffer: Some(10485760),
}),
},
punchy: Some(NebulaConfigPunchy {
punch: true,
respond: true,
delay: "1s".to_string(),
}),
cipher: NebulaConfigCipher::Aes,
preferred_ranges: vec![],
relay: Some(NebulaConfigRelay {
relays: info.relay_ips.to_vec(),
am_relay: info.host.is_relay,
use_relays: true,
}),
tun: Some(NebulaConfigTun {
disabled: false,
dev: Some("trifid1".to_string()),
drop_local_broadcast: true,
drop_multicast: true,
tx_queue: 500,
mtu: 1300,
routes: vec![],
unsafe_routes: vec![],
}),
logging: None,
sshd: None,
firewall: Some(NebulaConfigFirewall {
conntrack: None,
inbound: Some(info.firewall_rules.clone()),
outbound: Some(vec![NebulaConfigFirewallRule {
port: Some("any".to_string()),
proto: Some("any".to_string()),
ca_name: None,
ca_sha: None,
host: Some("any".to_string()),
group: None,
groups: None,
cidr: None,
}]),
}),
routines: 1,
stats: None,
local_range: None,
};
let mut val = Mapping::new();
for (k, v) in &info.config_overrides {
let key_split = k.split('.').collect::<Vec<&str>>();
let mut value = &mut val;
for ks_k in &key_split[..key_split.len()-1] {
if !value.contains_key(ks_k) {
value.insert(Value::String(ks_k.to_string()), Value::Mapping(Mapping::new()));
}
value = value.get_mut(ks_k).ok_or("Invalid key-value pair")?.as_mapping_mut().unwrap();
}
value.insert(Value::String(key_split[key_split.len()-1].to_string()), serde_yaml::from_str(&v)?);
}
let overrides_value = Value::Mapping(val);
debug!("{:?}", overrides_value);
let mut value = serde_yaml::to_value(nebula_config)?;
debug!("{:?}", value);
merge_yaml(&mut value, overrides_value);
debug!("{:?}", value);
let nebula_config = serde_yaml::from_value(value)?;
Ok((nebula_config, cert))
}
// This cursed abomination credit https://stackoverflow.com/questions/67727239/how-to-combine-including-nested-array-values-two-serde-yamlvalue-objects
fn merge_yaml(a: &mut serde_yaml::Value, b: serde_yaml::Value) {
match (a, b) {
(a @ &mut serde_yaml::Value::Mapping(_), serde_yaml::Value::Mapping(b)) => {
let a = a.as_mapping_mut().unwrap();
for (k, v) in b {
if v.is_sequence() && a.contains_key(&k) && a[&k].is_sequence() {
let mut _b = a.get(&k).unwrap().as_sequence().unwrap().to_owned();
_b.append(&mut v.as_sequence().unwrap().to_owned());
a[&k] = serde_yaml::Value::from(_b);
continue;
}
if !a.contains_key(&k) {a.insert(k.to_owned(), v.to_owned());}
else { merge_yaml(&mut a[&k], v); }
}
}
(a, b) => *a = b,
}
}
pub async fn collect_info<'a>(
db: &'a Data<AppState>,
host: &'a str,
dh_pubkey: &'a [u8],
) -> Result<CodegenRequiredInfo, Box<dyn Error>> {
// load host info
let host = trifid_api_entities::entity::host::Entity::find()
.filter(host::Column::Id.eq(host))
.one(&db.conn)
.await?;
let host = match host {
Some(host) => host,
None => return Err("Host does not exist".into()),
};
let host_config_overrides = trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(&host.id))
.all(&db.conn)
.await?;
let _host_static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
.filter(host_static_address::Column::Host.eq(&host.id))
.all(&db.conn)
.await?;
// load network info
let network = trifid_api_entities::entity::network::Entity::find()
.filter(network::Column::Id.eq(&host.network))
.one(&db.conn)
.await?;
let network = match network {
Some(network) => network,
None => {
return Err("Network does not exist".into());
}
};
// get all lighthouses and relays and get all of their static addresses, and get internal addresses of relays
let mut host_x_static_addresses = HashMap::new();
let mut relays = vec![];
let mut lighthouses = vec![];
let mut blocked_hosts = vec![];
let hosts = trifid_api_entities::entity::host::Entity::find()
.filter(host::Column::Network.eq(&network.id))
.filter(Condition::any().add(host::Column::IsRelay.eq(true)).add(host::Column::IsLighthouse.eq(true)))
.all(&db.conn)
.await?;
for host in hosts {
if host.is_relay {
relays.push(Ipv4Addr::from_str(&host.ip).unwrap());
} else if host.is_lighthouse {
lighthouses.push(Ipv4Addr::from_str(&host.ip).unwrap());
}
if host.is_blocked {
blocked_hosts.push(host.id.clone());
}
let static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
.filter(host_static_address::Column::Host.eq(host.id))
.all(&db.conn)
.await?;
let static_addresses: Vec<SocketAddrV4> = static_addresses
.iter()
.map(|u| SocketAddrV4::from_str(&u.address).unwrap())
.collect();
host_x_static_addresses.insert(host.ip.clone(), static_addresses);
}
// load org info
let org = trifid_api_entities::entity::organization::Entity::find()
.filter(organization::Column::Id.eq(&network.organization))
.one(&db.conn)
.await?;
let org = match org {
Some(org) => org,
None => {
return Err("Organization does not exist".into());
}
};
// get the CA that is closest to expiry, but *not* expired
let available_cas = trifid_api_entities::entity::signing_ca::Entity::find()
.filter(signing_ca::Column::Organization.eq(&org.id))
.all(&db.conn)
.await?;
let mut best_ca: Option<signing_ca::Model> = None;
let mut all_cas = vec![];
for ca in available_cas {
if let Some(existing_best) = &best_ca {
if ca.expires < existing_best.expires {
best_ca = Some(ca.clone());
}
} else {
best_ca = Some(ca.clone());
}
all_cas.push(ca);
}
if best_ca.is_none() {
error!(
"!!! NO AVAILABLE CAS !!! while trying to sign cert for {}",
org.id
);
return Err("No signing CAs available".into());
}
let best_ca = best_ca.unwrap();
// pull our host's config overrides
let config_overrides = host_config_overrides.iter().map(|u| {
(u.key.clone(), u.value.clone())
}).collect();
// pull our role's firewall rules
let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find()
.filter(firewall_rule::Column::Role.eq(&host.role))
.all(&db.conn)
.await?;
let firewall_rules = firewall_rules
.iter()
.map(|u| NebulaConfigFirewallRule {
port: Some(if u.port_range_from == 0 && u.port_range_to == 65535 {
"any".to_string()
} else {
format!("{}-{}", u.port_range_from, u.port_range_to)
}),
proto: Some(u.protocol.clone().to_lowercase()),
ca_name: None,
ca_sha: None,
host: if u.allowed_role_id.is_some() {
None
} else {
Some("any".to_string())
},
groups: if u.allowed_role_id.is_some() {
Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())])
} else {
None
},
group: None,
cidr: None,
})
.collect();
Ok(CodegenRequiredInfo {
host,
host_static_addresses: host_x_static_addresses,
network,
organization: org,
dh_pubkey: dh_pubkey.to_vec(),
ca: best_ca,
other_cas: all_cas,
relay_ips: relays,
lighthouse_ips: lighthouses,
blocked_hosts,
firewall_rules,
config_overrides
})
}