From a5fb79288b55ff0b762a9999bc07b4d30ea873b7 Mon Sep 17 00:00:00 2001 From: core Date: Wed, 27 Dec 2023 00:03:40 -0500 Subject: [PATCH] dnclient endpoint speedrun: 14 minutes --- Cargo.lock | 1 + trifid-api/Cargo.toml | 3 +- .../up.sql | 4 +- trifid-api/src/config.rs | 2 +- trifid-api/src/config_generator.rs | 261 ++++++++++++------ trifid-api/src/crypt.rs | 25 +- trifid-api/src/main.rs | 2 +- trifid-api/src/models.rs | 5 +- trifid-api/src/routes/v1/dnclient.rs | 201 ++++++++++++++ trifid-api/src/routes/v1/mod.rs | 1 + trifid-api/src/routes/v1/networks.rs | 4 +- trifid-api/src/routes/v2/enroll.rs | 72 +++-- trifid-api/src/schema.rs | 1 + 13 files changed, 456 insertions(+), 126 deletions(-) create mode 100644 trifid-api/src/routes/v1/dnclient.rs diff --git a/Cargo.lock b/Cargo.lock index 84fd150..b939bca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3087,6 +3087,7 @@ version = "0.3.0-alpha1" dependencies = [ "actix-cors", "actix-web", + "base64 0.21.5", "bb8", "chacha20poly1305", "chrono", diff --git a/trifid-api/Cargo.toml b/trifid-api/Cargo.toml index d06a5af..735f07b 100644 --- a/trifid-api/Cargo.toml +++ b/trifid-api/Cargo.toml @@ -34,4 +34,5 @@ chrono = "0.4" dnapi-rs = { version = "0.2", path = "../dnapi-rs" } nebula-config = { version = "0.1", path = "../nebula-config" } ipnet = "2.9" -serde_yaml = "0.9" \ No newline at end of file +serde_yaml = "0.9" +base64 = "0.21" \ No newline at end of file diff --git a/trifid-api/migrations/2023-12-24-220434_create_table_hosts/up.sql b/trifid-api/migrations/2023-12-24-220434_create_table_hosts/up.sql index 6d032e3..d27114f 100644 --- a/trifid-api/migrations/2023-12-24-220434_create_table_hosts/up.sql +++ b/trifid-api/migrations/2023-12-24-220434_create_table_hosts/up.sql @@ -17,5 +17,7 @@ CREATE TABLE hosts ( platform VARCHAR NULL, update_available BOOLEAN NULL, - tags text[] NOT NULL + tags text[] NOT NULL, + + counter INT NULL ); \ No newline at end of file diff --git a/trifid-api/src/config.rs b/trifid-api/src/config.rs index 3cfb14b..9560edc 100644 --- a/trifid-api/src/config.rs +++ b/trifid-api/src/config.rs @@ -44,5 +44,5 @@ pub struct ConfigTokens { pub session_token_expiry_seconds: u64, pub auth_token_expiry_seconds: u64, pub data_encryption_key: String, - pub cert_expiry_time_seconds: u64 + pub cert_expiry_time_seconds: u64, } diff --git a/trifid-api/src/config_generator.rs b/trifid-api/src/config_generator.rs index 591cf77..384b561 100644 --- a/trifid-api/src/config_generator.rs +++ b/trifid-api/src/config_generator.rs @@ -3,25 +3,33 @@ // and the entirety of trifid-pki, as it deals with CA private keys. // Review carefully what you write here! -use std::collections::{Bound, HashMap}; -use std::error::Error; -use std::net::{Ipv4Addr, SocketAddrV4}; -use std::str::FromStr; -use std::time::{Duration, SystemTime}; +use crate::crypt::sign_cert_with_ca; +use crate::models::{Host, HostKey, HostOverride, Network, Role, RoleFirewallRule, SigningCA}; +use crate::schema::{ + host_keys, host_overrides, hosts, networks, role_firewall_rules, roles, signing_cas, +}; +use crate::AppState; use actix_web::web::Data; use diesel::{ExpressionMethods, QueryDsl, SelectableHelper}; use diesel_async::pooled_connection::bb8::RunError; use diesel_async::RunQueryDsl; use ipnet::Ipv4Net; +use nebula_config::{ + NebulaConfig, NebulaConfigCipher, NebulaConfigFirewall, NebulaConfigFirewallRule, + NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy, + NebulaConfigRelay, NebulaConfigTun, +}; use serde_yaml::{Mapping, Value}; +use std::collections::{Bound, HashMap}; +use std::error::Error; +use std::net::{Ipv4Addr, SocketAddrV4}; +use std::str::FromStr; +use std::time::{Duration, SystemTime}; use thiserror::Error; -use nebula_config::{NebulaConfig, NebulaConfigCipher, NebulaConfigFirewall, NebulaConfigFirewallRule, NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy, NebulaConfigRelay, NebulaConfigTun}; -use trifid_pki::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate, NebulaCertificateDetails}; +use trifid_pki::cert::{ + deserialize_nebula_certificate_from_pem, NebulaCertificate, NebulaCertificateDetails, +}; use trifid_pki::x25519_dalek::PublicKey; -use crate::AppState; -use crate::crypt::sign_cert_with_ca; -use crate::models::{Host, Network, SigningCA, HostKey, Role, RoleFirewallRule, HostOverride}; -use crate::schema::{host_keys, hosts, networks, role_firewall_rules, roles, signing_cas, host_overrides}; #[derive(Error, Debug)] pub enum ConfigGenError { @@ -32,44 +40,68 @@ pub enum ConfigGenError { #[error("error parsing a signing CA: {0}")] InvalidCACert(serde_json::Error), #[error("an error occured: {0}")] - GenericError(Box) + GenericError(Box), } -pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data) -> Result { - let mut conn = state.pool.get().await.map_err(ConfigGenError::AcquireError)?; +pub async fn generate_config( + host: &Host, + dh_pubkey: PublicKey, + state: Data, +) -> Result { + let mut conn = state + .pool + .get() + .await + .map_err(ConfigGenError::AcquireError)?; - let cas = signing_cas::dsl::signing_cas.filter(signing_cas::organization_id.eq(&host.organization_id)).select(SigningCA::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; + let cas = signing_cas::dsl::signing_cas + .filter(signing_cas::organization_id.eq(&host.organization_id)) + .select(SigningCA::as_select()) + .load(&mut conn) + .await + .map_err(ConfigGenError::DbError)?; let mut good_cas = vec![]; let mut ca_string = String::new(); for ca in cas { if ca.expires_at < SystemTime::now() { - let ca_cert: NebulaCertificate = serde_json::from_value(ca.cert.clone()).map_err(ConfigGenError::InvalidCACert)?; - ca_string += &String::from_utf8_lossy(&ca_cert.serialize_to_pem().map_err(ConfigGenError::GenericError)?); + let ca_cert: NebulaCertificate = + serde_json::from_value(ca.cert.clone()).map_err(ConfigGenError::InvalidCACert)?; + ca_string += &String::from_utf8_lossy( + &ca_cert + .serialize_to_pem() + .map_err(ConfigGenError::GenericError)?, + ); good_cas.push((ca, ca_cert)); } } let (signing_ca, ca_cert) = &good_cas[0]; - let network = networks::dsl::networks.find(&host.network_id).first::(&mut conn).await.map_err(ConfigGenError::DbError)?; + let network = networks::dsl::networks + .find(&host.network_id) + .first::(&mut conn) + .await + .map_err(ConfigGenError::DbError)?; let mut cert = NebulaCertificate { details: NebulaCertificateDetails { name: host.name.clone(), - ips: vec![ - Ipv4Net::new( - Ipv4Addr::from_str(&host.ip_address).unwrap(), - Ipv4Net::from_str(&network.cidr).unwrap().prefix_len() - ).unwrap() - ], + ips: vec![Ipv4Net::new( + Ipv4Addr::from_str(&host.ip_address).unwrap(), + Ipv4Net::from_str(&network.cidr).unwrap().prefix_len(), + ) + .unwrap()], subnets: vec![], groups: if let Some(role_id) = &host.role_id { vec![format!("role:{}", role_id)] - } else { vec![] }, + } else { + vec![] + }, not_before: SystemTime::now() - Duration::from_secs(3600), - not_after: SystemTime::now() + Duration::from_secs(state.config.tokens.cert_expiry_time_seconds), + not_after: SystemTime::now() + + Duration::from_secs(state.config.tokens.cert_expiry_time_seconds), public_key: *dh_pubkey.as_bytes(), is_ca: false, issuer: ca_cert.sha256sum().unwrap(), @@ -79,12 +111,23 @@ pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data>()); + static_host_map.insert( + Ipv4Addr::from_str(&lighthouse.ip_address).unwrap(), + lighthouse + .static_addresses + .iter() + .map(|u| SocketAddrV4::from_str(&u.clone().unwrap()).unwrap()) + .collect::>(), + ); } for relay in &all_relays { - static_host_map.insert(Ipv4Addr::from_str(&relay.ip_address).unwrap(), relay.static_addresses.iter().map(|u| SocketAddrV4::from_str(&u.clone().unwrap()).unwrap()).collect::>()); + static_host_map.insert( + Ipv4Addr::from_str(&relay.ip_address).unwrap(), + relay + .static_addresses + .iter() + .map(|u| SocketAddrV4::from_str(&u.clone().unwrap()).unwrap()) + .collect::>(), + ); } let lighthouse = Some(NebulaConfigLighthouse { @@ -121,7 +190,10 @@ pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data>() + all_lighthouses + .iter() + .map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()) + .collect::>() }, remote_allow_list: Default::default(), local_allow_list: Default::default(), @@ -132,52 +204,66 @@ pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data>() + all_lighthouses + .iter() + .map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()) + .collect::>() } else { - all_relays.iter().map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()).collect::>() + all_relays + .iter() + .map(|u| Ipv4Addr::from_str(&u.ip_address).unwrap()) + .collect::>() }, - use_relays: true + use_relays: true, }); let mut inbound_firewall = None; if let Some(role_id) = &host.role_id { - let role = roles::dsl::roles.find(role_id).first::(&mut conn).await.map_err(ConfigGenError::DbError)?; + let firewall_rules = role_firewall_rules::dsl::role_firewall_rules + .filter(role_firewall_rules::role_id.eq(role_id)) + .select(RoleFirewallRule::as_select()) + .load(&mut conn) + .await + .map_err(ConfigGenError::DbError)?; - let firewall_rules = role_firewall_rules::dsl::role_firewall_rules.filter(role_firewall_rules::role_id.eq(role_id)).select(RoleFirewallRule::as_select()).load(&mut conn).await.map_err(ConfigGenError::DbError)?; - - inbound_firewall = Some(firewall_rules.iter().map(|u| NebulaConfigFirewallRule { - port: if let Some((from, to)) = u.port_range { - let start_port = match from { - Bound::Included(u) => u, - Bound::Excluded(u) => u+1, - Bound::Unbounded => 0 - }; - let end_port = match from { - Bound::Included(u) => u, - Bound::Excluded(u) => u-1, - Bound::Unbounded => 65535 - }; - Some(format!("{}-{}", start_port, end_port)) - } else { - Some("any".to_string()) - }, - proto: Some(u.protocol.clone()), - ca_name: None, - ca_sha: None, - host: if u.allowed_role_id.is_some() { - None - } else { - Some("any".to_string()) - }, - group: None, - groups: if u.allowed_role_id.is_some() { - Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())]) - } else { - None - }, - cidr: None, - }).collect::>()) + inbound_firewall = Some( + firewall_rules + .iter() + .map(|u| NebulaConfigFirewallRule { + port: if let Some((from, to)) = u.port_range { + let start_port = match from { + Bound::Included(u) => u, + Bound::Excluded(u) => u + 1, + Bound::Unbounded => 0, + }; + let end_port = match to { + Bound::Included(u) => u, + Bound::Excluded(u) => u - 1, + Bound::Unbounded => 65535, + }; + Some(format!("{}-{}", start_port, end_port)) + } else { + Some("any".to_string()) + }, + proto: Some(u.protocol.clone()), + ca_name: None, + ca_sha: None, + host: if u.allowed_role_id.is_some() { + None + } else { + Some("any".to_string()) + }, + group: None, + groups: if u.allowed_role_id.is_some() { + Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())]) + } else { + None + }, + cidr: None, + }) + .collect::>(), + ) } let config = NebulaConfig { @@ -214,18 +300,16 @@ pub async fn generate_config(host: &Host, dh_pubkey: PublicKey, state: Data, pub nonce: Vec, - pub key: Vec + pub key: Vec, } pub fn create_dnclient_ed_key( @@ -155,7 +155,7 @@ pub fn create_dnclient_ed_key( &hex::decode(&config.tokens.data_encryption_key) .map_err(|e| CryptographyError::InvalidKey(e))?, ) - .map_err(|_| CryptographyError::InvalidKeyLength)?; + .map_err(|_| CryptographyError::InvalidKeyLength)?; let salt = XChaCha20Poly1305::generate_nonce(&mut OsRng); @@ -169,13 +169,16 @@ pub fn create_dnclient_ed_key( aad: &aad, }, ) - .map_err(|e| CryptographyError::LockingError)?; + .map_err(|_| CryptographyError::LockingError)?; - Ok((DnclientKeyLockbox { - info: aad.to_vec(), - nonce: salt.as_slice().to_vec(), - key: lockbox, - }, key.verifying_key())) + Ok(( + DnclientKeyLockbox { + info: aad.to_vec(), + nonce: salt.as_slice().to_vec(), + key: lockbox, + }, + key.verifying_key(), + )) } pub fn sign_dnclient_with_lockbox( @@ -187,7 +190,7 @@ pub fn sign_dnclient_with_lockbox( &hex::decode(&config.tokens.data_encryption_key) .map_err(|e| CryptographyError::InvalidKey(e))?, ) - .map_err(|_| CryptographyError::InvalidKeyLength)?; + .map_err(|_| CryptographyError::InvalidKeyLength)?; let salt_u24: [u8; 24] = lockbox .nonce @@ -212,7 +215,7 @@ pub fn sign_dnclient_with_lockbox( .try_into() .map_err(|_| CryptographyError::InvalidSigningKeyLength)?, ) - .map_err(|e| CryptographyError::SignatureError(e))?; + .map_err(|e| CryptographyError::SignatureError(e))?; Ok(key.sign(bytes)) -} \ No newline at end of file +} diff --git a/trifid-api/src/main.rs b/trifid-api/src/main.rs index d5730d8..ddcf946 100644 --- a/trifid-api/src/main.rs +++ b/trifid-api/src/main.rs @@ -26,8 +26,8 @@ pub mod auth; pub mod email; #[macro_use] pub mod macros; -pub mod crypt; mod config_generator; +pub mod crypt; #[derive(Clone)] pub struct AppState { diff --git a/trifid-api/src/models.rs b/trifid-api/src/models.rs index 301f911..ea25b79 100644 --- a/trifid-api/src/models.rs +++ b/trifid-api/src/models.rs @@ -290,6 +290,7 @@ pub struct Host { pub platform: Option, pub update_available: Option, pub tags: Vec>, + pub counter: Option, } #[derive( @@ -311,7 +312,7 @@ pub struct HostOverride { pub id: String, pub host_id: String, pub key: String, - pub value: Value, + pub value: String, } #[derive( @@ -340,7 +341,7 @@ pub struct HostKey { pub salt: Vec, pub info: Vec, - pub server_ed_priv: Vec + pub server_ed_priv: Vec, } #[derive( diff --git a/trifid-api/src/routes/v1/dnclient.rs b/trifid-api/src/routes/v1/dnclient.rs new file mode 100644 index 0000000..f36e2d2 --- /dev/null +++ b/trifid-api/src/routes/v1/dnclient.rs @@ -0,0 +1,201 @@ +use crate::config_generator::generate_config; +use crate::crypt::{create_dnclient_ed_key, sign_dnclient_with_lockbox}; +use crate::models::{Host, HostKey}; +use crate::response::JsonAPIResponse; +use crate::schema::{host_keys, hosts}; +use crate::{randid, AppState}; +use actix_web::http::StatusCode; +use actix_web::web::{Data, Json}; +use base64::Engine; +use diesel::{ExpressionMethods, OptionalExtension, QueryDsl}; +use diesel_async::RunQueryDsl; +use dnapi_rs::message::{ + CheckForUpdateResponse, CheckForUpdateResponseWrapper, DoUpdateRequest, DoUpdateResponse, + RequestV1, RequestWrapper, SignedResponseWrapper, SignedResponse +}; +use log::warn; +use serde::Serialize; +use std::time::SystemTime; +use trifid_pki::cert::{ + deserialize_ed25519_public, deserialize_x25519_public, serialize_ed25519_public, +}; +use trifid_pki::ed25519_dalek::{Signature, Verifier, VerifyingKey}; +use trifid_pki::x25519_dalek::PublicKey; + +#[derive(Serialize, Debug)] +pub enum DnclientResponse { + CheckForUpdateResp(CheckForUpdateResponseWrapper), + DoUpdateResp(SignedResponseWrapper), +} + +pub async fn dnclient_req( + req: Json, + state: Data, +) -> JsonAPIResponse { + if req.version != 1 { + err!( + StatusCode::BAD_REQUEST, + make_err!( + "ERR_INVALID_DNCLIENT_VERSION", + "unsupported dnclient api version", + "version" + ) + ); + } + + let mut conn = handle_error!(state.pool.get().await); + + // get key for this host and counter + let maybe_key: Option = handle_error!(host_keys::dsl::host_keys + .filter(host_keys::host_id.eq(&req.host_id)) + .filter(host_keys::counter.eq(req.counter as i32)) + .first::(&mut conn) + .await + .optional()); + + let key = match maybe_key { + Some(k) => k, + None => { + err!( + StatusCode::UNAUTHORIZED, + make_err!( + "ERR_UNAUTHORIZED", + "no key for host/counter pair in keystore" + ) + ) + } + }; + + let signature = handle_error!(Signature::from_slice(&req.signature)); + let key = handle_error!(VerifyingKey::from_bytes( + &key.client_ed_pub.try_into().unwrap() + )); + + if key.verify(req.message.as_bytes(), &signature).is_err() { + warn!("! invalid signature from {}", req.host_id); + err!( + StatusCode::UNAUTHORIZED, + make_err!("ERR_UNAUTHORIZED", "unauthorized") + ) + } + + // Sig OK + + let msg_raw = handle_error!(base64::engine::general_purpose::STANDARD.decode(&req.message)); + + let req_w: RequestWrapper = handle_error!(serde_json::from_slice(&msg_raw)); + + let host = handle_error!( + hosts::table + .find(&req.host_id) + .first::(&mut conn) + .await + ); + + handle_error!( + diesel::update(&host) + .set(hosts::dsl::last_seen_at.eq(SystemTime::now())) + .execute(&mut conn) + .await + ); + + match req_w.message_type.as_str() { + "CheckForUpdate" => { + ok!(DnclientResponse::CheckForUpdateResp( + CheckForUpdateResponseWrapper { + data: CheckForUpdateResponse { + update_available: host.update_available.unwrap() + || req.counter < host.counter.unwrap() as u32 + } + } + )) + } + "DoUpdate" => { + if !host.update_available.unwrap_or(false) { + err!( + StatusCode::BAD_REQUEST, + make_err!("ERR_NO_CONFIG_AVAIL", "no new configuration available") + ); + } + + let do_update_req: DoUpdateRequest = + handle_error!(serde_json::from_slice(&req_w.value)); + + let new_dh_pub_bytes = + handle_error!(deserialize_x25519_public(&do_update_req.dh_pubkey_pem)); + let new_dh_pub_static: [u8; 32] = new_dh_pub_bytes.clone().try_into().unwrap(); + let new_dh_pub = PublicKey::from(new_dh_pub_static); + + let new_config = handle_error!(generate_config(&host, new_dh_pub, state.clone()).await); + + let new_ed_key = + handle_error!(deserialize_ed25519_public(&do_update_req.ed_pubkey_pem)); + + let new_counter = host.counter.unwrap() + 1; + + let (key_lockbox, trusted_key) = handle_error!(create_dnclient_ed_key(&state.config)); + + let new_key = HostKey { + id: randid!(id "hostkey"), + host_id: host.id.clone(), + counter: new_counter, + client_ed_pub: new_ed_key, + client_dh_pub: new_dh_pub_bytes, + client_cert: new_config.pki.cert.as_bytes().to_vec(), + salt: key_lockbox.nonce.clone(), + info: key_lockbox.info.clone(), + server_ed_priv: key_lockbox.key.clone(), + }; + + handle_error!( + diesel::insert_into(host_keys::table) + .values(&new_key) + .execute(&mut conn) + .await + ); + + handle_error!( + diesel::update(&host) + .set(( + hosts::dsl::last_seen_at.eq(SystemTime::now()), + hosts::dsl::update_available.eq(Some(false)), + hosts::dsl::counter.eq(Some(new_counter)) + )) + .execute(&mut conn) + .await + ); + + let config_str = handle_error!(serde_yaml::to_string(&new_config)); + + let msg = DoUpdateResponse { + config: config_str.as_bytes().to_vec(), + counter: new_counter as u32, + nonce: do_update_req.nonce, + trusted_keys: serialize_ed25519_public(trusted_key.as_bytes()), + }; + + let msg_bytes = handle_error!(serde_json::to_vec(&msg)); + + let resp = SignedResponse { + version: 1, + message: msg_bytes.clone(), + signature: sign_dnclient_with_lockbox(&key_lockbox, &msg_bytes, &state.config) + .unwrap().to_vec(), + }; + + let resp_w = SignedResponseWrapper { data: resp }; + + ok!(DnclientResponse::DoUpdateResp(resp_w)); + } + _ => { + err!( + StatusCode::BAD_REQUEST, + make_err!( + "ERR_INVALID_DNCLIENT_REQ_TYPE", + "unsupported dnclient request type", + "message_type" + ) + ); + } + } +} diff --git a/trifid-api/src/routes/v1/mod.rs b/trifid-api/src/routes/v1/mod.rs index 3f7c374..5bd018b 100644 --- a/trifid-api/src/routes/v1/mod.rs +++ b/trifid-api/src/routes/v1/mod.rs @@ -1,4 +1,5 @@ pub mod auth; +pub mod dnclient; pub mod networks; pub mod signup; pub mod totp_authenticators; diff --git a/trifid-api/src/routes/v1/networks.rs b/trifid-api/src/routes/v1/networks.rs index b85ee81..c88c40d 100644 --- a/trifid-api/src/routes/v1/networks.rs +++ b/trifid-api/src/routes/v1/networks.rs @@ -1,5 +1,5 @@ use crate::crypt::create_signing_ca; -use crate::models::{Network, NetworkNormalized, Organization, SigningCA, User}; +use crate::models::{Network, NetworkNormalized, Organization, User}; use crate::response::JsonAPIResponse; use crate::schema::networks::dsl::networks; use crate::schema::organizations::dsl::organizations; @@ -33,7 +33,7 @@ pub async fn create_network_req( let mut conn = handle_error!(state.pool.get().await); let auth_info = auth!(req_info, conn); - let (session_token, auth_token) = enforce!(sess auth auth_info); + let (session_token, _) = enforce!(sess auth auth_info); let user = handle_error!( users::table diff --git a/trifid-api/src/routes/v2/enroll.rs b/trifid-api/src/routes/v2/enroll.rs index 33d13ab..b10d0bb 100644 --- a/trifid-api/src/routes/v2/enroll.rs +++ b/trifid-api/src/routes/v2/enroll.rs @@ -1,17 +1,19 @@ -use std::time::SystemTime; -use crate::response::JsonAPIResponse; -use crate::{AppState, randid}; -use actix_web::web::{Data, Json}; -use actix_web::http::StatusCode; -use dnapi_rs::message::{EnrollRequest, EnrollResponse, EnrollResponseData}; -use crate::models::{EnrollmentCode, Host, HostKey}; -use crate::schema::{enrollment_codes, hosts}; -use diesel::{QueryDsl, OptionalExtension, ExpressionMethods}; -use diesel_async::RunQueryDsl; -use trifid_pki::x25519_dalek::PublicKey; use crate::config_generator::generate_config; use crate::crypt::create_dnclient_ed_key; +use crate::models::Organization; +use crate::models::{EnrollmentCode, Host, HostKey}; +use crate::response::JsonAPIResponse; use crate::schema::host_keys; +use crate::schema::{enrollment_codes, hosts, organizations}; +use crate::{randid, AppState}; +use actix_web::http::StatusCode; +use actix_web::web::{Data, Json}; +use diesel::{ExpressionMethods, OptionalExtension, QueryDsl}; +use diesel_async::RunQueryDsl; +use dnapi_rs::message::{EnrollRequest, EnrollResponse, EnrollResponseData, EnrollResponseDataOrg}; +use std::time::SystemTime; +use trifid_pki::cert::serialize_ed25519_public; +use trifid_pki::x25519_dalek::PublicKey; pub async fn enroll_req( req: Json, @@ -40,13 +42,13 @@ pub async fn enroll_req( if token.expires < SystemTime::now() { err!( - StatusCode::BAD_REQUEST, - make_err!( - "ERR_INVALID_VALUE", - "does not exist (maybe it expired?)", - "magicLinkToken" - ) - ); + StatusCode::BAD_REQUEST, + make_err!( + "ERR_INVALID_VALUE", + "does not exist (maybe it expired?)", + "magicLinkToken" + ) + ); } // valid token @@ -74,7 +76,11 @@ pub async fn enroll_req( // reset the host's key entries - handle_error!(diesel::delete(host_keys::dsl::host_keys.filter(host_keys::host_id.eq(&host.id))).execute(&mut conn).await); + handle_error!( + diesel::delete(host_keys::dsl::host_keys.filter(host_keys::host_id.eq(&host.id))) + .execute(&mut conn) + .await + ); let (key_lockbox, trusted_key) = handle_error!(create_dnclient_ed_key(&state.config)); @@ -103,12 +109,36 @@ pub async fn enroll_req( .await ); - let config_bytes = serde_yaml::to_string(&config).unwrap().as_bytes(); + let org = handle_error!( + organizations::dsl::organizations + .find(host.organization_id.clone()) + .first::(&mut conn) + .await + ); + + let config_bytes = serde_yaml::to_string(&config).unwrap(); + + handle_error!( + diesel::update(&host) + .set(( + hosts::dsl::last_seen_at.eq(SystemTime::now()), + hosts::dsl::update_available.eq(Some(false)), + hosts::dsl::counter.eq(Some(1)) + )) + .execute(&mut conn) + .await + ); ok!(EnrollResponse::Success { data: EnrollResponseData { host_id: host.id.clone(), - config: config_bytes.to_vec() + config: config_bytes.as_bytes().to_vec(), + counter: 1, + trusted_keys: serialize_ed25519_public(trusted_key.as_bytes()), + organization: EnrollResponseDataOrg { + id: org.id, + name: org.name + } } }) } diff --git a/trifid-api/src/schema.rs b/trifid-api/src/schema.rs index 4908518..fbaa696 100644 --- a/trifid-api/src/schema.rs +++ b/trifid-api/src/schema.rs @@ -58,6 +58,7 @@ diesel::table! { platform -> Nullable, update_available -> Nullable, tags -> Array>, + counter -> Nullable, } }