423 lines
14 KiB
Rust
423 lines
14 KiB
Rust
use std::collections::HashMap;
|
|
use std::error::Error;
|
|
use std::net::{Ipv4Addr, SocketAddrV4};
|
|
use std::str::FromStr;
|
|
use std::time::{Duration, SystemTime};
|
|
|
|
use actix_web::web::Data;
|
|
|
|
use crate::config::{
|
|
NebulaConfig, NebulaConfigCipher, NebulaConfigFirewall, NebulaConfigFirewallRule,
|
|
NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy,
|
|
NebulaConfigRelay, NebulaConfigTun, CONFIG,
|
|
};
|
|
use crate::crypto::{decrypt_with_nonce, get_cipher_from_config};
|
|
use crate::keystore::keystore_init;
|
|
use crate::AppState;
|
|
use ed25519_dalek::SigningKey;
|
|
use ipnet::Ipv4Net;
|
|
use log::{debug, error};
|
|
use sea_orm::{ColumnTrait, Condition, EntityTrait, QueryFilter};
|
|
use serde_yaml::{Mapping, Value};
|
|
use trifid_api_entities::entity::{
|
|
firewall_rule, host, host_config_override, host_static_address, network, organization,
|
|
signing_ca,
|
|
};
|
|
use trifid_pki::cert::{
|
|
deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate,
|
|
NebulaCertificateDetails,
|
|
};
|
|
|
|
pub struct CodegenRequiredInfo {
|
|
pub host: host::Model,
|
|
pub host_static_addresses: HashMap<String, Vec<SocketAddrV4>>,
|
|
pub network: network::Model,
|
|
pub organization: organization::Model,
|
|
pub dh_pubkey: Vec<u8>,
|
|
pub ca: signing_ca::Model,
|
|
pub other_cas: Vec<signing_ca::Model>,
|
|
pub relay_ips: Vec<Ipv4Addr>,
|
|
pub lighthouse_ips: Vec<Ipv4Addr>,
|
|
pub blocked_hosts: Vec<String>,
|
|
pub firewall_rules: Vec<NebulaConfigFirewallRule>,
|
|
pub config_overrides: Vec<(String, String)>
|
|
}
|
|
|
|
pub async fn generate_config(
|
|
_data: &Data<AppState>,
|
|
info: &CodegenRequiredInfo,
|
|
) -> Result<(NebulaConfig, NebulaCertificate), Box<dyn Error>> {
|
|
debug!(
|
|
"chk: deserialize CA cert {:x?}",
|
|
hex::decode(&info.ca.cert)?
|
|
);
|
|
// decode the CA data
|
|
let ca_cert = deserialize_nebula_certificate_from_pem(&hex::decode(&info.ca.cert)?)?;
|
|
|
|
// generate the client's new cert
|
|
let mut cert = NebulaCertificate {
|
|
details: NebulaCertificateDetails {
|
|
name: info.host.name.clone(),
|
|
ips: vec![Ipv4Net::new(
|
|
Ipv4Addr::from_str(&info.host.ip).unwrap(),
|
|
Ipv4Net::from_str(&info.network.cidr).unwrap().prefix_len(),
|
|
)
|
|
.unwrap()],
|
|
subnets: vec![],
|
|
groups: vec![format!("role:{}", info.host.role)],
|
|
not_before: SystemTime::now(),
|
|
not_after: SystemTime::now() + Duration::from_secs(CONFIG.crypto.certs_expiry_time),
|
|
public_key: info.dh_pubkey.clone().try_into().unwrap(),
|
|
is_ca: false,
|
|
issuer: ca_cert.sha256sum()?,
|
|
},
|
|
signature: vec![],
|
|
};
|
|
|
|
// decrypt the private key
|
|
let private_pem = decrypt_with_nonce(
|
|
&hex::decode(&info.ca.key)?,
|
|
hex::decode(&info.ca.nonce)?.try_into().unwrap(),
|
|
&get_cipher_from_config(&CONFIG)?,
|
|
)
|
|
.map_err(|_| "Encryption error")?;
|
|
|
|
let private_key = deserialize_ed25519_private(&private_pem)?;
|
|
let signing_key = SigningKey::from_keypair_bytes(&private_key.try_into().unwrap()).unwrap();
|
|
|
|
cert.sign(&signing_key)?;
|
|
|
|
// cas
|
|
let mut cas = String::new();
|
|
for ca in &info.other_cas {
|
|
cas += &String::from_utf8(hex::decode(&ca.cert)?)?;
|
|
}
|
|
|
|
let ks = keystore_init()?;
|
|
|
|
// blocked hosts
|
|
let mut blocked_hosts_fingerprints = vec![];
|
|
for host in &info.blocked_hosts {
|
|
if let Some(host) = ks.hosts.iter().find(|u| &u.id == host) {
|
|
for cert in &host.certs {
|
|
blocked_hosts_fingerprints.push(cert.cert.sha256sum()?);
|
|
}
|
|
}
|
|
}
|
|
|
|
let nebula_config = NebulaConfig {
|
|
pki: NebulaConfigPki {
|
|
ca: cas,
|
|
cert: String::from_utf8(cert.serialize_to_pem()?)?,
|
|
key: None,
|
|
blocklist: blocked_hosts_fingerprints,
|
|
disconnect_invalid: true,
|
|
},
|
|
static_host_map: info
|
|
.host_static_addresses
|
|
.iter()
|
|
.map(|(u, addrs)| (Ipv4Addr::from_str(u).unwrap(), addrs.clone()))
|
|
.collect(),
|
|
lighthouse: match info.host.is_lighthouse {
|
|
true => Some(NebulaConfigLighthouse {
|
|
am_lighthouse: true,
|
|
serve_dns: false,
|
|
dns: None,
|
|
interval: 60,
|
|
hosts: vec![],
|
|
remote_allow_list: HashMap::new(),
|
|
local_allow_list: HashMap::new(),
|
|
}),
|
|
false => Some(NebulaConfigLighthouse {
|
|
am_lighthouse: false,
|
|
serve_dns: false,
|
|
dns: None,
|
|
interval: 60,
|
|
hosts: info.lighthouse_ips.to_vec(),
|
|
remote_allow_list: HashMap::new(),
|
|
local_allow_list: HashMap::new(),
|
|
}),
|
|
},
|
|
listen: match info.host.is_lighthouse || info.host.is_relay {
|
|
true => Some(NebulaConfigListen {
|
|
host: "[::]".to_string(),
|
|
port: info.host.listen_port as u16,
|
|
batch: 64,
|
|
read_buffer: Some(10485760),
|
|
write_buffer: Some(10485760),
|
|
}),
|
|
false => Some(NebulaConfigListen {
|
|
host: "[::]".to_string(),
|
|
port: 0u16,
|
|
batch: 64,
|
|
read_buffer: Some(10485760),
|
|
write_buffer: Some(10485760),
|
|
}),
|
|
},
|
|
punchy: Some(NebulaConfigPunchy {
|
|
punch: true,
|
|
respond: true,
|
|
delay: "1s".to_string(),
|
|
}),
|
|
cipher: NebulaConfigCipher::Aes,
|
|
preferred_ranges: vec![],
|
|
relay: Some(NebulaConfigRelay {
|
|
relays: info.relay_ips.to_vec(),
|
|
am_relay: info.host.is_relay,
|
|
use_relays: true,
|
|
}),
|
|
tun: Some(NebulaConfigTun {
|
|
disabled: false,
|
|
dev: Some("trifid1".to_string()),
|
|
drop_local_broadcast: true,
|
|
drop_multicast: true,
|
|
tx_queue: 500,
|
|
mtu: 1300,
|
|
routes: vec![],
|
|
unsafe_routes: vec![],
|
|
}),
|
|
logging: None,
|
|
sshd: None,
|
|
firewall: Some(NebulaConfigFirewall {
|
|
conntrack: None,
|
|
inbound: Some(info.firewall_rules.clone()),
|
|
outbound: Some(vec![NebulaConfigFirewallRule {
|
|
port: Some("any".to_string()),
|
|
proto: Some("any".to_string()),
|
|
ca_name: None,
|
|
ca_sha: None,
|
|
host: Some("any".to_string()),
|
|
group: None,
|
|
groups: None,
|
|
cidr: None,
|
|
}]),
|
|
}),
|
|
routines: 1,
|
|
stats: None,
|
|
local_range: None,
|
|
};
|
|
|
|
let mut val = Mapping::new();
|
|
|
|
for (k, v) in &info.config_overrides {
|
|
let key_split = k.split('.').collect::<Vec<&str>>();
|
|
|
|
let mut value = &mut val;
|
|
|
|
for ks_k in &key_split[..key_split.len()-1] {
|
|
if !value.contains_key(ks_k) {
|
|
value.insert(Value::String(ks_k.to_string()), Value::Mapping(Mapping::new()));
|
|
}
|
|
|
|
value = value.get_mut(ks_k).ok_or("Invalid key-value pair")?.as_mapping_mut().unwrap();
|
|
}
|
|
|
|
value.insert(Value::String(key_split[key_split.len()-1].to_string()), serde_yaml::from_str(&v)?);
|
|
}
|
|
|
|
let overrides_value = Value::Mapping(val);
|
|
|
|
debug!("{:?}", overrides_value);
|
|
|
|
let mut value = serde_yaml::to_value(nebula_config)?;
|
|
|
|
debug!("{:?}", value);
|
|
|
|
merge_yaml(&mut value, overrides_value);
|
|
|
|
debug!("{:?}", value);
|
|
|
|
let nebula_config = serde_yaml::from_value(value)?;
|
|
|
|
Ok((nebula_config, cert))
|
|
}
|
|
|
|
// This cursed abomination credit https://stackoverflow.com/questions/67727239/how-to-combine-including-nested-array-values-two-serde-yamlvalue-objects
|
|
fn merge_yaml(a: &mut serde_yaml::Value, b: serde_yaml::Value) {
|
|
match (a, b) {
|
|
(a @ &mut serde_yaml::Value::Mapping(_), serde_yaml::Value::Mapping(b)) => {
|
|
let a = a.as_mapping_mut().unwrap();
|
|
for (k, v) in b {
|
|
if v.is_sequence() && a.contains_key(&k) && a[&k].is_sequence() {
|
|
let mut _b = a.get(&k).unwrap().as_sequence().unwrap().to_owned();
|
|
_b.append(&mut v.as_sequence().unwrap().to_owned());
|
|
a[&k] = serde_yaml::Value::from(_b);
|
|
continue;
|
|
}
|
|
if !a.contains_key(&k) {a.insert(k.to_owned(), v.to_owned());}
|
|
else { merge_yaml(&mut a[&k], v); }
|
|
|
|
}
|
|
|
|
}
|
|
(a, b) => *a = b,
|
|
}
|
|
}
|
|
|
|
pub async fn collect_info<'a>(
|
|
db: &'a Data<AppState>,
|
|
host: &'a str,
|
|
dh_pubkey: &'a [u8],
|
|
) -> Result<CodegenRequiredInfo, Box<dyn Error>> {
|
|
// load host info
|
|
let host = trifid_api_entities::entity::host::Entity::find()
|
|
.filter(host::Column::Id.eq(host))
|
|
.one(&db.conn)
|
|
.await?;
|
|
let host = match host {
|
|
Some(host) => host,
|
|
None => return Err("Host does not exist".into()),
|
|
};
|
|
|
|
let host_config_overrides = trifid_api_entities::entity::host_config_override::Entity::find()
|
|
.filter(host_config_override::Column::Host.eq(&host.id))
|
|
.all(&db.conn)
|
|
.await?;
|
|
|
|
let _host_static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
|
|
.filter(host_static_address::Column::Host.eq(&host.id))
|
|
.all(&db.conn)
|
|
.await?;
|
|
|
|
// load network info
|
|
let network = trifid_api_entities::entity::network::Entity::find()
|
|
.filter(network::Column::Id.eq(&host.network))
|
|
.one(&db.conn)
|
|
.await?;
|
|
|
|
let network = match network {
|
|
Some(network) => network,
|
|
None => {
|
|
return Err("Network does not exist".into());
|
|
}
|
|
};
|
|
|
|
// get all lighthouses and relays and get all of their static addresses, and get internal addresses of relays
|
|
let mut host_x_static_addresses = HashMap::new();
|
|
let mut relays = vec![];
|
|
let mut lighthouses = vec![];
|
|
let mut blocked_hosts = vec![];
|
|
|
|
let hosts = trifid_api_entities::entity::host::Entity::find()
|
|
.filter(host::Column::Network.eq(&network.id))
|
|
.filter(Condition::any().add(host::Column::IsRelay.eq(true)).add(host::Column::IsLighthouse.eq(true)))
|
|
.all(&db.conn)
|
|
.await?;
|
|
|
|
for host in hosts {
|
|
if host.is_relay {
|
|
relays.push(Ipv4Addr::from_str(&host.ip).unwrap());
|
|
} else if host.is_lighthouse {
|
|
lighthouses.push(Ipv4Addr::from_str(&host.ip).unwrap());
|
|
}
|
|
|
|
if host.is_blocked {
|
|
blocked_hosts.push(host.id.clone());
|
|
}
|
|
|
|
let static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
|
|
.filter(host_static_address::Column::Host.eq(host.id))
|
|
.all(&db.conn)
|
|
.await?;
|
|
let static_addresses: Vec<SocketAddrV4> = static_addresses
|
|
.iter()
|
|
.map(|u| SocketAddrV4::from_str(&u.address).unwrap())
|
|
.collect();
|
|
|
|
host_x_static_addresses.insert(host.ip.clone(), static_addresses);
|
|
}
|
|
|
|
// load org info
|
|
let org = trifid_api_entities::entity::organization::Entity::find()
|
|
.filter(organization::Column::Id.eq(&network.organization))
|
|
.one(&db.conn)
|
|
.await?;
|
|
let org = match org {
|
|
Some(org) => org,
|
|
None => {
|
|
return Err("Organization does not exist".into());
|
|
}
|
|
};
|
|
|
|
// get the CA that is closest to expiry, but *not* expired
|
|
let available_cas = trifid_api_entities::entity::signing_ca::Entity::find()
|
|
.filter(signing_ca::Column::Organization.eq(&org.id))
|
|
.all(&db.conn)
|
|
.await?;
|
|
|
|
let mut best_ca: Option<signing_ca::Model> = None;
|
|
let mut all_cas = vec![];
|
|
|
|
for ca in available_cas {
|
|
if let Some(existing_best) = &best_ca {
|
|
if ca.expires < existing_best.expires {
|
|
best_ca = Some(ca.clone());
|
|
}
|
|
} else {
|
|
best_ca = Some(ca.clone());
|
|
}
|
|
all_cas.push(ca);
|
|
}
|
|
|
|
if best_ca.is_none() {
|
|
error!(
|
|
"!!! NO AVAILABLE CAS !!! while trying to sign cert for {}",
|
|
org.id
|
|
);
|
|
return Err("No signing CAs available".into());
|
|
}
|
|
|
|
let best_ca = best_ca.unwrap();
|
|
|
|
// pull our host's config overrides
|
|
let config_overrides = host_config_overrides.iter().map(|u| {
|
|
(u.key.clone(), u.value.clone())
|
|
}).collect();
|
|
|
|
|
|
// pull our role's firewall rules
|
|
let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find()
|
|
.filter(firewall_rule::Column::Role.eq(&host.role))
|
|
.all(&db.conn)
|
|
.await?;
|
|
let firewall_rules = firewall_rules
|
|
.iter()
|
|
.map(|u| NebulaConfigFirewallRule {
|
|
port: Some(if u.port_range_from == 0 && u.port_range_to == 65535 {
|
|
"any".to_string()
|
|
} else {
|
|
format!("{}-{}", u.port_range_from, u.port_range_to)
|
|
}),
|
|
proto: Some(u.protocol.clone().to_lowercase()),
|
|
ca_name: None,
|
|
ca_sha: None,
|
|
host: if u.allowed_role_id.is_some() {
|
|
None
|
|
} else {
|
|
Some("any".to_string())
|
|
},
|
|
groups: if u.allowed_role_id.is_some() {
|
|
Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())])
|
|
} else {
|
|
None
|
|
},
|
|
group: None,
|
|
cidr: None,
|
|
})
|
|
.collect();
|
|
|
|
Ok(CodegenRequiredInfo {
|
|
host,
|
|
host_static_addresses: host_x_static_addresses,
|
|
network,
|
|
organization: org,
|
|
dh_pubkey: dh_pubkey.to_vec(),
|
|
ca: best_ca,
|
|
other_cas: all_cas,
|
|
relay_ips: relays,
|
|
lighthouse_ips: lighthouses,
|
|
blocked_hosts,
|
|
firewall_rules,
|
|
config_overrides
|
|
})
|
|
}
|