enrollment

This commit is contained in:
c0repwn3r 2023-05-14 13:47:49 -04:00
parent 224e3680e0
commit b8c6ddd123
Signed by: core
GPG Key ID: FDBF740DADDCEECF
36 changed files with 3365 additions and 1201 deletions

1135
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[package] [package]
name = "dnapi-rs" name = "dnapi-rs"
version = "0.1.9" version = "0.1.11"
edition = "2021" edition = "2021"
description = "A rust client for the Defined Networking API" description = "A rust client for the Defined Networking API"
license = "AGPL-3.0-or-later" license = "AGPL-3.0-or-later"

View File

@ -1,17 +1,21 @@
//! Client structs to handle communication with the Defined Networking API. This is the async client API - if you want blocking instead, enable the blocking (or default) feature instead. //! Client structs to handle communication with the Defined Networking API. This is the async client API - if you want blocking instead, enable the blocking (or default) feature instead.
use std::error::Error; use crate::credentials::{ed25519_public_keys_from_pem, Credentials};
use crate::crypto::{new_keys, nonce};
use crate::message::{
CheckForUpdateResponseWrapper, DoUpdateRequest, DoUpdateResponse, EnrollRequest,
EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper, CHECK_FOR_UPDATE, DO_UPDATE,
ENDPOINT_V1, ENROLL_ENDPOINT,
};
use base64::Engine;
use chrono::Local; use chrono::Local;
use log::{debug, error}; use log::{debug, error};
use reqwest::StatusCode; use reqwest::StatusCode;
use url::Url; use serde::{Deserialize, Serialize};
use std::error::Error;
use trifid_pki::cert::serialize_ed25519_public; use trifid_pki::cert::serialize_ed25519_public;
use trifid_pki::ed25519_dalek::{Signature, Signer, SigningKey, Verifier}; use trifid_pki::ed25519_dalek::{Signature, Signer, SigningKey, Verifier};
use crate::credentials::{Credentials, ed25519_public_keys_from_pem}; use url::Url;
use crate::crypto::{new_keys, nonce};
use crate::message::{CHECK_FOR_UPDATE, CheckForUpdateResponseWrapper, DO_UPDATE, DoUpdateRequest, DoUpdateResponse, ENDPOINT_V1, ENROLL_ENDPOINT, EnrollRequest, EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper};
use serde::{Serialize, Deserialize};
use base64::Engine;
/// A type alias to abstract return types /// A type alias to abstract return types
pub type NebulaConfig = Vec<u8>; pub type NebulaConfig = Vec<u8>;
@ -22,7 +26,7 @@ pub type DHPrivateKeyPEM = Vec<u8>;
/// A combination of persistent data and HTTP client used for communicating with the API. /// A combination of persistent data and HTTP client used for communicating with the API.
pub struct Client { pub struct Client {
http_client: reqwest::Client, http_client: reqwest::Client,
server_url: Url server_url: Url,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@ -31,7 +35,7 @@ pub struct EnrollMeta {
/// The server organization ID this node is now a member of /// The server organization ID this node is now a member of
pub organization_id: String, pub organization_id: String,
/// The server organization name this node is now a member of /// The server organization name this node is now a member of
pub organization_name: String pub organization_name: String,
} }
impl Client { impl Client {
@ -42,7 +46,7 @@ impl Client {
let client = reqwest::Client::builder().user_agent(user_agent).build()?; let client = reqwest::Client::builder().user_agent(user_agent).build()?;
Ok(Self { Ok(Self {
http_client: client, http_client: client,
server_url: api_base server_url: api_base,
}) })
} }
@ -59,8 +63,14 @@ impl Client {
/// - the server returns an error /// - the server returns an error
/// - the server returns invalid JSON /// - the server returns invalid JSON
/// - the `trusted_keys` field is invalid /// - the `trusted_keys` field is invalid
pub async fn enroll(&self, code: &str) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> { pub async fn enroll(
debug!("making enrollment request to API {{server: {}, code: {}}}", self.server_url, code); &self,
code: &str,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> {
debug!(
"making enrollment request to API {{server: {}, code: {}}}",
self.server_url, code
);
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys(); let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
@ -71,9 +81,19 @@ impl Client {
timestamp: Local::now().format("%Y-%m-%dT%H:%M:%S.%f%:z").to_string(), timestamp: Local::now().format("%Y-%m-%dT%H:%M:%S.%f%:z").to_string(),
})?; })?;
let resp = self.http_client.post(self.server_url.join(ENROLL_ENDPOINT)?).body(req_json).send().await?; let resp = self
.http_client
.post(self.server_url.join(ENROLL_ENDPOINT)?)
.body(req_json)
.header("Content-Type", "application/json")
.send()
.await?;
let req_id = resp.headers().get("X-Request-ID").ok_or("Response missing X-Request-ID")?.to_str()?; let req_id = resp
.headers()
.get("X-Request-ID")
.ok_or("Response missing X-Request-ID")?
.to_str()?;
debug!("enrollment request complete {{req_id: {}}}", req_id); debug!("enrollment request complete {{req_id: {}}}", req_id);
let resp: EnrollResponse = resp.json().await?; let resp: EnrollResponse = resp.json().await?;
@ -107,7 +127,15 @@ impl Client {
/// # Errors /// # Errors
/// This function returns an error if the dnclient request fails, or the server returns invalid data. /// This function returns an error if the dnclient request fails, or the server returns invalid data.
pub async fn check_for_update(&self, creds: &Credentials) -> Result<bool, Box<dyn Error>> { pub async fn check_for_update(&self, creds: &Credentials) -> Result<bool, Box<dyn Error>> {
let body = self.post_dnclient(CHECK_FOR_UPDATE, &[], &creds.host_id, creds.counter, &creds.ed_privkey).await?; let body = self
.post_dnclient(
CHECK_FOR_UPDATE,
&[],
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)
.await?;
let result: CheckForUpdateResponseWrapper = serde_json::from_slice(&body)?; let result: CheckForUpdateResponseWrapper = serde_json::from_slice(&body)?;
@ -125,7 +153,10 @@ impl Client {
/// - if the response could not be deserialized /// - if the response could not be deserialized
/// - if the signature is invalid /// - if the signature is invalid
/// - if the keys are invalid /// - if the keys are invalid
pub async fn do_update(&self, creds: &Credentials) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> { pub async fn do_update(
&self,
creds: &Credentials,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> {
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys(); let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
let update_keys = DoUpdateRequest { let update_keys = DoUpdateRequest {
@ -136,28 +167,45 @@ impl Client {
let update_keys_blob = serde_json::to_vec(&update_keys)?; let update_keys_blob = serde_json::to_vec(&update_keys)?;
let resp = self.post_dnclient(DO_UPDATE, &update_keys_blob, &creds.host_id, creds.counter, &creds.ed_privkey).await?; let resp = self
.post_dnclient(
DO_UPDATE,
&update_keys_blob,
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)
.await?;
let result_wrapper: SignedResponseWrapper = serde_json::from_slice(&resp)?; let result_wrapper: SignedResponseWrapper = serde_json::from_slice(&resp)?;
let mut valid = false; let mut valid = false;
for ca_pubkey in &creds.trusted_keys { for ca_pubkey in &creds.trusted_keys {
if ca_pubkey.verify(&result_wrapper.data.message, &Signature::from_slice(&result_wrapper.data.signature)?).is_ok() { if ca_pubkey
.verify(
&result_wrapper.data.message,
&Signature::from_slice(&result_wrapper.data.signature)?,
)
.is_ok()
{
valid = true; valid = true;
break; break;
} }
} }
if !valid { if !valid {
return Err("Failed to verify signed API result".into()) return Err("Failed to verify signed API result".into());
} }
let result: DoUpdateResponse = serde_json::from_slice(&result_wrapper.data.message)?; let result: DoUpdateResponse = serde_json::from_slice(&result_wrapper.data.message)?;
if result.nonce != update_keys.nonce { if result.nonce != update_keys.nonce {
error!("nonce mismatch between request {:x?} and response {:x?}", result.nonce, update_keys.nonce); error!(
return Err("nonce mismatch between request and response".into()) "nonce mismatch between request {:x?} and response {:x?}",
result.nonce, update_keys.nonce
);
return Err("nonce mismatch between request and response".into());
} }
let trusted_keys = ed25519_public_keys_from_pem(&result.trusted_keys)?; let trusted_keys = ed25519_public_keys_from_pem(&result.trusted_keys)?;
@ -179,7 +227,14 @@ impl Client {
/// - serialization in any step fails /// - serialization in any step fails
/// - if the `server_url` is invalid /// - if the `server_url` is invalid
/// - if the request could not be sent /// - if the request could not be sent
pub async fn post_dnclient(&self, req_type: &str, value: &[u8], host_id: &str, counter: u32, ed_privkey: &SigningKey) -> Result<Vec<u8>, Box<dyn Error>> { pub async fn post_dnclient(
&self,
req_type: &str,
value: &[u8],
host_id: &str,
counter: u32,
ed_privkey: &SigningKey,
) -> Result<Vec<u8>, Box<dyn Error>> {
let encoded_msg = serde_json::to_string(&RequestWrapper { let encoded_msg = serde_json::to_string(&RequestWrapper {
message_type: req_type.to_string(), message_type: req_type.to_string(),
value: value.to_vec(), value: value.to_vec(),
@ -203,19 +258,23 @@ impl Client {
let post_body = serde_json::to_string(&body)?; let post_body = serde_json::to_string(&body)?;
let resp = self.http_client.post(self.server_url.join(ENDPOINT_V1)?).body(post_body).send().await?; let resp = self
.http_client
.post(self.server_url.join(ENDPOINT_V1)?)
.body(post_body)
.send()
.await?;
match resp.status() { match resp.status() {
StatusCode::OK => { StatusCode::OK => Ok(resp.bytes().await?.to_vec()),
Ok(resp.bytes().await?.to_vec()) StatusCode::FORBIDDEN => Err("Forbidden".into()),
},
StatusCode::FORBIDDEN => {
Err("Forbidden".into())
},
_ => { _ => {
error!("dnclient endpoint returned bad status code {}", resp.status()); error!(
"dnclient endpoint returned bad status code {}",
resp.status()
);
Err("dnclient endpoint returned error".into()) Err("dnclient endpoint returned error".into())
} }
} }
} }
} }

View File

@ -1,17 +1,21 @@
//! Client structs to handle communication with the Defined Networking API. This is the blocking client API - if you want async instead, set no-default-features and enable the async feature instead. //! Client structs to handle communication with the Defined Networking API. This is the blocking client API - if you want async instead, set no-default-features and enable the async feature instead.
use std::error::Error; use crate::credentials::{ed25519_public_keys_from_pem, Credentials};
use crate::crypto::{new_keys, nonce};
use crate::message::{
CheckForUpdateResponseWrapper, DoUpdateRequest, DoUpdateResponse, EnrollRequest,
EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper, CHECK_FOR_UPDATE, DO_UPDATE,
ENDPOINT_V1, ENROLL_ENDPOINT,
};
use base64::Engine; use base64::Engine;
use chrono::Local; use chrono::Local;
use log::{debug, error, trace}; use log::{debug, error, trace};
use reqwest::StatusCode; use reqwest::StatusCode;
use url::Url; use serde::{Deserialize, Serialize};
use std::error::Error;
use trifid_pki::cert::serialize_ed25519_public; use trifid_pki::cert::serialize_ed25519_public;
use trifid_pki::ed25519_dalek::{Signature, Signer, SigningKey, Verifier}; use trifid_pki::ed25519_dalek::{Signature, Signer, SigningKey, Verifier};
use crate::credentials::{Credentials, ed25519_public_keys_from_pem}; use url::Url;
use crate::crypto::{new_keys, nonce};
use crate::message::{CHECK_FOR_UPDATE, CheckForUpdateResponseWrapper, DO_UPDATE, DoUpdateRequest, DoUpdateResponse, ENDPOINT_V1, ENROLL_ENDPOINT, EnrollRequest, EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper};
use serde::{Serialize, Deserialize};
/// A type alias to abstract return types /// A type alias to abstract return types
pub type NebulaConfig = Vec<u8>; pub type NebulaConfig = Vec<u8>;
@ -22,7 +26,7 @@ pub type DHPrivateKeyPEM = Vec<u8>;
/// A combination of persistent data and HTTP client used for communicating with the API. /// A combination of persistent data and HTTP client used for communicating with the API.
pub struct Client { pub struct Client {
http_client: reqwest::blocking::Client, http_client: reqwest::blocking::Client,
server_url: Url server_url: Url,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@ -31,7 +35,7 @@ pub struct EnrollMeta {
/// The server organization ID this node is now a member of /// The server organization ID this node is now a member of
pub organization_id: String, pub organization_id: String,
/// The server organization name this node is now a member of /// The server organization name this node is now a member of
pub organization_name: String pub organization_name: String,
} }
impl Client { impl Client {
@ -39,10 +43,12 @@ impl Client {
/// # Errors /// # Errors
/// This function will return an error if the reqwest Client could not be created. /// This function will return an error if the reqwest Client could not be created.
pub fn new(user_agent: String, api_base: Url) -> Result<Self, Box<dyn Error>> { pub fn new(user_agent: String, api_base: Url) -> Result<Self, Box<dyn Error>> {
let client = reqwest::blocking::Client::builder().user_agent(user_agent).build()?; let client = reqwest::blocking::Client::builder()
.user_agent(user_agent)
.build()?;
Ok(Self { Ok(Self {
http_client: client, http_client: client,
server_url: api_base server_url: api_base,
}) })
} }
@ -59,8 +65,14 @@ impl Client {
/// - the server returns an error /// - the server returns an error
/// - the server returns invalid JSON /// - the server returns invalid JSON
/// - the `trusted_keys` field is invalid /// - the `trusted_keys` field is invalid
pub fn enroll(&self, code: &str) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> { pub fn enroll(
debug!("making enrollment request to API {{server: {}, code: {}}}", self.server_url, code); &self,
code: &str,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> {
debug!(
"making enrollment request to API {{server: {}, code: {}}}",
self.server_url, code
);
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys(); let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
@ -71,9 +83,18 @@ impl Client {
timestamp: Local::now().format("%Y-%m-%dT%H:%M:%S.%f%:z").to_string(), timestamp: Local::now().format("%Y-%m-%dT%H:%M:%S.%f%:z").to_string(),
})?; })?;
let resp = self.http_client.post(self.server_url.join(ENROLL_ENDPOINT)?).body(req_json).send()?; let resp = self
.http_client
.post(self.server_url.join(ENROLL_ENDPOINT)?)
.header("Content-Type", "application/json")
.body(req_json)
.send()?;
let req_id = resp.headers().get("X-Request-ID").ok_or("Response missing X-Request-ID")?.to_str()?; let req_id = resp
.headers()
.get("X-Request-ID")
.ok_or("Response missing X-Request-ID")?
.to_str()?;
debug!("enrollment request complete {{req_id: {}}}", req_id); debug!("enrollment request complete {{req_id: {}}}", req_id);
let resp: EnrollResponse = resp.json()?; let resp: EnrollResponse = resp.json()?;
@ -93,7 +114,6 @@ impl Client {
debug!("parsing public keys"); debug!("parsing public keys");
let trusted_keys = ed25519_public_keys_from_pem(&r.trusted_keys)?; let trusted_keys = ed25519_public_keys_from_pem(&r.trusted_keys)?;
let creds = Credentials { let creds = Credentials {
@ -110,7 +130,13 @@ impl Client {
/// # Errors /// # Errors
/// This function returns an error if the dnclient request fails, or the server returns invalid data. /// This function returns an error if the dnclient request fails, or the server returns invalid data.
pub fn check_for_update(&self, creds: &Credentials) -> Result<bool, Box<dyn Error>> { pub fn check_for_update(&self, creds: &Credentials) -> Result<bool, Box<dyn Error>> {
let body = self.post_dnclient(CHECK_FOR_UPDATE, &[], &creds.host_id, creds.counter, &creds.ed_privkey)?; let body = self.post_dnclient(
CHECK_FOR_UPDATE,
&[],
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)?;
let result: CheckForUpdateResponseWrapper = serde_json::from_slice(&body)?; let result: CheckForUpdateResponseWrapper = serde_json::from_slice(&body)?;
@ -128,7 +154,10 @@ impl Client {
/// - if the response could not be deserialized /// - if the response could not be deserialized
/// - if the signature is invalid /// - if the signature is invalid
/// - if the keys are invalid /// - if the keys are invalid
pub fn do_update(&self, creds: &Credentials) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> { pub fn do_update(
&self,
creds: &Credentials,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> {
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys(); let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
let update_keys = DoUpdateRequest { let update_keys = DoUpdateRequest {
@ -139,33 +168,51 @@ impl Client {
let update_keys_blob = serde_json::to_vec(&update_keys)?; let update_keys_blob = serde_json::to_vec(&update_keys)?;
let resp = self.post_dnclient(DO_UPDATE, &update_keys_blob, &creds.host_id, creds.counter, &creds.ed_privkey)?; let resp = self.post_dnclient(
DO_UPDATE,
&update_keys_blob,
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)?;
let result_wrapper: SignedResponseWrapper = serde_json::from_slice(&resp)?; let result_wrapper: SignedResponseWrapper = serde_json::from_slice(&resp)?;
let mut valid = false; let mut valid = false;
for ca_pubkey in &creds.trusted_keys { for ca_pubkey in &creds.trusted_keys {
if ca_pubkey.verify(&result_wrapper.data.message, &Signature::from_slice(&result_wrapper.data.signature)?).is_ok() { if ca_pubkey
.verify(
&result_wrapper.data.message,
&Signature::from_slice(&result_wrapper.data.signature)?,
)
.is_ok()
{
valid = true; valid = true;
break; break;
} }
} }
if !valid { if !valid {
return Err("Failed to verify signed API result".into()) return Err("Failed to verify signed API result".into());
} }
let result: DoUpdateResponse = serde_json::from_slice(&result_wrapper.data.message)?; let result: DoUpdateResponse = serde_json::from_slice(&result_wrapper.data.message)?;
if result.nonce != update_keys.nonce { if result.nonce != update_keys.nonce {
error!("nonce mismatch between request {:x?} and response {:x?}", result.nonce, update_keys.nonce); error!(
return Err("nonce mismatch between request and response".into()) "nonce mismatch between request {:x?} and response {:x?}",
result.nonce, update_keys.nonce
);
return Err("nonce mismatch between request and response".into());
} }
if result.counter <= creds.counter { if result.counter <= creds.counter {
error!("counter in request {} should be less than counter in response {}", creds.counter, result.counter); error!(
return Err("received older config than what we already had".into()) "counter in request {} should be less than counter in response {}",
creds.counter, result.counter
);
return Err("received older config than what we already had".into());
} }
let trusted_keys = ed25519_public_keys_from_pem(&result.trusted_keys)?; let trusted_keys = ed25519_public_keys_from_pem(&result.trusted_keys)?;
@ -187,7 +234,14 @@ impl Client {
/// - serialization in any step fails /// - serialization in any step fails
/// - if the `server_url` is invalid /// - if the `server_url` is invalid
/// - if the request could not be sent /// - if the request could not be sent
pub fn post_dnclient(&self, req_type: &str, value: &[u8], host_id: &str, counter: u32, ed_privkey: &SigningKey) -> Result<Vec<u8>, Box<dyn Error>> { pub fn post_dnclient(
&self,
req_type: &str,
value: &[u8],
host_id: &str,
counter: u32,
ed_privkey: &SigningKey,
) -> Result<Vec<u8>, Box<dyn Error>> {
let encoded_msg = serde_json::to_string(&RequestWrapper { let encoded_msg = serde_json::to_string(&RequestWrapper {
message_type: req_type.to_string(), message_type: req_type.to_string(),
value: value.to_vec(), value: value.to_vec(),
@ -213,19 +267,22 @@ impl Client {
trace!("sending dnclient request {}", post_body); trace!("sending dnclient request {}", post_body);
let resp = self.http_client.post(self.server_url.join(ENDPOINT_V1)?).body(post_body).send()?; let resp = self
.http_client
.post(self.server_url.join(ENDPOINT_V1)?)
.body(post_body)
.send()?;
match resp.status() { match resp.status() {
StatusCode::OK => { StatusCode::OK => Ok(resp.bytes()?.to_vec()),
Ok(resp.bytes()?.to_vec()) StatusCode::FORBIDDEN => Err("Forbidden".into()),
},
StatusCode::FORBIDDEN => {
Err("Forbidden".into())
},
_ => { _ => {
error!("dnclient endpoint returned bad status code {}", resp.status()); error!(
"dnclient endpoint returned bad status code {}",
resp.status()
);
Err("dnclient endpoint returned error".into()) Err("dnclient endpoint returned error".into())
} }
} }
} }
} }

View File

@ -1,9 +1,9 @@
//! Contains the `Credentials` struct, which contains all keys, IDs, organizations and other identity-related and security-related data that is persistent in a `Client` //! Contains the `Credentials` struct, which contains all keys, IDs, organizations and other identity-related and security-related data that is persistent in a `Client`
use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use trifid_pki::cert::{deserialize_ed25519_public_many, serialize_ed25519_public}; use trifid_pki::cert::{deserialize_ed25519_public_many, serialize_ed25519_public};
use trifid_pki::ed25519_dalek::{SigningKey, VerifyingKey}; use trifid_pki::ed25519_dalek::{SigningKey, VerifyingKey};
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
/// Contains information necessary to make requests against the `DNClient` API. /// Contains information necessary to make requests against the `DNClient` API.
@ -15,7 +15,7 @@ pub struct Credentials {
/// The counter used in the other API requests. It is unknown what the purpose of this is, but the original client persists it and it is needed for API calls. /// The counter used in the other API requests. It is unknown what the purpose of this is, but the original client persists it and it is needed for API calls.
pub counter: u32, pub counter: u32,
/// The set of trusted ed25519 keys that may be used by the API to sign API responses. /// The set of trusted ed25519 keys that may be used by the API to sign API responses.
pub trusted_keys: Vec<VerifyingKey> pub trusted_keys: Vec<VerifyingKey>,
} }
/// Converts an array of `VerifyingKey`s to a singular bundle of PEM-encoded keys /// Converts an array of `VerifyingKey`s to a singular bundle of PEM-encoded keys
@ -38,8 +38,10 @@ pub fn ed25519_public_keys_from_pem(pem: &[u8]) -> Result<Vec<VerifyingKey>, Box
#[allow(clippy::unwrap_used)] #[allow(clippy::unwrap_used)]
for pem in pems { for pem in pems {
keys.push(VerifyingKey::from_bytes(&pem.try_into().unwrap_or_else(|_| unreachable!()))?); keys.push(VerifyingKey::from_bytes(
&pem.try_into().unwrap_or_else(|_| unreachable!()),
)?);
} }
Ok(keys) Ok(keys)
} }

View File

@ -1,7 +1,7 @@
//! Functions for generating keys and nonces for use in API calls //! Functions for generating keys and nonces for use in API calls
use rand::Rng;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::Rng;
use trifid_pki::cert::{serialize_x25519_private, serialize_x25519_public}; use trifid_pki::cert::{serialize_x25519_private, serialize_x25519_public};
use trifid_pki::ed25519_dalek::{SigningKey, VerifyingKey}; use trifid_pki::ed25519_dalek::{SigningKey, VerifyingKey};
use trifid_pki::x25519_dalek::{PublicKey, StaticSecret}; use trifid_pki::x25519_dalek::{PublicKey, StaticSecret};
@ -38,4 +38,4 @@ pub fn new_ed25519_keypair() -> (VerifyingKey, SigningKey) {
/// Generates a 16-byte random nonce for use in API calls /// Generates a 16-byte random nonce for use in API calls
pub fn nonce() -> [u8; 16] { pub fn nonce() -> [u8; 16] {
rand::thread_rng().gen() rand::thread_rng().gen()
} }

View File

@ -17,8 +17,8 @@
pub mod message; pub mod message;
pub mod client_blocking;
pub mod client_async; pub mod client_async;
pub mod client_blocking;
pub mod credentials; pub mod credentials;
pub mod crypto; pub mod crypto;

View File

@ -1,7 +1,7 @@
//! Models for interacting with the Defined Networking API. //! Models for interacting with the Defined Networking API.
use base64_serde::base64_serde_type; use base64_serde::base64_serde_type;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// The version 1 `DNClient` API endpoint /// The version 1 `DNClient` API endpoint
pub const ENDPOINT_V1: &str = "/v1/dnclient"; pub const ENDPOINT_V1: &str = "/v1/dnclient";
@ -27,7 +27,7 @@ pub struct RequestV1 {
pub message: String, pub message: String,
#[serde(with = "Base64Standard")] #[serde(with = "Base64Standard")]
/// An ed25519 signature over the `message`, which can be verified with the host's previously enrolled ed25519 public key /// An ed25519 signature over the `message`, which can be verified with the host's previously enrolled ed25519 public key
pub signature: Vec<u8> pub signature: Vec<u8>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -45,14 +45,14 @@ pub struct RequestWrapper {
/// For example: /// For example:
/// `2023-03-29T09:56:42.380006369-04:00` /// `2023-03-29T09:56:42.380006369-04:00`
/// would represent `29 March 03, 2023, 09:56:42.380006369 UTC-4` /// would represent `29 March 03, 2023, 09:56:42.380006369 UTC-4`
pub timestamp: String pub timestamp: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
/// `SignedResponseWrapper` contains a response message and a signature to validate inside `data`. /// `SignedResponseWrapper` contains a response message and a signature to validate inside `data`.
pub struct SignedResponseWrapper { pub struct SignedResponseWrapper {
/// The response data contained in this message /// The response data contained in this message
pub data: SignedResponse pub data: SignedResponse,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -65,14 +65,14 @@ pub struct SignedResponse {
pub message: Vec<u8>, pub message: Vec<u8>,
#[serde(with = "Base64Standard")] #[serde(with = "Base64Standard")]
/// The ed25519 signature over the `message` /// The ed25519 signature over the `message`
pub signature: Vec<u8> pub signature: Vec<u8>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
/// `CheckForUpdateResponseWrapper` contains a response to `CheckForUpdate` inside "data." /// `CheckForUpdateResponseWrapper` contains a response to `CheckForUpdate` inside "data."
pub struct CheckForUpdateResponseWrapper { pub struct CheckForUpdateResponseWrapper {
/// The response data contained in this message /// The response data contained in this message
pub data: CheckForUpdateResponse pub data: CheckForUpdateResponse,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -80,7 +80,7 @@ pub struct CheckForUpdateResponseWrapper {
pub struct CheckForUpdateResponse { pub struct CheckForUpdateResponse {
#[serde(rename = "updateAvailable")] #[serde(rename = "updateAvailable")]
/// Set to true if a config update is available /// Set to true if a config update is available
pub update_available: bool pub update_available: bool,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -97,7 +97,7 @@ pub struct DoUpdateRequest {
#[serde(with = "Base64Standard")] #[serde(with = "Base64Standard")]
/// A randomized value used to uniquely identify this request. /// A randomized value used to uniquely identify this request.
/// The original client uses a randomized, 16-byte value here, which dnapi-rs replicates /// The original client uses a randomized, 16-byte value here, which dnapi-rs replicates
pub nonce: Vec<u8> pub nonce: Vec<u8>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -114,13 +114,13 @@ pub struct DoUpdateResponse {
#[serde(rename = "trustedKeys")] #[serde(rename = "trustedKeys")]
#[serde(with = "Base64Standard")] #[serde(with = "Base64Standard")]
/// A new set of trusted ed25519 keys that can be used by the server to sign messages. /// A new set of trusted ed25519 keys that can be used by the server to sign messages.
pub trusted_keys: Vec<u8> pub trusted_keys: Vec<u8>,
} }
/// The REST enrollment endpoint /// The REST enrollment endpoint
pub const ENROLL_ENDPOINT: &str = "/v2/enroll"; pub const ENROLL_ENDPOINT: &str = "/v2/enroll";
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Debug)]
/// `EnrollRequest` is issued to the `ENROLL_ENDPOINT` to enroll this `dnclient` with a dnapi organization /// `EnrollRequest` is issued to the `ENROLL_ENDPOINT` to enroll this `dnclient` with a dnapi organization
pub struct EnrollRequest { pub struct EnrollRequest {
/// The enrollment code given by the API server. /// The enrollment code given by the API server.
@ -138,10 +138,9 @@ pub struct EnrollRequest {
/// For example: /// For example:
/// `2023-03-29T09:56:42.380006369-04:00` /// `2023-03-29T09:56:42.380006369-04:00`
/// would represent `29 March 03, 2023, 09:56:42.380006369 UTC-4` /// would represent `29 March 03, 2023, 09:56:42.380006369 UTC-4`
pub timestamp: String pub timestamp: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
/// The response to an `EnrollRequest` /// The response to an `EnrollRequest`
@ -149,13 +148,13 @@ pub enum EnrollResponse {
/// A successful enrollment, with a `data` field pointing to an `EnrollResponseData` /// A successful enrollment, with a `data` field pointing to an `EnrollResponseData`
Success { Success {
/// The response data from this response /// The response data from this response
data: EnrollResponseData data: EnrollResponseData,
}, },
/// An unsuccessful enrollment, with an `errors` field pointing to an array of `APIError`s. /// An unsuccessful enrollment, with an `errors` field pointing to an array of `APIError`s.
Error { Error {
/// A list of `APIError`s that happened while trying to enroll. `APIErrors` is a type alias to `Vec<APIError>` /// A list of `APIError`s that happened while trying to enroll. `APIErrors` is a type alias to `Vec<APIError>`
errors: APIErrors errors: APIErrors,
} },
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -174,7 +173,7 @@ pub struct EnrollResponseData {
/// A new set of trusted ed25519 keys that can be used by the server to sign messages. /// A new set of trusted ed25519 keys that can be used by the server to sign messages.
pub trusted_keys: Vec<u8>, pub trusted_keys: Vec<u8>,
/// The organization data that this node is now a part of /// The organization data that this node is now a part of
pub organization: EnrollResponseDataOrg pub organization: EnrollResponseDataOrg,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -183,7 +182,7 @@ pub struct EnrollResponseDataOrg {
/// The organization ID that this node is now a part of /// The organization ID that this node is now a part of
pub id: String, pub id: String,
/// The name of the organization that this node is now a part of /// The name of the organization that this node is now a part of
pub name: String pub name: String,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -194,8 +193,8 @@ pub struct APIError {
/// The human-readable error message /// The human-readable error message
pub message: String, pub message: String,
/// An optional path to where the error occured /// An optional path to where the error occured
pub path: Option<String> pub path: Option<String>,
} }
/// A type alias to a array of `APIErrors`. Just for parity with dnapi. /// A type alias to a array of `APIErrors`. Just for parity with dnapi.
pub type APIErrors = Vec<APIError>; pub type APIErrors = Vec<APIError>;

View File

@ -1,6 +1,6 @@
[package] [package]
name = "tfclient" name = "tfclient"
version = "0.1.7" version = "0.1.8"
edition = "2021" edition = "2021"
description = "An open-source reimplementation of a Defined Networking-compatible client" description = "An open-source reimplementation of a Defined Networking-compatible client"
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"

View File

@ -1,19 +1,18 @@
use flate2::read::GzDecoder;
use reqwest::blocking::Response;
use reqwest::header::HeaderMap;
use std::fs; use std::fs;
use std::fs::{File, remove_file}; use std::fs::{remove_file, File};
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
use std::path::Path; use std::path::Path;
use std::process::{Command, Output}; use std::process::{Command, Output};
use flate2::read::GzDecoder;
use reqwest::blocking::Response;
use reqwest::header::HeaderMap;
use tar::Archive; use tar::Archive;
#[derive(serde::Deserialize, Debug)] #[derive(serde::Deserialize, Debug)]
struct GithubRelease { struct GithubRelease {
name: String, name: String,
assets: Vec<GithubReleaseAsset> assets: Vec<GithubReleaseAsset>,
} }
#[derive(serde::Deserialize, Debug)] #[derive(serde::Deserialize, Debug)]
@ -23,11 +22,18 @@ struct GithubUser {}
struct GithubReleaseAsset { struct GithubReleaseAsset {
browser_download_url: String, browser_download_url: String,
name: String, name: String,
size: i64 size: i64,
} }
fn main() { fn main() {
if Path::new(&format!("{}/{}", std::env::var("OUT_DIR").unwrap(), "noredownload")).exists() && std::env::var("TFBUILD_FORCE_REDOWNLOAD").is_err() { if Path::new(&format!(
"{}/{}",
std::env::var("OUT_DIR").unwrap(),
"noredownload"
))
.exists()
&& std::env::var("TFBUILD_FORCE_REDOWNLOAD").is_err()
{
println!("noredownload exists and TFBUILD_FORCE_REDOWNLOAD is not set. Not redoing build process."); println!("noredownload exists and TFBUILD_FORCE_REDOWNLOAD is not set. Not redoing build process.");
return; return;
} }
@ -38,15 +44,32 @@ fn main() {
let mut has_api_key = false; let mut has_api_key = false;
if let Ok(api_key) = std::env::var("GH_API_KEY") { if let Ok(api_key) = std::env::var("GH_API_KEY") {
headers.insert("Authorization", format!("Bearer {}", api_key).parse().unwrap()); headers.insert(
"Authorization",
format!("Bearer {}", api_key).parse().unwrap(),
);
has_api_key = true; has_api_key = true;
} }
let client = reqwest::blocking::Client::builder().user_agent("curl/7.57.1").default_headers(headers).build().unwrap(); let client = reqwest::blocking::Client::builder()
.user_agent("curl/7.57.1")
.default_headers(headers)
.build()
.unwrap();
let resp: Response = client.get("https://api.github.com/repos/slackhq/nebula/releases/latest").send().unwrap(); let resp: Response = client
.get("https://api.github.com/repos/slackhq/nebula/releases/latest")
.send()
.unwrap();
if resp.headers().get("X-Ratelimit-Remaining").unwrap().to_str().unwrap() == "0" { if resp
.headers()
.get("X-Ratelimit-Remaining")
.unwrap()
.to_str()
.unwrap()
== "0"
{
println!("You've been ratelimited from the GitHub API. Wait a while (1 hour)"); println!("You've been ratelimited from the GitHub API. Wait a while (1 hour)");
if !has_api_key { if !has_api_key {
println!("You can also set a GitHub API key with the environment variable GH_API_KEY, which will increase your ratelimit ( a lot )"); println!("You can also set a GitHub API key with the environment variable GH_API_KEY, which will increase your ratelimit ( a lot )");
@ -54,7 +77,6 @@ fn main() {
panic!("Ratelimited"); panic!("Ratelimited");
} }
let release: GithubRelease = resp.json().unwrap(); let release: GithubRelease = resp.json().unwrap();
println!("[*] Fetching target triplet..."); println!("[*] Fetching target triplet...");
@ -84,9 +106,16 @@ fn main() {
println!("[*] Embedding {} {}", target_file, release.name); println!("[*] Embedding {} {}", target_file, release.name);
let download = release.assets.iter().find(|r| r.name == format!("{}.tar.gz", target_file)).expect("That architecture isn't avaliable :("); let download = release
.assets
.iter()
.find(|r| r.name == format!("{}.tar.gz", target_file))
.expect("That architecture isn't avaliable :(");
println!("[*] Downloading {}.tar.gz ({}, {} bytes) from {}", target_file, target, download.size, download.browser_download_url); println!(
"[*] Downloading {}.tar.gz ({}, {} bytes) from {}",
target_file, target, download.size, download.browser_download_url
);
let response = reqwest::blocking::get(&download.browser_download_url).unwrap(); let response = reqwest::blocking::get(&download.browser_download_url).unwrap();
let content = response.bytes().unwrap().to_vec(); let content = response.bytes().unwrap().to_vec();
@ -102,10 +131,14 @@ fn main() {
for entry in entries { for entry in entries {
let mut entry = entry.unwrap(); let mut entry = entry.unwrap();
if entry.path().unwrap() == Path::new("nebula") || entry.path().unwrap() == Path::new("nebula.exe") { if entry.path().unwrap() == Path::new("nebula")
|| entry.path().unwrap() == Path::new("nebula.exe")
{
nebula_bin.reserve(entry.size() as usize); nebula_bin.reserve(entry.size() as usize);
entry.read_to_end(&mut nebula_bin).unwrap(); entry.read_to_end(&mut nebula_bin).unwrap();
} else if entry.path().unwrap() == Path::new("nebula-cert") || entry.path().unwrap() == Path::new("nebula-cert.exe") { } else if entry.path().unwrap() == Path::new("nebula-cert")
|| entry.path().unwrap() == Path::new("nebula-cert.exe")
{
nebula_cert_bin.reserve(entry.size() as usize); nebula_cert_bin.reserve(entry.size() as usize);
entry.read_to_end(&mut nebula_cert_bin).unwrap(); entry.read_to_end(&mut nebula_cert_bin).unwrap();
} else if entry.path().unwrap() == Path::new("SHASUM256.txt") { } else if entry.path().unwrap() == Path::new("SHASUM256.txt") {
@ -121,18 +154,28 @@ fn main() {
panic!("[x] Release did not contain nebula_cert binary"); panic!("[x] Release did not contain nebula_cert binary");
} }
let mut nebula_file = File::create(format!("{}/nebula.bin", std::env::var("OUT_DIR").unwrap())).unwrap(); let mut nebula_file =
File::create(format!("{}/nebula.bin", std::env::var("OUT_DIR").unwrap())).unwrap();
nebula_file.write_all(&nebula_bin).unwrap(); nebula_file.write_all(&nebula_bin).unwrap();
codegen_version(&nebula_bin, "nebula.bin", "NEBULA"); codegen_version(&nebula_bin, "nebula.bin", "NEBULA");
let mut nebula_cert_file = File::create(format!("{}/nebula_cert.bin", std::env::var("OUT_DIR").unwrap())).unwrap(); let mut nebula_cert_file = File::create(format!(
"{}/nebula_cert.bin",
std::env::var("OUT_DIR").unwrap()
))
.unwrap();
nebula_cert_file.write_all(&nebula_cert_bin).unwrap(); nebula_cert_file.write_all(&nebula_cert_bin).unwrap();
codegen_version(&nebula_cert_bin, "nebula_cert.bin", "NEBULA_CERT"); codegen_version(&nebula_cert_bin, "nebula_cert.bin", "NEBULA_CERT");
// Indicate to cargo and ourselves that we have already downloaded and codegenned // Indicate to cargo and ourselves that we have already downloaded and codegenned
File::create(format!("{}/{}", std::env::var("OUT_DIR").unwrap(), "noredownload")).unwrap(); File::create(format!(
"{}/{}",
std::env::var("OUT_DIR").unwrap(),
"noredownload"
))
.unwrap();
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
} }
@ -149,7 +192,8 @@ fn codegen_version(bin: &[u8], fp: &str, name: &str) {
let code = format!("// This code was automatically @generated by build.rs. It should not be modified.\npub const {}_BIN: &[u8] = include_bytes!(concat!(env!(\"OUT_DIR\"), \"/{}\"));\npub const {}_VERSION: &str = \"{}\";", name, fp, name, version); let code = format!("// This code was automatically @generated by build.rs. It should not be modified.\npub const {}_BIN: &[u8] = include_bytes!(concat!(env!(\"OUT_DIR\"), \"/{}\"));\npub const {}_VERSION: &str = \"{}\";", name, fp, name, version);
let mut file = File::create(format!("{}/{}.rs", std::env::var("OUT_DIR").unwrap(), fp)).unwrap(); let mut file =
File::create(format!("{}/{}.rs", std::env::var("OUT_DIR").unwrap(), fp)).unwrap();
file.write_all(code.as_bytes()).unwrap(); file.write_all(code.as_bytes()).unwrap();
} }

View File

@ -1,13 +1,9 @@
use std::fs; use std::fs;
use std::sync::mpsc::{Receiver, RecvError, TryRecvError}; use std::sync::mpsc::Receiver;
use dnapi_rs::client_blocking::Client;
use log::{error, info, warn}; use log::{error, info, warn};
use url::Url; use url::Url;
use dnapi_rs::client_blocking::Client;
use crate::config::{load_cdata, save_cdata, TFClientConfig}; use crate::config::{load_cdata, save_cdata, TFClientConfig};
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
@ -18,10 +14,16 @@ pub enum APIWorkerMessage {
Shutdown, Shutdown,
Enroll { code: String }, Enroll { code: String },
Update, Update,
Timer Timer,
} }
pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx: ThreadMessageSender, rx: Receiver<APIWorkerMessage>) { pub fn apiworker_main(
_config: TFClientConfig,
instance: String,
url: String,
tx: ThreadMessageSender,
rx: Receiver<APIWorkerMessage>,
) {
let server = Url::parse(&url).unwrap(); let server = Url::parse(&url).unwrap();
let client = Client::new(format!("tfclient/{}", env!("CARGO_PKG_VERSION")), server).unwrap(); let client = Client::new(format!("tfclient/{}", env!("CARGO_PKG_VERSION")), server).unwrap();
@ -33,7 +35,7 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
APIWorkerMessage::Shutdown => { APIWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping"); info!("recv on command socket: shutdown, stopping");
break; break;
}, }
APIWorkerMessage::Timer | APIWorkerMessage::Update => { APIWorkerMessage::Timer | APIWorkerMessage::Update => {
info!("updating config"); info!("updating config");
let mut cdata = match load_cdata(&instance) { let mut cdata = match load_cdata(&instance) {
@ -108,9 +110,13 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
}; };
cdata.creds = Some(creds); cdata.creds = Some(creds);
cdata.dh_privkey = Some(dh_privkey.try_into().expect("32 != 32")); cdata.dh_privkey = Some(dh_privkey);
match fs::write(get_nebulaconfig_file(&instance).expect("Unable to determine nebula config file location"), config) { match fs::write(
get_nebulaconfig_file(&instance)
.expect("Unable to determine nebula config file location"),
config,
) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("unable to save nebula config: {}", e); error!("unable to save nebula config: {}", e);
@ -146,7 +152,7 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
return; return;
} }
} }
}, }
APIWorkerMessage::Enroll { code } => { APIWorkerMessage::Enroll { code } => {
info!("recv on command socket: enroll {}", code); info!("recv on command socket: enroll {}", code);
let mut cdata = match load_cdata(&instance) { let mut cdata = match load_cdata(&instance) {
@ -170,7 +176,11 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
} }
}; };
match fs::write(get_nebulaconfig_file(&instance).expect("Unable to determine nebula config file location"), config) { match fs::write(
get_nebulaconfig_file(&instance)
.expect("Unable to determine nebula config file location"),
config,
) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("unable to save nebula config: {}", e); error!("unable to save nebula config: {}", e);
@ -179,7 +189,7 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
} }
cdata.creds = Some(creds); cdata.creds = Some(creds);
cdata.dh_privkey = Some(dh_privkey.try_into().expect("32 != 32")); cdata.dh_privkey = Some(dh_privkey);
cdata.meta = Some(meta); cdata.meta = Some(meta);
// Save vardata // Save vardata
@ -204,11 +214,11 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
} }
} }
} }
}, }
Err(e) => { Err(e) => {
error!("error on command socket: {}", e); error!("error on command socket: {}", e);
return; return;
} }
} }
} }
} }

View File

@ -1,33 +1,34 @@
use ipnet::{IpNet, Ipv4Net};
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddrV4};
use ipnet::{IpNet, Ipv4Net};
use log::{debug, info};
use serde::{Deserialize, Serialize};
use dnapi_rs::client_blocking::EnrollMeta; use dnapi_rs::client_blocking::EnrollMeta;
use dnapi_rs::credentials::Credentials; use dnapi_rs::credentials::Credentials;
use log::{debug, info};
use serde::{Deserialize, Serialize};
use crate::dirs::{get_cdata_dir, get_cdata_file, get_config_dir, get_config_file}; use crate::dirs::{get_cdata_dir, get_cdata_file, get_config_dir, get_config_file};
pub const DEFAULT_PORT: u16 = 8157; pub const DEFAULT_PORT: u16 = 8157;
fn default_port() -> u16 { DEFAULT_PORT } fn default_port() -> u16 {
DEFAULT_PORT
}
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
pub struct TFClientConfig { pub struct TFClientConfig {
#[serde(default = "default_port")] #[serde(default = "default_port")]
pub listen_port: u16, pub listen_port: u16,
#[serde(default = "bool_false")] #[serde(default = "bool_false")]
pub disable_automatic_config_updates: bool pub disable_automatic_config_updates: bool,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
pub struct TFClientData { pub struct TFClientData {
pub dh_privkey: Option<Vec<u8>>, pub dh_privkey: Option<Vec<u8>>,
pub creds: Option<Credentials>, pub creds: Option<Credentials>,
pub meta: Option<EnrollMeta> pub meta: Option<EnrollMeta>,
} }
pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> { pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> {
@ -39,7 +40,10 @@ pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> {
disable_automatic_config_updates: false, disable_automatic_config_updates: false,
}; };
let config_str = toml::to_string(&config)?; let config_str = toml::to_string(&config)?;
fs::write(get_config_file(instance).ok_or("Unable to load config dir")?, config_str)?; fs::write(
get_config_file(instance).ok_or("Unable to load config dir")?,
config_str,
)?;
Ok(()) Ok(())
} }
@ -63,9 +67,16 @@ pub fn create_cdata(instance: &str) -> Result<(), Box<dyn Error>> {
info!("Creating data directory..."); info!("Creating data directory...");
fs::create_dir_all(get_cdata_dir(instance).ok_or("Unable to load data dir")?)?; fs::create_dir_all(get_cdata_dir(instance).ok_or("Unable to load data dir")?)?;
info!("Copying default data file to config directory..."); info!("Copying default data file to config directory...");
let config = TFClientData { dh_privkey: None, creds: None, meta: None }; let config = TFClientData {
dh_privkey: None,
creds: None,
meta: None,
};
let config_str = toml::to_string(&config)?; let config_str = toml::to_string(&config)?;
fs::write(get_cdata_file(instance).ok_or("Unable to load data dir")?, config_str)?; fs::write(
get_cdata_file(instance).ok_or("Unable to load data dir")?,
config_str,
)?;
Ok(()) Ok(())
} }
@ -141,7 +152,7 @@ pub struct NebulaConfig {
#[serde(default = "none")] #[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")] #[serde(skip_serializing_if = "is_none")]
pub local_range: Option<Ipv4Net> pub local_range: Option<Ipv4Net>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -156,7 +167,7 @@ pub struct NebulaConfigPki {
pub blocklist: Vec<String>, pub blocklist: Vec<String>,
#[serde(default = "bool_false")] #[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")] #[serde(skip_serializing_if = "is_bool_false")]
pub disconnect_invalid: bool pub disconnect_invalid: bool,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -190,7 +201,7 @@ pub struct NebulaConfigLighthouseDns {
pub host: String, pub host: String,
#[serde(default = "u16_53")] #[serde(default = "u16_53")]
#[serde(skip_serializing_if = "is_u16_53")] #[serde(skip_serializing_if = "is_u16_53")]
pub port: u16 pub port: u16,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -207,7 +218,7 @@ pub struct NebulaConfigListen {
#[serde(skip_serializing_if = "is_none")] #[serde(skip_serializing_if = "is_none")]
pub read_buffer: Option<u32>, pub read_buffer: Option<u32>,
#[serde(skip_serializing_if = "is_none")] #[serde(skip_serializing_if = "is_none")]
pub write_buffer: Option<u32> pub write_buffer: Option<u32>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -220,7 +231,7 @@ pub struct NebulaConfigPunchy {
pub respond: bool, pub respond: bool,
#[serde(default = "string_1s")] #[serde(default = "string_1s")]
#[serde(skip_serializing_if = "is_string_1s")] #[serde(skip_serializing_if = "is_string_1s")]
pub delay: String pub delay: String,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -228,7 +239,7 @@ pub enum NebulaConfigCipher {
#[serde(rename = "aes")] #[serde(rename = "aes")]
Aes, Aes,
#[serde(rename = "chachapoly")] #[serde(rename = "chachapoly")]
ChaChaPoly ChaChaPoly,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -241,7 +252,7 @@ pub struct NebulaConfigRelay {
pub am_relay: bool, pub am_relay: bool,
#[serde(default = "bool_true")] #[serde(default = "bool_true")]
#[serde(skip_serializing_if = "is_bool_true")] #[serde(skip_serializing_if = "is_bool_true")]
pub use_relays: bool pub use_relays: bool,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -268,13 +279,13 @@ pub struct NebulaConfigTun {
pub routes: Vec<NebulaConfigTunRouteOverride>, pub routes: Vec<NebulaConfigTunRouteOverride>,
#[serde(default = "empty_vec")] #[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")] #[serde(skip_serializing_if = "is_empty_vec")]
pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute> pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunRouteOverride { pub struct NebulaConfigTunRouteOverride {
pub mtu: u64, pub mtu: u64,
pub route: Ipv4Net pub route: Ipv4Net,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -286,7 +297,7 @@ pub struct NebulaConfigTunUnsafeRoute {
pub mtu: u64, pub mtu: u64,
#[serde(default = "i64_100")] #[serde(default = "i64_100")]
#[serde(skip_serializing_if = "is_i64_100")] #[serde(skip_serializing_if = "is_i64_100")]
pub metric: i64 pub metric: i64,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -302,7 +313,7 @@ pub struct NebulaConfigLogging {
pub disable_timestamp: bool, pub disable_timestamp: bool,
#[serde(default = "timestamp")] #[serde(default = "timestamp")]
#[serde(skip_serializing_if = "is_timestamp")] #[serde(skip_serializing_if = "is_timestamp")]
pub timestamp_format: String pub timestamp_format: String,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -318,7 +329,7 @@ pub enum NebulaConfigLoggingLevel {
#[serde(rename = "info")] #[serde(rename = "info")]
Info, Info,
#[serde(rename = "debug")] #[serde(rename = "debug")]
Debug Debug,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -326,7 +337,7 @@ pub enum NebulaConfigLoggingFormat {
#[serde(rename = "json")] #[serde(rename = "json")]
Json, Json,
#[serde(rename = "text")] #[serde(rename = "text")]
Text Text,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -338,7 +349,7 @@ pub struct NebulaConfigSshd {
pub host_key: String, pub host_key: String,
#[serde(default = "empty_vec")] #[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")] #[serde(skip_serializing_if = "is_empty_vec")]
pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser> pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -346,7 +357,7 @@ pub struct NebulaConfigSshdAuthorizedUser {
pub user: String, pub user: String,
#[serde(default = "empty_vec")] #[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")] #[serde(skip_serializing_if = "is_empty_vec")]
pub keys: Vec<String> pub keys: Vec<String>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -355,7 +366,7 @@ pub enum NebulaConfigStats {
#[serde(rename = "graphite")] #[serde(rename = "graphite")]
Graphite(NebulaConfigStatsGraphite), Graphite(NebulaConfigStatsGraphite),
#[serde(rename = "prometheus")] #[serde(rename = "prometheus")]
Prometheus(NebulaConfigStatsPrometheus) Prometheus(NebulaConfigStatsPrometheus),
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -373,7 +384,7 @@ pub struct NebulaConfigStatsGraphite {
pub message_metrics: bool, pub message_metrics: bool,
#[serde(default = "bool_false")] #[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")] #[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool pub lighthouse_metrics: bool,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -381,7 +392,7 @@ pub enum NebulaConfigStatsGraphiteProtocol {
#[serde(rename = "tcp")] #[serde(rename = "tcp")]
Tcp, Tcp,
#[serde(rename = "udp")] #[serde(rename = "udp")]
Udp Udp,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -400,7 +411,7 @@ pub struct NebulaConfigStatsPrometheus {
pub message_metrics: bool, pub message_metrics: bool,
#[serde(default = "bool_false")] #[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")] #[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool pub lighthouse_metrics: bool,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -428,7 +439,7 @@ pub struct NebulaConfigFirewallConntrack {
pub udp_timeout: String, pub udp_timeout: String,
#[serde(default = "string_10m")] #[serde(default = "string_10m")]
#[serde(skip_serializing_if = "is_string_10m")] #[serde(skip_serializing_if = "is_string_10m")]
pub default_timeout: String pub default_timeout: String,
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
@ -456,82 +467,175 @@ pub struct NebulaConfigFirewallRule {
pub groups: Option<Vec<String>>, pub groups: Option<Vec<String>>,
#[serde(default = "none")] #[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")] #[serde(skip_serializing_if = "is_none")]
pub cidr: Option<String> pub cidr: Option<String>,
} }
// Default values for serde // Default values for serde
fn string_12m() -> String { "12m".to_string() } fn string_12m() -> String {
fn is_string_12m(s: &str) -> bool { s == "12m" } "12m".to_string()
}
fn is_string_12m(s: &str) -> bool {
s == "12m"
}
fn string_3m() -> String { "3m".to_string() } fn string_3m() -> String {
fn is_string_3m(s: &str) -> bool { s == "3m" } "3m".to_string()
}
fn is_string_3m(s: &str) -> bool {
s == "3m"
}
fn string_10m() -> String { "10m".to_string() } fn string_10m() -> String {
fn is_string_10m(s: &str) -> bool { s == "10m" } "10m".to_string()
}
fn is_string_10m(s: &str) -> bool {
s == "10m"
}
fn empty_vec<T>() -> Vec<T> { vec![] } fn empty_vec<T>() -> Vec<T> {
fn is_empty_vec<T>(v: &Vec<T>) -> bool { v.is_empty() } vec![]
}
fn is_empty_vec<T>(v: &Vec<T>) -> bool {
v.is_empty()
}
fn empty_hashmap<A, B>() -> HashMap<A, B> { HashMap::new() } fn empty_hashmap<A, B>() -> HashMap<A, B> {
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool { h.is_empty() } HashMap::new()
}
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool {
h.is_empty()
}
fn bool_false() -> bool { false } fn bool_false() -> bool {
fn is_bool_false(b: &bool) -> bool { !*b } false
}
fn is_bool_false(b: &bool) -> bool {
!*b
}
fn bool_true() -> bool { true } fn bool_true() -> bool {
fn is_bool_true(b: &bool) -> bool { *b } true
}
fn is_bool_true(b: &bool) -> bool {
*b
}
fn u16_53() -> u16 { 53 } fn u16_53() -> u16 {
fn is_u16_53(u: &u16) -> bool { *u == 53 } 53
}
fn is_u16_53(u: &u16) -> bool {
*u == 53
}
fn u32_10() -> u32 { 10 } fn u32_10() -> u32 {
fn is_u32_10(u: &u32) -> bool { *u == 10 } 10
}
fn is_u32_10(u: &u32) -> bool {
*u == 10
}
fn ipv4_0000() -> Ipv4Addr { Ipv4Addr::new(0, 0, 0, 0) } fn u16_0() -> u16 {
fn is_ipv4_0000(i: &Ipv4Addr) -> bool { *i == ipv4_0000() } 0
}
fn is_u16_0(u: &u16) -> bool {
*u == 0
}
fn u16_0() -> u16 { 0 } fn u32_64() -> u32 {
fn is_u16_0(u: &u16) -> bool { *u == 0 } 64
}
fn is_u32_64(u: &u32) -> bool {
*u == 64
}
fn u32_64() -> u32 { 64 } fn string_1s() -> String {
fn is_u32_64(u: &u32) -> bool { *u == 64 } "1s".to_string()
}
fn is_string_1s(s: &str) -> bool {
s == "1s"
}
fn string_1s() -> String { "1s".to_string() } fn cipher_aes() -> NebulaConfigCipher {
fn is_string_1s(s: &str) -> bool { s == "1s" } NebulaConfigCipher::Aes
}
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool {
matches!(c, NebulaConfigCipher::Aes)
}
fn cipher_aes() -> NebulaConfigCipher { NebulaConfigCipher::Aes } fn u64_500() -> u64 {
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool { matches!(c, NebulaConfigCipher::Aes) } 500
}
fn is_u64_500(u: &u64) -> bool {
*u == 500
}
fn u64_500() -> u64 { 500 } fn u64_1300() -> u64 {
fn is_u64_500(u: &u64) -> bool { *u == 500 } 1300
}
fn is_u64_1300(u: &u64) -> bool {
*u == 1300
}
fn u64_1300() -> u64 { 1300 } fn i64_100() -> i64 {
fn is_u64_1300(u: &u64) -> bool { *u == 1300 } 100
}
fn is_i64_100(i: &i64) -> bool {
*i == 100
}
fn i64_100() -> i64 { 100 } fn loglevel_info() -> NebulaConfigLoggingLevel {
fn is_i64_100(i: &i64) -> bool { *i == 100 } NebulaConfigLoggingLevel::Info
}
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool {
matches!(l, NebulaConfigLoggingLevel::Info)
}
fn loglevel_info() -> NebulaConfigLoggingLevel { NebulaConfigLoggingLevel::Info } fn format_text() -> NebulaConfigLoggingFormat {
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool { matches!(l, NebulaConfigLoggingLevel::Info) } NebulaConfigLoggingFormat::Text
}
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool {
matches!(f, NebulaConfigLoggingFormat::Text)
}
fn format_text() -> NebulaConfigLoggingFormat { NebulaConfigLoggingFormat::Text } fn timestamp() -> String {
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool { matches!(f, NebulaConfigLoggingFormat::Text) } "2006-01-02T15:04:05Z07:00".to_string()
}
fn is_timestamp(s: &str) -> bool {
s == "2006-01-02T15:04:05Z07:00"
}
fn timestamp() -> String { "2006-01-02T15:04:05Z07:00".to_string() } fn u64_1() -> u64 {
fn is_timestamp(s: &str) -> bool { s == "2006-01-02T15:04:05Z07:00" } 1
}
fn is_u64_1(u: &u64) -> bool {
*u == 1
}
fn u64_1() -> u64 { 1 } fn string_nebula() -> String {
fn is_u64_1(u: &u64) -> bool { *u == 1 } "nebula".to_string()
}
fn is_string_nebula(s: &str) -> bool {
s == "nebula"
}
fn string_nebula() -> String { "nebula".to_string() } fn string_empty() -> String {
fn is_string_nebula(s: &str) -> bool { s == "nebula" } String::new()
}
fn is_string_empty(s: &str) -> bool {
s.is_empty()
}
fn string_empty() -> String { String::new() } fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol {
fn is_string_empty(s: &str) -> bool { s == "" } NebulaConfigStatsGraphiteProtocol::Tcp
}
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool {
matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp)
}
fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol { NebulaConfigStatsGraphiteProtocol::Tcp } fn none<T>() -> Option<T> {
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool { matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp) } None
}
fn none<T>() -> Option<T> { None } fn is_none<T>(o: &Option<T>) -> bool {
fn is_none<T>(o: &Option<T>) -> bool { o.is_none() } o.is_none()
}

View File

@ -1,7 +1,7 @@
use log::{error, info};
use std::sync::mpsc; use std::sync::mpsc;
use std::sync::mpsc::Sender; use std::sync::mpsc::Sender;
use std::thread; use std::thread;
use log::{error, info};
use crate::apiworker::{apiworker_main, APIWorkerMessage}; use crate::apiworker::{apiworker_main, APIWorkerMessage};
use crate::config::load_config; use crate::config::load_config;
@ -44,28 +44,49 @@ pub fn daemon_main(name: String, server: String) {
match ctrlc::set_handler(move || { match ctrlc::set_handler(move || {
info!("Ctrl-C detected. Stopping threads..."); info!("Ctrl-C detected. Stopping threads...");
match mainthread_transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) { match mainthread_transmitter
.nebula_thread
.send(NebulaWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to nebula worker thread: {}", e); error!(
"Error sending shutdown message to nebula worker thread: {}",
e
);
} }
} }
match mainthread_transmitter.api_thread.send(APIWorkerMessage::Shutdown) { match mainthread_transmitter
.api_thread
.send(APIWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to api worker thread: {}", e); error!("Error sending shutdown message to api worker thread: {}", e);
} }
} }
match mainthread_transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) { match mainthread_transmitter
.socket_thread
.send(SocketWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to socket worker thread: {}", e); error!(
"Error sending shutdown message to socket worker thread: {}",
e
);
} }
} }
match mainthread_transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) { match mainthread_transmitter
.timer_thread
.send(TimerWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to timer worker thread: {}", e); error!(
"Error sending shutdown message to timer worker thread: {}",
e
);
} }
} }
}) { }) {
@ -81,21 +102,21 @@ pub fn daemon_main(name: String, server: String) {
let config_api = config.clone(); let config_api = config.clone();
let transmitter_api = transmitter.clone(); let transmitter_api = transmitter.clone();
let name_api = name.clone(); let name_api = name.clone();
let server_api = server.clone(); let server_api = server;
let api_thread = thread::spawn(move || { let api_thread = thread::spawn(move || {
apiworker_main(config_api, name_api, server_api,transmitter_api, rx_api); apiworker_main(config_api, name_api, server_api, transmitter_api, rx_api);
}); });
info!("Starting Nebula thread..."); info!("Starting Nebula thread...");
let config_nebula = config.clone(); let config_nebula = config.clone();
let transmitter_nebula = transmitter.clone(); let transmitter_nebula = transmitter.clone();
let name_nebula = name.clone(); let name_nebula = name.clone();
let nebula_thread = thread::spawn(move || { //let nebula_thread = thread::spawn(move || {
nebulaworker_main(config_nebula, name_nebula, transmitter_nebula, rx_nebula); // nebulaworker_main(config_nebula, name_nebula, transmitter_nebula, rx_nebula);
}); //});
info!("Starting socket worker thread..."); info!("Starting socket worker thread...");
let name_socket = name.clone(); let name_socket = name;
let config_socket = config.clone(); let config_socket = config.clone();
let tx_socket = transmitter.clone(); let tx_socket = transmitter.clone();
let socket_thread = thread::spawn(move || { let socket_thread = thread::spawn(move || {
@ -104,7 +125,7 @@ pub fn daemon_main(name: String, server: String) {
info!("Starting timer thread..."); info!("Starting timer thread...");
if !config.disable_automatic_config_updates { if !config.disable_automatic_config_updates {
let timer_transmitter = transmitter.clone(); let timer_transmitter = transmitter;
let timer_thread = thread::spawn(move || { let timer_thread = thread::spawn(move || {
timer_main(timer_transmitter, rx_timer); timer_main(timer_transmitter, rx_timer);
}); });
@ -142,13 +163,13 @@ pub fn daemon_main(name: String, server: String) {
info!("API thread exited"); info!("API thread exited");
info!("Waiting for Nebula thread to exit..."); info!("Waiting for Nebula thread to exit...");
match nebula_thread.join() { //match nebula_thread.join() {
Ok(_) => (), // Ok(_) => (),
Err(_) => { // Err(_) => {
error!("Error waiting for nebula thread to exit."); // error!("Error waiting for nebula thread to exit.");
std::process::exit(1); // std::process::exit(1);
} // }
} //}
info!("Nebula thread exited"); info!("Nebula thread exited");
info!("All threads exited"); info!("All threads exited");
@ -159,5 +180,5 @@ pub struct ThreadMessageSender {
pub socket_thread: Sender<SocketWorkerMessage>, pub socket_thread: Sender<SocketWorkerMessage>,
pub api_thread: Sender<APIWorkerMessage>, pub api_thread: Sender<APIWorkerMessage>,
pub nebula_thread: Sender<NebulaWorkerMessage>, pub nebula_thread: Sender<NebulaWorkerMessage>,
pub timer_thread: Sender<TimerWorkerMessage> pub timer_thread: Sender<TimerWorkerMessage>,
} }

View File

@ -22,4 +22,4 @@ pub fn get_cdata_file(instance: &str) -> Option<PathBuf> {
pub fn get_nebulaconfig_file(instance: &str) -> Option<PathBuf> { pub fn get_nebulaconfig_file(instance: &str) -> Option<PathBuf> {
get_cdata_dir(instance).map(|f| f.join("nebula.sk_embedded.yml")) get_cdata_dir(instance).map(|f| f.join("nebula.sk_embedded.yml"))
} }

View File

@ -1,3 +1,6 @@
use crate::dirs::get_data_dir;
use crate::util::sha256;
use log::debug;
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use std::fs::File; use std::fs::File;
@ -5,9 +8,6 @@ use std::io::Write;
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::{Child, Command}; use std::process::{Child, Command};
use log::debug;
use crate::dirs::get_data_dir;
use crate::util::sha256;
pub fn extract_embedded_nebula() -> Result<PathBuf, Box<dyn Error>> { pub fn extract_embedded_nebula() -> Result<PathBuf, Box<dyn Error>> {
let data_dir = get_data_dir().ok_or("Unable to get platform-specific data dir")?; let data_dir = get_data_dir().ok_or("Unable to get platform-specific data dir")?;
@ -25,7 +25,11 @@ pub fn extract_embedded_nebula() -> Result<PathBuf, Box<dyn Error>> {
} }
let executable_postfix = if cfg!(windows) { ".exe" } else { "" }; let executable_postfix = if cfg!(windows) { ".exe" } else { "" };
let executable_name = format!("nebula-{}{}", crate::nebula_bin::NEBULA_VERSION, executable_postfix); let executable_name = format!(
"nebula-{}{}",
crate::nebula_bin::NEBULA_VERSION,
executable_postfix
);
let file_path = hash_dir.join(executable_name); let file_path = hash_dir.join(executable_name);
@ -49,7 +53,10 @@ pub fn extract_embedded_nebula_cert() -> Result<PathBuf, Box<dyn Error>> {
} }
let bin_dir = data_dir.join("cache/"); let bin_dir = data_dir.join("cache/");
let hash_dir = bin_dir.join(format!("{}/", sha256(crate::nebula_cert_bin::NEBULA_CERT_BIN))); let hash_dir = bin_dir.join(format!(
"{}/",
sha256(crate::nebula_cert_bin::NEBULA_CERT_BIN)
));
if !hash_dir.exists() { if !hash_dir.exists() {
fs::create_dir_all(&hash_dir)?; fs::create_dir_all(&hash_dir)?;
@ -57,7 +64,11 @@ pub fn extract_embedded_nebula_cert() -> Result<PathBuf, Box<dyn Error>> {
} }
let executable_postfix = if cfg!(windows) { ".exe" } else { "" }; let executable_postfix = if cfg!(windows) { ".exe" } else { "" };
let executable_name = format!("nebula-cert-{}{}", crate::nebula_cert_bin::NEBULA_CERT_VERSION, executable_postfix); let executable_name = format!(
"nebula-cert-{}{}",
crate::nebula_cert_bin::NEBULA_CERT_VERSION,
executable_postfix
);
let file_path = hash_dir.join(executable_name); let file_path = hash_dir.join(executable_name);
@ -101,4 +112,4 @@ pub fn run_embedded_nebula_cert(args: &[String]) -> Result<Child, Box<dyn Error>
debug!("Running {} with args {:?}", path.as_path().display(), args); debug!("Running {} with args {:?}", path.as_path().display(), args);
_setup_permissions(&path)?; _setup_permissions(&path)?;
Ok(Command::new(path).args(args).spawn()?) Ok(Command::new(path).args(args).spawn()?)
} }

View File

@ -14,16 +14,16 @@
// You should have received a copy of the GNU General Public License // You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
pub mod embedded_nebula;
pub mod dirs;
pub mod util;
pub mod nebulaworker;
pub mod daemon;
pub mod config;
pub mod apiworker; pub mod apiworker;
pub mod socketworker; pub mod config;
pub mod daemon;
pub mod dirs;
pub mod embedded_nebula;
pub mod nebulaworker;
pub mod socketclient; pub mod socketclient;
pub mod socketworker;
pub mod timerworker; pub mod timerworker;
pub mod util;
pub mod nebula_bin { pub mod nebula_bin {
include!(concat!(env!("OUT_DIR"), "/nebula.bin.rs")); include!(concat!(env!("OUT_DIR"), "/nebula.bin.rs"));
@ -32,15 +32,14 @@ pub mod nebula_cert_bin {
include!(concat!(env!("OUT_DIR"), "/nebula_cert.bin.rs")); include!(concat!(env!("OUT_DIR"), "/nebula_cert.bin.rs"));
} }
use std::fs; use std::fs;
use clap::{Parser, ArgAction, Subcommand};
use log::{error, info};
use simple_logger::SimpleLogger;
use crate::config::load_config; use crate::config::load_config;
use crate::dirs::get_data_dir; use crate::dirs::get_data_dir;
use crate::embedded_nebula::{run_embedded_nebula, run_embedded_nebula_cert}; use crate::embedded_nebula::{run_embedded_nebula, run_embedded_nebula_cert};
use clap::{ArgAction, Parser, Subcommand};
use log::{error, info};
use simple_logger::SimpleLogger;
#[derive(Parser)] #[derive(Parser)]
#[command(author = "c0repwn3r", version, about, long_about = None)] #[command(author = "c0repwn3r", version, about, long_about = None)]
@ -52,7 +51,7 @@ struct Cli {
version: bool, version: bool,
#[command(subcommand)] #[command(subcommand)]
subcommand: Commands subcommand: Commands,
} }
#[derive(Subcommand)] #[derive(Subcommand)]
@ -60,14 +59,14 @@ enum Commands {
/// Run the `nebula` binary. This is useful if you want to do debugging with tfclient's internal nebula. /// Run the `nebula` binary. This is useful if you want to do debugging with tfclient's internal nebula.
RunNebula { RunNebula {
/// Arguments to pass to the `nebula` binary /// Arguments to pass to the `nebula` binary
#[clap(trailing_var_arg=true, allow_hyphen_values=true)] #[clap(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String> args: Vec<String>,
}, },
/// Run the `nebula-cert` binary. This is useful if you want to mess with certificates. Note: tfclient does not actually use nebula-cert for certificate operations, and instead uses trifid-pki internally /// Run the `nebula-cert` binary. This is useful if you want to mess with certificates. Note: tfclient does not actually use nebula-cert for certificate operations, and instead uses trifid-pki internally
RunNebulaCert { RunNebulaCert {
/// Arguments to pass to the `nebula-cert` binary /// Arguments to pass to the `nebula-cert` binary
#[clap(trailing_var_arg=true, allow_hyphen_values=true)] #[clap(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String> args: Vec<String>,
}, },
/// Clear any cached data that tfclient may have added /// Clear any cached data that tfclient may have added
ClearCache {}, ClearCache {},
@ -79,7 +78,7 @@ enum Commands {
name: String, name: String,
#[clap(short, long)] #[clap(short, long)]
/// Server to use for API calls. /// Server to use for API calls.
server: String server: String,
}, },
/// Enroll this host using a trifid-api enrollment code /// Enroll this host using a trifid-api enrollment code
@ -97,7 +96,7 @@ enum Commands {
#[clap(short, long, default_value = "tfclient")] #[clap(short, long, default_value = "tfclient")]
/// Service name specified on install /// Service name specified on install
name: String, name: String,
} },
} }
fn main() { fn main() {
@ -110,34 +109,28 @@ fn main() {
} }
match args.subcommand { match args.subcommand {
Commands::RunNebula { args } => { Commands::RunNebula { args } => match run_embedded_nebula(&args) {
match run_embedded_nebula(&args) { Ok(mut c) => match c.wait() {
Ok(mut c) => { Ok(stat) => match stat.code() {
match c.wait() { Some(code) => {
Ok(stat) => { if code != 0 {
match stat.code() { error!("Nebula process exited with nonzero status code {}", code);
Some(code) => {
if code != 0 {
error!("Nebula process exited with nonzero status code {}", code);
}
std::process::exit(code);
},
None => {
info!("Nebula process terminated by signal");
std::process::exit(0);
}
}
},
Err(e) => {
error!("Unable to wait for child to exit: {}", e);
std::process::exit(1);
} }
std::process::exit(code);
}
None => {
info!("Nebula process terminated by signal");
std::process::exit(0);
} }
}, },
Err(e) => { Err(e) => {
error!("Unable to start nebula binary: {}", e); error!("Unable to wait for child to exit: {}", e);
std::process::exit(1); std::process::exit(1);
} }
},
Err(e) => {
error!("Unable to start nebula binary: {}", e);
std::process::exit(1);
} }
}, },
Commands::ClearCache { .. } => { Commands::ClearCache { .. } => {
@ -159,37 +152,34 @@ fn main() {
info!("Removed all cached data."); info!("Removed all cached data.");
std::process::exit(0); std::process::exit(0);
}, }
Commands::RunNebulaCert { args } => { Commands::RunNebulaCert { args } => match run_embedded_nebula_cert(&args) {
match run_embedded_nebula_cert(&args) { Ok(mut c) => match c.wait() {
Ok(mut c) => { Ok(stat) => match stat.code() {
match c.wait() { Some(code) => {
Ok(stat) => { if code != 0 {
match stat.code() { error!(
Some(code) => { "nebula-cert process exited with nonzero status code {}",
if code != 0 { code
error!("nebula-cert process exited with nonzero status code {}", code); );
}
std::process::exit(code);
},
None => {
info!("nebula-cert process terminated by signal");
std::process::exit(0);
}
}
},
Err(e) => {
error!("Unable to wait for child to exit: {}", e);
std::process::exit(1);
} }
std::process::exit(code);
}
None => {
info!("nebula-cert process terminated by signal");
std::process::exit(0);
} }
}, },
Err(e) => { Err(e) => {
error!("Unable to start nebula-cert binary: {}", e); error!("Unable to wait for child to exit: {}", e);
std::process::exit(1); std::process::exit(1);
} }
},
Err(e) => {
error!("Unable to start nebula-cert binary: {}", e);
std::process::exit(1);
} }
} },
Commands::Run { name, server } => { Commands::Run { name, server } => {
daemon::daemon_main(name, server); daemon::daemon_main(name, server);
} }
@ -209,7 +199,7 @@ fn main() {
std::process::exit(1); std::process::exit(1);
} }
}; };
}, }
Commands::Update { name } => { Commands::Update { name } => {
info!("Loading config..."); info!("Loading config...");
let config = match load_config(&name) { let config = match load_config(&name) {
@ -231,5 +221,11 @@ fn main() {
} }
fn print_version() { fn print_version() {
println!("tfclient v{} linked to trifid-pki v{}, embedding nebula v{} and nebula-cert v{}", env!("CARGO_PKG_VERSION"), trifid_pki::TRIFID_PKI_VERSION, crate::nebula_bin::NEBULA_VERSION, crate::nebula_cert_bin::NEBULA_CERT_VERSION); println!(
} "tfclient v{} linked to trifid-pki v{}, embedding nebula v{} and nebula-cert v{}",
env!("CARGO_PKG_VERSION"),
trifid_pki::TRIFID_PKI_VERSION,
crate::nebula_bin::NEBULA_VERSION,
crate::nebula_cert_bin::NEBULA_CERT_VERSION
);
}

View File

@ -1,29 +1,34 @@
// Code to handle the nebula worker // Code to handle the nebula worker
use std::error::Error;
use std::fs;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::time::{Duration, SystemTime};
use log::{debug, error, info};
use crate::config::{load_cdata, NebulaConfig, TFClientConfig}; use crate::config::{load_cdata, NebulaConfig, TFClientConfig};
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
use crate::dirs::get_nebulaconfig_file; use crate::dirs::get_nebulaconfig_file;
use crate::embedded_nebula::run_embedded_nebula; use crate::embedded_nebula::run_embedded_nebula;
use log::{debug, error, info};
use std::error::Error;
use std::fs;
use std::sync::mpsc::Receiver;
use std::time::{Duration, SystemTime};
pub enum NebulaWorkerMessage { pub enum NebulaWorkerMessage {
Shutdown, Shutdown,
ConfigUpdated, ConfigUpdated,
WakeUp WakeUp,
} }
fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> { fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
if !get_nebulaconfig_file(instance).ok_or("Could not get config file location")?.exists() { if !get_nebulaconfig_file(instance)
.ok_or("Could not get config file location")?
.exists()
{
return Ok(()); // cant insert private key into a file that does not exist - BUT. we can gracefully handle nebula crashing - we cannot gracefully handle this fn failing return Ok(()); // cant insert private key into a file that does not exist - BUT. we can gracefully handle nebula crashing - we cannot gracefully handle this fn failing
} }
let cdata = load_cdata(instance)?; let cdata = load_cdata(instance)?;
let key = cdata.dh_privkey.ok_or("Missing private key")?; let key = cdata.dh_privkey.ok_or("Missing private key")?;
let config_str = fs::read_to_string(get_nebulaconfig_file(instance).ok_or("Could not get config file location")?)?; let config_str = fs::read_to_string(
get_nebulaconfig_file(instance).ok_or("Could not get config file location")?,
)?;
let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?; let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?;
config.pki.key = Some(String::from_utf8(key)?); config.pki.key = Some(String::from_utf8(key)?);
@ -31,12 +36,20 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
debug!("inserted private key into config: {:?}", config); debug!("inserted private key into config: {:?}", config);
let config_str = serde_yaml::to_string(&config)?; let config_str = serde_yaml::to_string(&config)?;
fs::write(get_nebulaconfig_file(instance).ok_or("Could not get config file location")?, config_str)?; fs::write(
get_nebulaconfig_file(instance).ok_or("Could not get config file location")?,
config_str,
)?;
Ok(()) Ok(())
} }
pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter: ThreadMessageSender, rx: Receiver<NebulaWorkerMessage>) { pub fn nebulaworker_main(
_config: TFClientConfig,
instance: String,
_transmitter: ThreadMessageSender,
rx: Receiver<NebulaWorkerMessage>,
) {
let _cdata = match load_cdata(&instance) { let _cdata = match load_cdata(&instance) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -50,7 +63,7 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
match insert_private_key(&instance) { match insert_private_key(&instance) {
Ok(_) => { Ok(_) => {
info!("config fixed (private-key embedded)"); info!("config fixed (private-key embedded)");
}, }
Err(e) => { Err(e) => {
error!("unable to fix config: {}", e); error!("unable to fix config: {}", e);
error!("nebula thread exiting with error"); error!("nebula thread exiting with error");
@ -58,7 +71,14 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
} }
} }
info!("starting nebula child..."); info!("starting nebula child...");
let mut child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) { let mut child = match run_embedded_nebula(&[
"-config".to_string(),
get_nebulaconfig_file(&instance)
.unwrap()
.to_str()
.unwrap()
.to_string(),
]) {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
error!("unable to start embedded nebula binary: {}", e); error!("unable to start embedded nebula binary: {}", e);
@ -75,7 +95,14 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
if let Ok(e) = child.try_wait() { if let Ok(e) = child.try_wait() {
if e.is_some() && SystemTime::now() > last_restart_time + Duration::from_secs(5) { if e.is_some() && SystemTime::now() > last_restart_time + Duration::from_secs(5) {
info!("nebula process has exited, restarting"); info!("nebula process has exited, restarting");
child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) { child = match run_embedded_nebula(&[
"-config".to_string(),
get_nebulaconfig_file(&instance)
.unwrap()
.to_str()
.unwrap()
.to_string(),
]) {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
error!("unable to start embedded nebula binary: {}", e); error!("unable to start embedded nebula binary: {}", e);
@ -88,59 +115,64 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
} }
} }
match rx.recv() { match rx.recv() {
Ok(msg) => { Ok(msg) => match msg {
match msg { NebulaWorkerMessage::WakeUp => {
NebulaWorkerMessage::WakeUp => { continue;
continue; }
}, NebulaWorkerMessage::Shutdown => {
NebulaWorkerMessage::Shutdown => { info!("recv on command socket: shutdown, stopping");
info!("recv on command socket: shutdown, stopping"); info!("shutting down nebula binary");
info!("shutting down nebula binary"); match child.kill() {
match child.kill() { Ok(_) => {
Ok(_) => { debug!("nebula process exited");
debug!("nebula process exited");
},
Err(e) => {
error!("nebula process already exited: {}", e);
}
} }
info!("nebula shut down"); Err(e) => {
break; error!("nebula process already exited: {}", e);
},
NebulaWorkerMessage::ConfigUpdated => {
info!("our configuration has been updated - restarting");
debug!("killing existing process");
match child.kill() {
Ok(_) => {
debug!("nebula process exited");
},
Err(e) => {
error!("nebula process already exited: {}", e);
}
} }
debug!("fixing config...");
match insert_private_key(&instance) {
Ok(_) => {
debug!("config fixed (private-key embedded)");
},
Err(e) => {
error!("unable to fix config: {}", e);
error!("nebula thread exiting with error");
return;
}
}
debug!("restarting nebula process");
child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) {
Ok(c) => c,
Err(e) => {
error!("unable to start embedded nebula binary: {}", e);
error!("nebula thread exiting with error");
return;
}
};
last_restart_time = SystemTime::now();
debug!("nebula process restarted");
} }
info!("nebula shut down");
break;
}
NebulaWorkerMessage::ConfigUpdated => {
info!("our configuration has been updated - restarting");
debug!("killing existing process");
match child.kill() {
Ok(_) => {
debug!("nebula process exited");
}
Err(e) => {
error!("nebula process already exited: {}", e);
}
}
debug!("fixing config...");
match insert_private_key(&instance) {
Ok(_) => {
debug!("config fixed (private-key embedded)");
}
Err(e) => {
error!("unable to fix config: {}", e);
error!("nebula thread exiting with error");
return;
}
}
debug!("restarting nebula process");
child = match run_embedded_nebula(&[
"-config".to_string(),
get_nebulaconfig_file(&instance)
.unwrap()
.to_str()
.unwrap()
.to_string(),
]) {
Ok(c) => c,
Err(e) => {
error!("unable to start embedded nebula binary: {}", e);
error!("nebula thread exiting with error");
return;
}
};
last_restart_time = SystemTime::now();
debug!("nebula process restarted");
} }
}, },
Err(e) => { Err(e) => {
@ -149,4 +181,4 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
} }
} }
} }
} }

View File

@ -1,13 +1,16 @@
use crate::config::TFClientConfig;
use crate::socketworker::{ctob, DisconnectReason, JsonMessage, JSON_API_VERSION};
use log::{error, info};
use std::error::Error; use std::error::Error;
use std::io::{BufRead, BufReader, Write}; use std::io::{BufRead, BufReader, Write};
use std::net::{IpAddr, SocketAddr, TcpStream}; use std::net::{IpAddr, SocketAddr, TcpStream};
use log::{error, info};
use crate::config::TFClientConfig;
use crate::socketworker::{ctob, DisconnectReason, JSON_API_VERSION, JsonMessage};
pub fn enroll(code: &str, config: &TFClientConfig) -> Result<(), Box<dyn Error>> { pub fn enroll(code: &str, config: &TFClientConfig) -> Result<(), Box<dyn Error>> {
info!("Connecting to local command socket..."); info!("Connecting to local command socket...");
let mut stream = TcpStream::connect(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?; let mut stream = TcpStream::connect(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
let stream2 = stream.try_clone()?; let stream2 = stream.try_clone()?;
let mut reader = BufReader::new(&stream2); let mut reader = BufReader::new(&stream2);
@ -52,7 +55,10 @@ pub fn enroll(code: &str, config: &TFClientConfig) -> Result<(), Box<dyn Error>>
pub fn update(config: &TFClientConfig) -> Result<(), Box<dyn Error>> { pub fn update(config: &TFClientConfig) -> Result<(), Box<dyn Error>> {
info!("Connecting to local command socket..."); info!("Connecting to local command socket...");
let mut stream = TcpStream::connect(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?; let mut stream = TcpStream::connect(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
let stream2 = stream.try_clone()?; let stream2 = stream.try_clone()?;
let mut reader = BufReader::new(&stream2); let mut reader = BufReader::new(&stream2);
@ -98,4 +104,4 @@ fn read_msg(reader: &mut BufReader<&TcpStream>) -> Result<JsonMessage, Box<dyn E
reader.read_line(&mut str)?; reader.read_line(&mut str)?;
let msg: JsonMessage = serde_json::from_str(&str)?; let msg: JsonMessage = serde_json::from_str(&str)?;
Ok(msg) Ok(msg)
} }

View File

@ -1,25 +1,30 @@
// Code to handle the nebula worker // Code to handle the nebula worker
use std::error::Error; use std::error::Error;
use std::{io, thread}; use std::io::{BufRead, BufReader, Write};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream}; use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream};
use std::sync::mpsc::{Receiver, TryRecvError}; use std::sync::mpsc::Receiver;
use std::{io, thread};
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
use crate::apiworker::APIWorkerMessage; use crate::apiworker::APIWorkerMessage;
use crate::config::{load_cdata, TFClientConfig}; use crate::config::{load_cdata, TFClientConfig};
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage; use crate::nebulaworker::NebulaWorkerMessage;
use crate::timerworker::TimerWorkerMessage; use crate::timerworker::TimerWorkerMessage;
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
pub enum SocketWorkerMessage { pub enum SocketWorkerMessage {
Shutdown, Shutdown,
WakeUp WakeUp,
} }
pub fn socketworker_main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSender, rx: Receiver<SocketWorkerMessage>) { pub fn socketworker_main(
config: TFClientConfig,
instance: String,
transmitter: ThreadMessageSender,
rx: Receiver<SocketWorkerMessage>,
) {
info!("socketworker_main called, entering realmain"); info!("socketworker_main called, entering realmain");
match _main(config, instance, transmitter, rx) { match _main(config, instance, transmitter, rx) {
Ok(_) => (), Ok(_) => (),
@ -29,8 +34,16 @@ pub fn socketworker_main(config: TFClientConfig, instance: String, transmitter:
}; };
} }
fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSender, rx: Receiver<SocketWorkerMessage>) -> Result<(), Box<dyn Error>> { fn _main(
let listener = TcpListener::bind(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?; config: TFClientConfig,
instance: String,
transmitter: ThreadMessageSender,
rx: Receiver<SocketWorkerMessage>,
) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
listener.set_nonblocking(true)?; listener.set_nonblocking(true)?;
loop { loop {
@ -47,21 +60,21 @@ fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSen
} }
} }
}); });
}, }
Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => { Err(e)?; } Err(e) => {
Err(e)?;
}
} }
match rx.recv() { match rx.recv() {
Ok(msg) => { Ok(msg) => match msg {
match msg { SocketWorkerMessage::Shutdown => {
SocketWorkerMessage::Shutdown => { info!("recv on command socket: shutdown, stopping");
info!("recv on command socket: shutdown, stopping"); break;
break; }
}, SocketWorkerMessage::WakeUp => {
SocketWorkerMessage::WakeUp => { continue;
continue;
}
} }
}, },
Err(e) => { Err(e) => {
@ -74,22 +87,27 @@ fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSen
Ok(()) Ok(())
} }
fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSender, config: TFClientConfig, instance: String) -> Result<(), io::Error> { fn handle_stream(
stream: (TcpStream, SocketAddr),
transmitter: ThreadMessageSender,
config: TFClientConfig,
instance: String,
) -> Result<(), io::Error> {
info!("Incoming client"); info!("Incoming client");
match handle_client(stream.0, transmitter, config, instance) { match handle_client(stream.0, transmitter, config, instance) {
Ok(()) => (), Ok(()) => (),
Err(e) if e.kind() == io::ErrorKind::TimedOut => { Err(e) if e.kind() == io::ErrorKind::TimedOut => {
warn!("Client timed out, connection aborted"); warn!("Client timed out, connection aborted");
}, }
Err(e) if e.kind() == io::ErrorKind::NotConnected => { Err(e) if e.kind() == io::ErrorKind::NotConnected => {
warn!("Client connection severed"); warn!("Client connection severed");
}, }
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => { Err(e) if e.kind() == io::ErrorKind::BrokenPipe => {
warn!("Client connection returned error: broken pipe"); warn!("Client connection returned error: broken pipe");
}, }
Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => { Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
warn!("Client aborted connection"); warn!("Client aborted connection");
}, }
Err(e) => { Err(e) => {
error!("Error in client handler: {}", e); error!("Error in client handler: {}", e);
return Err(e); return Err(e);
@ -98,15 +116,18 @@ fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSend
Ok(()) Ok(())
} }
fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TFClientConfig, instance: String) -> Result<(), io::Error> { fn handle_client(
stream: TcpStream,
transmitter: ThreadMessageSender,
_config: TFClientConfig,
instance: String,
) -> Result<(), io::Error> {
info!("Handling connection from {}", stream.peer_addr()?); info!("Handling connection from {}", stream.peer_addr()?);
let mut client = Client { let mut client = Client {
state: ClientState::WaitHello, state: ClientState::WaitHello,
reader: BufReader::new(&stream), reader: BufReader::new(&stream),
writer: BufWriter::new(&stream),
stream: &stream, stream: &stream,
config,
instance, instance,
}; };
@ -118,18 +139,14 @@ fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TF
trace!("recv {:?} from {}", command, client.stream.peer_addr()?); trace!("recv {:?} from {}", command, client.stream.peer_addr()?);
let should_disconnect; let should_disconnect = match client.state {
ClientState::WaitHello => waithello_handle(&mut client, &transmitter, command)?,
ClientState::SentHello => senthello_handle(&mut client, &transmitter, command)?,
};
match client.state { if should_disconnect {
ClientState::WaitHello => { break;
should_disconnect = waithello_handle(&mut client, &transmitter, command)?;
}
ClientState::SentHello => {
should_disconnect = senthello_handle(&mut client, &transmitter, command)?;
}
} }
if should_disconnect { break; }
} }
// Gracefully close the connection // Gracefully close the connection
@ -141,13 +158,15 @@ fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TF
struct Client<'a> { struct Client<'a> {
state: ClientState, state: ClientState,
reader: BufReader<&'a TcpStream>, reader: BufReader<&'a TcpStream>,
writer: BufWriter<&'a TcpStream>,
stream: &'a TcpStream, stream: &'a TcpStream,
config: TFClientConfig, instance: String,
instance: String
} }
fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, command: JsonMessage) -> Result<bool, io::Error> { fn waithello_handle(
client: &mut Client,
_transmitter: &ThreadMessageSender,
command: JsonMessage,
) -> Result<bool, io::Error> {
trace!("state: WaitHello, handing with waithello_handle"); trace!("state: WaitHello, handing with waithello_handle");
let mut should_disconnect = false; let mut should_disconnect = false;
@ -158,20 +177,20 @@ fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, com
client.stream.write_all(&ctob(JsonMessage::Goodbye { client.stream.write_all(&ctob(JsonMessage::Goodbye {
reason: DisconnectReason::UnsupportedVersion { reason: DisconnectReason::UnsupportedVersion {
expected: JSON_API_VERSION, expected: JSON_API_VERSION,
got: version got: version,
} },
}))?; }))?;
} }
client.stream.write_all(&ctob(JsonMessage::Hello { client.stream.write_all(&ctob(JsonMessage::Hello {
version: JSON_API_VERSION version: JSON_API_VERSION,
}))?; }))?;
client.state = ClientState::SentHello; client.state = ClientState::SentHello;
trace!("setting state to SentHello"); trace!("setting state to SentHello");
}, }
JsonMessage::Goodbye { reason } => { JsonMessage::Goodbye { reason } => {
info!("Client sent disconnect: {:?}", reason); info!("Client sent disconnect: {:?}", reason);
should_disconnect = true; should_disconnect = true;
}, }
_ => { _ => {
debug!("message type unexpected in WaitHello state"); debug!("message type unexpected in WaitHello state");
should_disconnect = true; should_disconnect = true;
@ -184,7 +203,11 @@ fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, com
Ok(should_disconnect) Ok(should_disconnect)
} }
fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, command: JsonMessage) -> Result<bool, io::Error> { fn senthello_handle(
client: &mut Client,
transmitter: &ThreadMessageSender,
command: JsonMessage,
) -> Result<bool, io::Error> {
trace!("state: SentHello, handing with senthello_handle"); trace!("state: SentHello, handing with senthello_handle");
let mut should_disconnect = false; let mut should_disconnect = false;
@ -192,14 +215,20 @@ fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, comm
JsonMessage::Goodbye { reason } => { JsonMessage::Goodbye { reason } => {
info!("Client sent disconnect: {:?}", reason); info!("Client sent disconnect: {:?}", reason);
should_disconnect = true; should_disconnect = true;
}, }
JsonMessage::Shutdown {} => { JsonMessage::Shutdown {} => {
info!("Requested to shutdown by local control socket. Sending shutdown message to threads"); info!("Requested to shutdown by local control socket. Sending shutdown message to threads");
match transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) { match transmitter
.nebula_thread
.send(NebulaWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to nebula worker thread: {}", e); error!(
"Error sending shutdown message to nebula worker thread: {}",
e
);
} }
} }
match transmitter.api_thread.send(APIWorkerMessage::Shutdown) { match transmitter.api_thread.send(APIWorkerMessage::Shutdown) {
@ -208,19 +237,28 @@ fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, comm
error!("Error sending shutdown message to api worker thread: {}", e); error!("Error sending shutdown message to api worker thread: {}", e);
} }
} }
match transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) { match transmitter
.socket_thread
.send(SocketWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to socket worker thread: {}", e); error!(
"Error sending shutdown message to socket worker thread: {}",
e
);
} }
} }
match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) { match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to timer worker thread: {}", e); error!(
"Error sending shutdown message to timer worker thread: {}",
e
);
} }
} }
}, }
JsonMessage::GetHostID {} => { JsonMessage::GetHostID {} => {
let data = match load_cdata(&client.instance) { let data = match load_cdata(&client.instance) {
@ -232,20 +270,26 @@ fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, comm
}; };
client.stream.write_all(&ctob(JsonMessage::HostID { client.stream.write_all(&ctob(JsonMessage::HostID {
has_id: data.creds.is_some(), has_id: data.creds.is_some(),
id: data.creds.map(|c| c.host_id) id: data.creds.map(|c| c.host_id),
}))?; }))?;
}, }
JsonMessage::Enroll { code } => { JsonMessage::Enroll { code } => {
info!("Client sent enroll with code {}", code); info!("Client sent enroll with code {}", code);
info!("Sending enroll request to apiworker"); info!("Sending enroll request to apiworker");
transmitter.api_thread.send(APIWorkerMessage::Enroll { code }).unwrap(); transmitter
}, .api_thread
.send(APIWorkerMessage::Enroll { code })
.unwrap();
}
JsonMessage::Update {} => { JsonMessage::Update {} => {
info!("Client sent update request."); info!("Client sent update request.");
info!("Telling apiworker to update configuration"); info!("Telling apiworker to update configuration");
transmitter.api_thread.send(APIWorkerMessage::Update).unwrap(); transmitter
.api_thread
.send(APIWorkerMessage::Update)
.unwrap();
} }
_ => { _ => {
@ -267,7 +311,7 @@ pub fn ctob(command: JsonMessage) -> Vec<u8> {
enum ClientState { enum ClientState {
WaitHello, WaitHello,
SentHello SentHello,
} }
pub const JSON_API_VERSION: i32 = 1; pub const JSON_API_VERSION: i32 = 1;
@ -276,28 +320,19 @@ pub const JSON_API_VERSION: i32 = 1;
#[serde(tag = "method")] #[serde(tag = "method")]
pub enum JsonMessage { pub enum JsonMessage {
#[serde(rename = "hello")] #[serde(rename = "hello")]
Hello { Hello { version: i32 },
version: i32
},
#[serde(rename = "goodbye")] #[serde(rename = "goodbye")]
Goodbye { Goodbye { reason: DisconnectReason },
reason: DisconnectReason
},
#[serde(rename = "shutdown")] #[serde(rename = "shutdown")]
Shutdown {}, Shutdown {},
#[serde(rename = "get_host_id")] #[serde(rename = "get_host_id")]
GetHostID {}, GetHostID {},
#[serde(rename = "host_id")] #[serde(rename = "host_id")]
HostID { HostID { has_id: bool, id: Option<String> },
has_id: bool,
id: Option<String>
},
#[serde(rename = "enroll")] #[serde(rename = "enroll")]
Enroll { Enroll { code: String },
code: String
},
#[serde(rename = "update")] #[serde(rename = "update")]
Update {} Update {},
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -308,5 +343,5 @@ pub enum DisconnectReason {
#[serde(rename = "unexpected_message_type")] #[serde(rename = "unexpected_message_type")]
UnexpectedMessageType, UnexpectedMessageType,
#[serde(rename = "done")] #[serde(rename = "done")]
Done Done,
} }

View File

@ -1,15 +1,15 @@
use std::ops::Add;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::thread;
use std::time::{Duration, SystemTime};
use log::{error, info};
use crate::apiworker::APIWorkerMessage; use crate::apiworker::APIWorkerMessage;
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage; use crate::nebulaworker::NebulaWorkerMessage;
use crate::socketworker::SocketWorkerMessage; use crate::socketworker::SocketWorkerMessage;
use log::{error, info};
use std::ops::Add;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::thread;
use std::time::{Duration, SystemTime};
pub enum TimerWorkerMessage { pub enum TimerWorkerMessage {
Shutdown Shutdown,
} }
pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) { pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) {
@ -19,23 +19,19 @@ pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) {
thread::sleep(Duration::from_secs(10)); thread::sleep(Duration::from_secs(10));
match rx.try_recv() { match rx.try_recv() {
Ok(msg) => { Ok(msg) => match msg {
match msg { TimerWorkerMessage::Shutdown => {
TimerWorkerMessage::Shutdown => { info!("recv on command socket: shutdown, stopping");
info!("recv on command socket: shutdown, stopping"); break;
break;
}
} }
}, },
Err(e) => { Err(e) => match e {
match e { TryRecvError::Empty => {}
TryRecvError::Empty => {} TryRecvError::Disconnected => {
TryRecvError::Disconnected => { error!("timerworker command socket disconnected, shutting down to prevent orphaning");
error!("timerworker command socket disconnected, shutting down to prevent orphaning"); break;
break;
}
} }
} },
} }
if SystemTime::now().gt(&api_reload_timer) { if SystemTime::now().gt(&api_reload_timer) {
@ -52,15 +48,21 @@ pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) {
match tx.nebula_thread.send(NebulaWorkerMessage::WakeUp) { match tx.nebula_thread.send(NebulaWorkerMessage::WakeUp) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending wakeup message to nebula worker thread: {}", e); error!(
"Error sending wakeup message to nebula worker thread: {}",
e
);
} }
} }
match tx.socket_thread.send(SocketWorkerMessage::WakeUp) { match tx.socket_thread.send(SocketWorkerMessage::WakeUp) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending wakeup message to socket worker thread: {}", e); error!(
"Error sending wakeup message to socket worker thread: {}",
e
);
} }
} }
} }
} }

View File

@ -1,6 +1,6 @@
use log::{error, warn}; use log::{error, warn};
use sha2::Sha256;
use sha2::Digest; use sha2::Digest;
use sha2::Sha256;
use url::Url; use url::Url;
pub fn sha256(bytes: &[u8]) -> String { pub fn sha256(bytes: &[u8]) -> String {
@ -11,7 +11,7 @@ pub fn sha256(bytes: &[u8]) -> String {
} }
pub fn check_server_url(server: &str) { pub fn check_server_url(server: &str) {
let api_base = match Url::parse(&server) { let api_base = match Url::parse(server) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
error!("Invalid server url `{}`: {}", server, e); error!("Invalid server url `{}`: {}", server, e);
@ -19,11 +19,16 @@ pub fn check_server_url(server: &str) {
} }
}; };
match api_base.scheme() { match api_base.scheme() {
"http" => { warn!("HTTP api urls are not reccomended. Please switch to HTTPS if possible.") }, "http" => {
warn!("HTTP api urls are not reccomended. Please switch to HTTPS if possible.")
}
"https" => (), "https" => (),
_ => { _ => {
error!("Unsupported protocol `{}` (expected one of http, https)", api_base.scheme()); error!(
"Unsupported protocol `{}` (expected one of http, https)",
api_base.scheme()
);
std::process::exit(1); std::process::exit(1);
} }
} }
} }

View File

@ -12,8 +12,9 @@ actix-request-identifier = "4" # Web framework
serde = { version = "1", features = ["derive"] } # Serialization and deserialization serde = { version = "1", features = ["derive"] } # Serialization and deserialization
serde_json = "1.0.95" # Serialization and deserialization (cursors) serde_json = "1.0.95" # Serialization and deserialization (cursors)
once_cell = "1" # Config once_cell = "1" # Config
toml = "0.7" # Config / Serialization and deserialization toml = "0.7" # Config / Serialization and deserialization
serde_yaml = "0.9.21" # Config / Serialization and deserialization
log = "0.4" # Logging log = "0.4" # Logging
simple_logger = "4" # Logging simple_logger = "4" # Logging
@ -28,5 +29,9 @@ totp-rs = { version = "5.0.1", features = ["gen_secret", "otpauth"] } # Misc.
base64 = "0.21.0" # Misc. base64 = "0.21.0" # Misc.
chrono = "0.4.24" # Misc. chrono = "0.4.24" # Misc.
trifid-pki = { version = "0.1.9" } # Cryptography trifid-pki = { version = "0.1.9", features = ["serde_derive"] } # Cryptography
aes-gcm = "0.10.1" # Cryptography aes-gcm = "0.10.1" # Cryptography
ed25519-dalek = "2.0.0-rc.2" # Cryptography
dnapi-rs = "0.1.9" # API message types
ipnet = "2.7.2" # API message types

View File

@ -120,4 +120,10 @@ url = "your-database-url-here"
# ------- WARNING ------- # ------- WARNING -------
# Do not change this value in a production instance. It will make existing data inaccessible until changed back. # Do not change this value in a production instance. It will make existing data inaccessible until changed back.
# ------- WARNING ------- # ------- WARNING -------
data-key = "edd600bcebea461381ea23791b6967c8667e12827ac8b94dc022f189a5dc59a2" data-key = "edd600bcebea461381ea23791b6967c8667e12827ac8b94dc022f189a5dc59a2"
# The data directory used for storing keys, configuration, signing keys, etc. Must be writable by this instance.
# This directory will be used to store very sensitive data - protect it like a password! It should be writable by
# this instance and ONLY this instance.
# Do not modify any files in this directory manually unless directed to do so by trifid.
local_keystore_directory = "./trifid_data"

View File

@ -0,0 +1,338 @@
use std::collections::HashMap;
use std::error::Error;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::time::{Duration, SystemTime};
use actix_web::web::Data;
use crate::config::{NebulaConfig, NebulaConfigCipher, NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy, NebulaConfigRelay, NebulaConfigTun, CONFIG, NebulaConfigFirewall, NebulaConfigFirewallRule};
use crate::crypto::{decrypt_with_nonce, encrypt_with_nonce, get_cipher_from_config};
use crate::AppState;
use ed25519_dalek::SigningKey;
use ipnet::Ipv4Net;
use log::{debug, error};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use trifid_api_entities::entity::{firewall_rule, host, host_config_override, host_static_address, network, organization, signing_ca};
use trifid_pki::cert::{
deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate,
NebulaCertificateDetails,
};
pub struct CodegenRequiredInfo {
pub host: host::Model,
pub host_static_addresses: HashMap<String, Vec<SocketAddrV4>>,
pub host_config_overrides: Vec<host_config_override::Model>,
pub network: network::Model,
pub organization: organization::Model,
pub dh_pubkey: Vec<u8>,
pub ca: signing_ca::Model,
pub other_cas: Vec<signing_ca::Model>,
pub relay_ips: Vec<Ipv4Addr>,
pub lighthouse_ips: Vec<Ipv4Addr>,
pub blocked_hosts: Vec<String>,
pub firewall_rules: Vec<NebulaConfigFirewallRule>
}
pub async fn generate_config(
data: &Data<AppState>,
info: &CodegenRequiredInfo,
) -> Result<(NebulaConfig, NebulaCertificate), Box<dyn Error>> {
debug!("chk: deserialize CA cert {:x?}", hex::decode(&info.ca.cert)?);
// decode the CA data
let ca_cert = deserialize_nebula_certificate_from_pem(&hex::decode(&info.ca.cert)?)?;
// generate the client's new cert
let mut cert = NebulaCertificate {
details: NebulaCertificateDetails {
name: info.host.name.clone(),
ips: vec![Ipv4Net::new(
Ipv4Addr::from_str(&info.host.ip).unwrap(),
Ipv4Net::from_str(&info.network.cidr).unwrap().prefix_len(),
)
.unwrap()],
subnets: vec![],
groups: vec![
format!("role:{}", info.host.role)
],
not_before: SystemTime::now(),
not_after: SystemTime::now() + Duration::from_secs(CONFIG.crypto.certs_expiry_time),
public_key: info.dh_pubkey.clone().try_into().unwrap(),
is_ca: false,
issuer: ca_cert.sha256sum()?,
},
signature: vec![],
};
// decrypt the private key
let private_pem = decrypt_with_nonce(
&hex::decode(&info.ca.key)?,
hex::decode(&info.ca.nonce)?.try_into().unwrap(),
&get_cipher_from_config(&CONFIG)?,
)
.map_err(|_| "Encryption error")?;
let private_key = deserialize_ed25519_private(&private_pem)?;
let signing_key = SigningKey::from_keypair_bytes(&private_key.try_into().unwrap()).unwrap();
cert.sign(&signing_key)?;
// cas
let mut cas = String::new();
for ca in &info.other_cas {
cas += &String::from_utf8(hex::decode(&ca.cert)?)?;
}
// blocked hosts
let mut blocked_hosts_fingerprints = vec![];
for host in &info.blocked_hosts {
if let Some(host) = data.keystore.hosts.iter().find(|u| &u.id == host) {
for cert in &host.certs {
blocked_hosts_fingerprints.push(cert.cert.sha256sum()?);
}
}
}
let nebula_config = NebulaConfig {
pki: NebulaConfigPki {
ca: cas,
cert: String::from_utf8(cert.serialize_to_pem()?)?,
key: None,
blocklist: blocked_hosts_fingerprints,
disconnect_invalid: true,
},
static_host_map: info
.host_static_addresses
.iter()
.map(|(u, addrs)| (Ipv4Addr::from_str(u).unwrap(), addrs.clone()))
.collect(),
lighthouse: match info.host.is_lighthouse {
true => Some(NebulaConfigLighthouse {
am_lighthouse: true,
serve_dns: false,
dns: None,
interval: 60,
hosts: vec![],
remote_allow_list: HashMap::new(),
local_allow_list: HashMap::new(),
}),
false => Some(NebulaConfigLighthouse {
am_lighthouse: false,
serve_dns: false,
dns: None,
interval: 60,
hosts: info.lighthouse_ips.to_vec(),
remote_allow_list: HashMap::new(),
local_allow_list: HashMap::new(),
}),
},
listen: match info.host.is_lighthouse || info.host.is_relay {
true => Some(NebulaConfigListen {
host: "[::]".to_string(),
port: info.host.listen_port as u16,
batch: 64,
read_buffer: Some(10485760),
write_buffer: Some(10485760),
}),
false => Some(NebulaConfigListen {
host: "[::]".to_string(),
port: 0u16,
batch: 64,
read_buffer: Some(10485760),
write_buffer: Some(10485760),
}),
},
punchy: Some(NebulaConfigPunchy {
punch: true,
respond: true,
delay: "".to_string(),
}),
cipher: NebulaConfigCipher::Aes,
preferred_ranges: vec![],
relay: Some(NebulaConfigRelay {
relays: info.relay_ips.to_vec(),
am_relay: info.host.is_relay,
use_relays: true,
}),
tun: Some(NebulaConfigTun {
disabled: false,
dev: Some("trifid1".to_string()),
drop_local_broadcast: true,
drop_multicast: true,
tx_queue: 500,
mtu: 1300,
routes: vec![],
unsafe_routes: vec![],
}),
logging: None,
sshd: None,
firewall: Some(NebulaConfigFirewall {
conntrack: None,
inbound: Some(info.firewall_rules.clone()),
outbound: Some(vec![
NebulaConfigFirewallRule {
port: Some("any".to_string()),
proto: Some("any".to_string()),
ca_name: None,
ca_sha: None,
host: Some("any".to_string()),
group: None,
groups: None,
cidr: None,
}
]),
}),
routines: 0,
stats: None,
local_range: None,
};
Ok((nebula_config, cert))
}
pub async fn collect_info<'a>(
db: &'a Data<AppState>,
host: &'a str,
dh_pubkey: &'a [u8],
) -> Result<CodegenRequiredInfo, Box<dyn Error>> {
// load host info
let host = trifid_api_entities::entity::host::Entity::find()
.filter(host::Column::Id.eq(host))
.one(&db.conn)
.await?;
let host = match host {
Some(host) => host,
None => return Err("Host does not exist".into()),
};
let host_config_overrides = trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Id.eq(&host.id))
.all(&db.conn)
.await?;
let _host_static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
.filter(host_static_address::Column::Id.eq(&host.id))
.all(&db.conn)
.await?;
// load network info
let network = trifid_api_entities::entity::network::Entity::find()
.filter(network::Column::Id.eq(&host.network))
.one(&db.conn)
.await?;
let network = match network {
Some(network) => network,
None => {
return Err("Network does not exist".into());
}
};
// get all lighthouses and relays and get all of their static addresses, and get internal addresses of relays
let mut host_x_static_addresses = HashMap::new();
let mut relays = vec![];
let mut lighthouses = vec![];
let mut blocked_hosts = vec![];
let hosts = trifid_api_entities::entity::host::Entity::find()
.filter(host::Column::Network.eq(&network.id))
.filter(host::Column::IsRelay.eq(true))
.filter(host::Column::IsLighthouse.eq(true))
.all(&db.conn)
.await?;
for host in hosts {
if host.is_relay {
relays.push(Ipv4Addr::from_str(&host.ip).unwrap());
} else if host.is_lighthouse {
lighthouses.push(Ipv4Addr::from_str(&host.ip).unwrap());
}
if host.is_blocked {
blocked_hosts.push(host.id.clone());
}
let static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
.filter(host_static_address::Column::Host.eq(host.id))
.all(&db.conn)
.await?;
let static_addresses: Vec<SocketAddrV4> = static_addresses
.iter()
.map(|u| SocketAddrV4::from_str(&u.address).unwrap())
.collect();
host_x_static_addresses.insert(host.ip.clone(), static_addresses);
}
// load org info
let org = trifid_api_entities::entity::organization::Entity::find()
.filter(organization::Column::Id.eq(&network.organization))
.one(&db.conn)
.await?;
let org = match org {
Some(org) => org,
None => {
return Err("Organization does not exist".into());
}
};
// get the CA that is closest to expiry, but *not* expired
let available_cas = trifid_api_entities::entity::signing_ca::Entity::find()
.filter(signing_ca::Column::Organization.eq(&org.id))
.all(&db.conn)
.await?;
let mut best_ca: Option<signing_ca::Model> = None;
let mut all_cas = vec![];
for ca in available_cas {
if let Some(existing_best) = &best_ca {
if ca.expires < existing_best.expires {
best_ca = Some(ca.clone());
}
} else {
best_ca = Some(ca.clone());
}
all_cas.push(ca);
}
if best_ca.is_none() {
error!(
"!!! NO AVAILABLE CAS !!! while trying to sign cert for {}",
org.id
);
return Err("No signing CAs available".into());
}
let best_ca = best_ca.unwrap();
// pull our role's firewall rules
let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find().filter(firewall_rule::Column::Role.eq(&host.id)).all(&db.conn).await?;
let firewall_rules = firewall_rules.iter().map(|u| {
NebulaConfigFirewallRule {
port: Some(if u.port_range_from == 0 && u.port_range_to == 65535 { "any".to_string() } else { format!("{}-{}", u.port_range_from, u.port_range_to) }),
proto: Some(u.protocol.clone()),
ca_name: None,
ca_sha: None,
host: if u.allowed_role_id.is_some() { None } else { Some("any".to_string()) },
groups: if u.allowed_role_id.is_some() { Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())])} else { None },
group: None,
cidr: None,
}
}).collect();
Ok(CodegenRequiredInfo {
host,
host_static_addresses: host_x_static_addresses,
host_config_overrides,
network,
organization: org,
dh_pubkey: dh_pubkey.to_vec(),
ca: best_ca,
other_cas: all_cas,
relay_ips: relays,
lighthouse_ips: lighthouses,
blocked_hosts,
firewall_rules
})
}

View File

@ -14,11 +14,14 @@
// You should have received a copy of the GNU General Public License // You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
use ipnet::{IpNet, Ipv4Net};
use log::error; use log::error;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs; use std::fs;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::path::PathBuf;
pub static CONFIG: Lazy<TrifidConfig> = Lazy::new(|| { pub static CONFIG: Lazy<TrifidConfig> = Lazy::new(|| {
let config_str = match fs::read_to_string("/etc/trifid/config.toml") { let config_str = match fs::read_to_string("/etc/trifid/config.toml") {
@ -88,6 +91,9 @@ pub struct TrifidConfigTokens {
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct TrifidConfigCryptography { pub struct TrifidConfigCryptography {
pub data_encryption_key: String, pub data_encryption_key: String,
pub local_keystore_directory: PathBuf,
#[serde(default = "certs_expiry_time")]
pub certs_expiry_time: u64,
} }
fn max_connections_default() -> u32 { fn max_connections_default() -> u32 {
@ -120,3 +126,534 @@ fn mfa_tokens_expiry_time() -> u64 {
fn enrollment_tokens_expiry_time() -> u64 { fn enrollment_tokens_expiry_time() -> u64 {
600 600
} // 10 minutes } // 10 minutes
fn certs_expiry_time() -> u64 {
3600 * 24 * 31 * 12 // 1 year
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfig {
pub pki: NebulaConfigPki,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub static_host_map: HashMap<Ipv4Addr, Vec<SocketAddrV4>>,
#[serde(skip_serializing_if = "is_none")]
pub lighthouse: Option<NebulaConfigLighthouse>,
#[serde(skip_serializing_if = "is_none")]
pub listen: Option<NebulaConfigListen>,
#[serde(skip_serializing_if = "is_none")]
pub punchy: Option<NebulaConfigPunchy>,
#[serde(default = "cipher_aes")]
#[serde(skip_serializing_if = "is_cipher_aes")]
pub cipher: NebulaConfigCipher,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub preferred_ranges: Vec<IpNet>,
#[serde(skip_serializing_if = "is_none")]
pub relay: Option<NebulaConfigRelay>,
#[serde(skip_serializing_if = "is_none")]
pub tun: Option<NebulaConfigTun>,
#[serde(skip_serializing_if = "is_none")]
pub logging: Option<NebulaConfigLogging>,
#[serde(skip_serializing_if = "is_none")]
pub sshd: Option<NebulaConfigSshd>,
#[serde(skip_serializing_if = "is_none")]
pub firewall: Option<NebulaConfigFirewall>,
#[serde(default = "u64_1")]
#[serde(skip_serializing_if = "is_u64_1")]
pub routines: u64,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub stats: Option<NebulaConfigStats>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub local_range: Option<Ipv4Net>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigPki {
pub ca: String,
pub cert: String,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub key: Option<String>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub blocklist: Vec<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disconnect_invalid: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLighthouse {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub am_lighthouse: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub serve_dns: bool,
#[serde(skip_serializing_if = "is_none")]
pub dns: Option<NebulaConfigLighthouseDns>,
#[serde(default = "u32_10")]
#[serde(skip_serializing_if = "is_u32_10")]
pub interval: u32,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub hosts: Vec<Ipv4Addr>,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub remote_allow_list: HashMap<Ipv4Net, bool>,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub local_allow_list: HashMap<Ipv4Net, bool>, // `interfaces` is not supported
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLighthouseDns {
#[serde(default = "string_empty")]
#[serde(skip_serializing_if = "is_string_empty")]
pub host: String,
#[serde(default = "u16_53")]
#[serde(skip_serializing_if = "is_u16_53")]
pub port: u16,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigListen {
#[serde(default = "string_empty")]
#[serde(skip_serializing_if = "is_string_empty")]
pub host: String,
#[serde(default = "u16_0")]
#[serde(skip_serializing_if = "is_u16_0")]
pub port: u16,
#[serde(default = "u32_64")]
#[serde(skip_serializing_if = "is_u32_64")]
pub batch: u32,
#[serde(skip_serializing_if = "is_none")]
pub read_buffer: Option<u32>,
#[serde(skip_serializing_if = "is_none")]
pub write_buffer: Option<u32>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigPunchy {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub punch: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub respond: bool,
#[serde(default = "string_1s")]
#[serde(skip_serializing_if = "is_string_1s")]
pub delay: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigCipher {
#[serde(rename = "aes")]
Aes,
#[serde(rename = "chachapoly")]
ChaChaPoly,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigRelay {
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub relays: Vec<Ipv4Addr>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub am_relay: bool,
#[serde(default = "bool_true")]
#[serde(skip_serializing_if = "is_bool_true")]
pub use_relays: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTun {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disabled: bool,
#[serde(skip_serializing_if = "is_none")]
pub dev: Option<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub drop_local_broadcast: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub drop_multicast: bool,
#[serde(default = "u64_500")]
#[serde(skip_serializing_if = "is_u64_500")]
pub tx_queue: u64,
#[serde(default = "u64_1300")]
#[serde(skip_serializing_if = "is_u64_1300")]
pub mtu: u64,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub routes: Vec<NebulaConfigTunRouteOverride>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunRouteOverride {
pub mtu: u64,
pub route: Ipv4Net,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunUnsafeRoute {
pub route: Ipv4Net,
pub via: Ipv4Addr,
#[serde(default = "u64_1300")]
#[serde(skip_serializing_if = "is_u64_1300")]
pub mtu: u64,
#[serde(default = "i64_100")]
#[serde(skip_serializing_if = "is_i64_100")]
pub metric: i64,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLogging {
#[serde(default = "loglevel_info")]
#[serde(skip_serializing_if = "is_loglevel_info")]
pub level: NebulaConfigLoggingLevel,
#[serde(default = "format_text")]
#[serde(skip_serializing_if = "is_format_text")]
pub format: NebulaConfigLoggingFormat,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disable_timestamp: bool,
#[serde(default = "timestamp")]
#[serde(skip_serializing_if = "is_timestamp")]
pub timestamp_format: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigLoggingLevel {
#[serde(rename = "panic")]
Panic,
#[serde(rename = "fatal")]
Fatal,
#[serde(rename = "error")]
Error,
#[serde(rename = "warning")]
Warning,
#[serde(rename = "info")]
Info,
#[serde(rename = "debug")]
Debug,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigLoggingFormat {
#[serde(rename = "json")]
Json,
#[serde(rename = "text")]
Text,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigSshd {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub enabled: bool,
pub listen: SocketAddrV4,
pub host_key: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigSshdAuthorizedUser {
pub user: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub keys: Vec<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "type")]
pub enum NebulaConfigStats {
#[serde(rename = "graphite")]
Graphite(NebulaConfigStatsGraphite),
#[serde(rename = "prometheus")]
Prometheus(NebulaConfigStatsPrometheus),
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigStatsGraphite {
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub prefix: String,
#[serde(default = "protocol_tcp")]
#[serde(skip_serializing_if = "is_protocol_tcp")]
pub protocol: NebulaConfigStatsGraphiteProtocol,
pub host: SocketAddrV4,
pub interval: String,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigStatsGraphiteProtocol {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "udp")]
Udp,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigStatsPrometheus {
pub listen: String,
pub path: String,
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub namespace: String,
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub subsystem: String,
pub interval: String,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewall {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub conntrack: Option<NebulaConfigFirewallConntrack>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub inbound: Option<Vec<NebulaConfigFirewallRule>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub outbound: Option<Vec<NebulaConfigFirewallRule>>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewallConntrack {
#[serde(default = "string_12m")]
#[serde(skip_serializing_if = "is_string_12m")]
pub tcp_timeout: String,
#[serde(default = "string_3m")]
#[serde(skip_serializing_if = "is_string_3m")]
pub udp_timeout: String,
#[serde(default = "string_10m")]
#[serde(skip_serializing_if = "is_string_10m")]
pub default_timeout: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewallRule {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub port: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub proto: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub ca_name: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub ca_sha: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub host: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub group: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub groups: Option<Vec<String>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub cidr: Option<String>,
}
// Default values for serde
fn string_12m() -> String {
"12m".to_string()
}
fn is_string_12m(s: &str) -> bool {
s == "12m"
}
fn string_3m() -> String {
"3m".to_string()
}
fn is_string_3m(s: &str) -> bool {
s == "3m"
}
fn string_10m() -> String {
"10m".to_string()
}
fn is_string_10m(s: &str) -> bool {
s == "10m"
}
fn empty_vec<T>() -> Vec<T> {
vec![]
}
fn is_empty_vec<T>(v: &Vec<T>) -> bool {
v.is_empty()
}
fn empty_hashmap<A, B>() -> HashMap<A, B> {
HashMap::new()
}
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool {
h.is_empty()
}
fn bool_false() -> bool {
false
}
fn is_bool_false(b: &bool) -> bool {
!*b
}
fn bool_true() -> bool {
true
}
fn is_bool_true(b: &bool) -> bool {
*b
}
fn u16_53() -> u16 {
53
}
fn is_u16_53(u: &u16) -> bool {
*u == 53
}
fn u32_10() -> u32 {
10
}
fn is_u32_10(u: &u32) -> bool {
*u == 10
}
fn u16_0() -> u16 {
0
}
fn is_u16_0(u: &u16) -> bool {
*u == 0
}
fn u32_64() -> u32 {
64
}
fn is_u32_64(u: &u32) -> bool {
*u == 64
}
fn string_1s() -> String {
"1s".to_string()
}
fn is_string_1s(s: &str) -> bool {
s == "1s"
}
fn cipher_aes() -> NebulaConfigCipher {
NebulaConfigCipher::Aes
}
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool {
matches!(c, NebulaConfigCipher::Aes)
}
fn u64_500() -> u64 {
500
}
fn is_u64_500(u: &u64) -> bool {
*u == 500
}
fn u64_1300() -> u64 {
1300
}
fn is_u64_1300(u: &u64) -> bool {
*u == 1300
}
fn i64_100() -> i64 {
100
}
fn is_i64_100(i: &i64) -> bool {
*i == 100
}
fn loglevel_info() -> NebulaConfigLoggingLevel {
NebulaConfigLoggingLevel::Info
}
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool {
matches!(l, NebulaConfigLoggingLevel::Info)
}
fn format_text() -> NebulaConfigLoggingFormat {
NebulaConfigLoggingFormat::Text
}
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool {
matches!(f, NebulaConfigLoggingFormat::Text)
}
fn timestamp() -> String {
"2006-01-02T15:04:05Z07:00".to_string()
}
fn is_timestamp(s: &str) -> bool {
s == "2006-01-02T15:04:05Z07:00"
}
fn u64_1() -> u64 {
1
}
fn is_u64_1(u: &u64) -> bool {
*u == 1
}
fn string_nebula() -> String {
"nebula".to_string()
}
fn is_string_nebula(s: &str) -> bool {
s == "nebula"
}
fn string_empty() -> String {
String::new()
}
fn is_string_empty(s: &str) -> bool {
s.is_empty()
}
fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol {
NebulaConfigStatsGraphiteProtocol::Tcp
}
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool {
matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp)
}
fn none<T>() -> Option<T> {
None
}
fn is_none<T>(o: &Option<T>) -> bool {
o.is_none()
}

View File

@ -0,0 +1,83 @@
use crate::config::{NebulaConfig, CONFIG};
use ed25519_dalek::{SigningKey, VerifyingKey};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use trifid_pki::cert::NebulaCertificate;
use trifid_pki::x25519_dalek::PublicKey;
#[derive(Serialize, Deserialize)]
pub struct Keystore {
#[serde(default = "default_vec")]
pub hosts: Vec<KeystoreHostInformation>,
}
fn default_vec<T>() -> Vec<T> {
vec![]
}
pub fn keystore_init() -> Result<Keystore, Box<dyn Error>> {
let mut ks_fp = CONFIG.crypto.local_keystore_directory.clone();
ks_fp.push("/tfks.toml");
if !ks_fp.exists() {
return Ok(Keystore {
hosts: vec![]
})
}
let f_str = fs::read_to_string(ks_fp)?;
let keystore: Keystore = toml::from_str(&f_str)?;
Ok(keystore)
}
pub fn keystore_flush(ks: &Keystore) -> Result<(), Box<dyn Error>> {
let mut ks_fp = CONFIG.crypto.local_keystore_directory.clone();
ks_fp.push("/tfks.toml");
fs::write(ks_fp, toml::to_string(ks)?)?;
Ok(())
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KeystoreHostInformation {
pub id: String,
pub current_signing_key: u64,
pub current_client_key: u64,
pub current_config: u64,
pub current_cert: u64,
pub certs: Vec<KSCert>,
pub config: Vec<KSConfig>,
pub signing_keys: Vec<KSSigningKey>,
pub client_keys: Vec<KSClientKey>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSCert {
pub id: u64,
pub cert: NebulaCertificate,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSConfig {
pub id: u64,
pub config: NebulaConfig,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSSigningKey {
pub id: u64,
pub key: SigningKey,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSClientKey {
pub id: u64,
pub dh_pub: PublicKey,
pub ed_pub: VerifyingKey,
}

View File

@ -26,14 +26,17 @@ use std::time::Duration;
use crate::config::CONFIG; use crate::config::CONFIG;
use crate::error::{APIError, APIErrorsResponse}; use crate::error::{APIError, APIErrorsResponse};
use crate::keystore::{keystore_init, Keystore};
use crate::tokens::random_id_no_id; use crate::tokens::random_id_no_id;
use trifid_api_migration::{Migrator, MigratorTrait}; use trifid_api_migration::{Migrator, MigratorTrait};
pub mod auth_tokens; pub mod auth_tokens;
pub mod codegen;
pub mod config; pub mod config;
pub mod crypto; pub mod crypto;
pub mod cursor; pub mod cursor;
pub mod error; pub mod error;
pub mod keystore;
pub mod magic_link; pub mod magic_link;
pub mod routes; pub mod routes;
pub mod timers; pub mod timers;
@ -41,12 +44,17 @@ pub mod tokens;
pub struct AppState { pub struct AppState {
pub conn: DatabaseConnection, pub conn: DatabaseConnection,
pub keystore: Keystore,
} }
#[actix_web::main] #[actix_web::main]
async fn main() -> Result<(), Box<dyn Error>> { async fn main() -> Result<(), Box<dyn Error>> {
simple_logger::init_with_level(Level::Debug).unwrap(); simple_logger::init_with_level(Level::Debug).unwrap();
info!("Creating keystore...");
let keystore = keystore_init()?;
info!("Connecting to database at {}...", CONFIG.database.url); info!("Connecting to database at {}...", CONFIG.database.url);
let mut opt = ConnectOptions::new(CONFIG.database.url.clone()); let mut opt = ConnectOptions::new(CONFIG.database.url.clone());
@ -64,7 +72,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
info!("Performing database migration..."); info!("Performing database migration...");
Migrator::up(&db, None).await?; Migrator::up(&db, None).await?;
let data = Data::new(AppState { conn: db }); let data = Data::new(AppState { conn: db, keystore });
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
@ -103,6 +111,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
.service(routes::v1::hosts::block_host) .service(routes::v1::hosts::block_host)
.service(routes::v1::hosts::enroll_host) .service(routes::v1::hosts::enroll_host)
.service(routes::v1::hosts::create_host_and_enrollment_code) .service(routes::v1::hosts::create_host_and_enrollment_code)
.service(routes::v2::enroll::enroll)
}) })
.bind(CONFIG.server.bind)? .bind(CONFIG.server.bind)?
.run() .run()

View File

@ -1 +1,2 @@
pub mod v1; pub mod v1;
pub mod v2;

View File

@ -33,7 +33,7 @@ use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, Query
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use trifid_api_entities::entity::{network, organization, signing_ca}; use trifid_api_entities::entity::{network, organization, signing_ca};
use trifid_pki::cert::{serialize_x25519_private, NebulaCertificate, NebulaCertificateDetails}; use trifid_pki::cert::{serialize_x25519_private, NebulaCertificate, NebulaCertificateDetails, serialize_ed25519_private};
use trifid_pki::ed25519_dalek::SigningKey; use trifid_pki::ed25519_dalek::SigningKey;
use trifid_pki::rand_core::OsRng; use trifid_pki::rand_core::OsRng;
@ -146,7 +146,7 @@ pub async fn create_org_request(
} }
// PEM-encode the CA key // PEM-encode the CA key
let ca_key_pem = serialize_x25519_private(&private_key.to_keypair_bytes()); let ca_key_pem = serialize_ed25519_private(&private_key.to_keypair_bytes());
// PEM-encode the CA cert // PEM-encode the CA cert
let ca_cert_pem = match cert.serialize_to_pem() { let ca_cert_pem = match cert.serialize_to_pem() {
Ok(pem) => pem, Ok(pem) => pem,
@ -204,8 +204,8 @@ pub async fn create_org_request(
let signing_ca = signing_ca::Model { let signing_ca = signing_ca::Model {
id: random_id("ca"), id: random_id("ca"),
organization: org.id.clone(), organization: org.id.clone(),
cert: ca_key_encrypted, cert: ca_crt,
key: ca_crt, key: ca_key_encrypted,
expires: cert expires: cert
.details .details
.not_after .not_after

View File

@ -0,0 +1,233 @@
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest, HttpResponse};
use dnapi_rs::message::{
APIError, EnrollRequest, EnrollResponse, EnrollResponseData, EnrollResponseDataOrg,
};
use ed25519_dalek::{SigningKey, VerifyingKey};
use log::{debug, error, trace};
use rand::rngs::OsRng;
use sea_orm::{ColumnTrait, EntityTrait, ModelTrait, QueryFilter};
use crate::codegen::{collect_info, generate_config};
use crate::keystore::{KSCert, KSClientKey, KSConfig, KSSigningKey, KeystoreHostInformation};
use crate::AppState;
use trifid_api_entities::entity::host_enrollment_code;
use trifid_pki::cert::{
deserialize_ed25519_public, deserialize_x25519_public, serialize_ed25519_public,
};
use trifid_pki::x25519_dalek::PublicKey;
use crate::timers::expired;
#[post("/v2/enroll")]
pub async fn enroll(
req: Json<EnrollRequest>,
_req_info: HttpRequest,
db: Data<AppState>,
) -> HttpResponse {
debug!("{:x?} {:x?}", req.dh_pubkey, req.ed_pubkey);
// pull enroll information from the db
let code_info = match host_enrollment_code::Entity::find()
.filter(host_enrollment_code::Column::Id.eq(&req.code))
.one(&db.conn)
.await
{
Ok(ci) => ci,
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message:
"There was an error with the database request. Please try again later."
.to_string(),
path: None,
}],
});
}
};
let enroll_info = match code_info {
Some(ei) => ei,
None => {
return HttpResponse::Unauthorized().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_UNAUTHORIZED".to_string(),
message: "That code is invalid or has expired.".to_string(),
path: None,
}],
})
}
};
if expired(enroll_info.expires_on as u64) {
return HttpResponse::Unauthorized().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_UNAUTHORIZED".to_string(),
message: "That code is invalid or has expired.".to_string(),
path: None,
}],
});
}
// deserialize
let dh_pubkey = match deserialize_x25519_public(&req.dh_pubkey) {
Ok(k) => k,
Err(e) => {
error!("public key deserialization error: {}", e);
return HttpResponse::BadRequest().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_BAD_DH_PUB".to_string(),
message: "Unable to deserialize the DH public key.".to_string(),
path: None,
}],
});
}
};
let ed_pubkey = match deserialize_ed25519_public(&req.ed_pubkey) {
Ok(k) => k,
Err(e) => {
error!("public key deserialization error: {}", e);
return HttpResponse::BadRequest().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_BAD_ED_PUB".to_string(),
message: "Unable to deserialize the ED25519 public key.".to_string(),
path: None,
}],
});
}
};
// destroy the enrollment code before doing anything else
match enroll_info.clone().delete(&db.conn).await {
Ok(_) => (),
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message:
"There was an error with the database request. Please try again later."
.to_string(),
path: None,
}],
});
}
}
let info = match collect_info(&db, &enroll_info.host, &dh_pubkey).await {
Ok(i) => i,
Err(e) => {
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: e.to_string(),
path: None,
}],
});
}
};
// codegen: handoff to dedicated codegen module, we have collected all information
let (cfg, cert) = match generate_config(&db, &info).await {
Ok(cfg) => cfg,
Err(e) => {
error!("error generating configuration: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: "There was an error generating the host configuration.".to_string(),
path: None,
}],
});
}
};
let host_in_ks = db.keystore.hosts.iter().find(|u| u.id == enroll_info.id);
let host_in_ks = match host_in_ks {
Some(ksinfo) => {
let mut ks = ksinfo.clone();
ks.certs.push(KSCert {
id: ks.current_cert + 1,
cert,
});
ks.current_cert += 1;
ks.config.push(KSConfig {
id: ks.current_config + 1,
config: cfg.clone(),
});
ks.current_config += 1;
ks.signing_keys.push(KSSigningKey {
id: ks.current_signing_key,
key: SigningKey::generate(&mut OsRng),
});
ks.current_signing_key += 1;
let dh_pubkey_typed: [u8; 32] = dh_pubkey.clone().try_into().unwrap();
ks.client_keys.push(KSClientKey {
id: ks.current_client_key + 1,
dh_pub: PublicKey::from(dh_pubkey_typed),
ed_pub: VerifyingKey::from_bytes(&ed_pubkey.try_into().unwrap()).unwrap(),
});
ks.current_client_key += 1;
ks
}
None => {
let dh_pubkey_typed: [u8; 32] = dh_pubkey.clone().try_into().unwrap();
KeystoreHostInformation {
id: enroll_info.id.clone(),
current_signing_key: 1,
current_client_key: 1,
current_config: 1,
current_cert: 1,
certs: vec![KSCert { id: 1, cert }],
config: vec![KSConfig {
id: 1,
config: cfg.clone(),
}],
signing_keys: vec![KSSigningKey {
id: 1,
key: SigningKey::generate(&mut OsRng),
}],
client_keys: vec![KSClientKey {
id: 1,
dh_pub: PublicKey::from(dh_pubkey_typed),
ed_pub: VerifyingKey::from_bytes(&ed_pubkey.try_into().unwrap()).unwrap(),
}],
}
}
};
HttpResponse::Ok().json(EnrollResponse::Success {
data: EnrollResponseData {
config: match serde_yaml::to_string(&cfg) {
Ok(cfg) => cfg.as_bytes().to_vec(),
Err(e) => {
error!("serialization error: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![
APIError {
code: "ERR_CFG_SERIALIZATION_ERROR".to_string(),
message: "There was an error serializing the node's configuration. Please try again later.".to_string(),
path: None,
}
],
});
}
},
host_id: enroll_info.host.clone(),
counter: host_in_ks.current_config as u32,
trusted_keys: serialize_ed25519_public(host_in_ks.signing_keys.iter().find(|u| u.id == host_in_ks.current_signing_key).unwrap().key.verifying_key().as_bytes().as_slice()).to_vec(),
organization: EnrollResponseDataOrg { id: info.organization.id.clone(), name: info.organization.name.clone() },
},
})
}

View File

@ -0,0 +1 @@
pub mod enroll;

View File

@ -1,6 +1,6 @@
[package] [package]
name = "trifid-pki" name = "trifid-pki"
version = "0.1.10" version = "0.1.11"
edition = "2021" edition = "2021"
description = "A rust implementation of the Nebula PKI system" description = "A rust implementation of the Nebula PKI system"
license = "AGPL-3.0-or-later" license = "AGPL-3.0-or-later"

View File

@ -1,14 +1,14 @@
//! Structs to represent a pool of CA's and blacklisted certificates //! Structs to represent a pool of CA's and blacklisted certificates
use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
use ed25519_dalek::VerifyingKey;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::time::SystemTime; use std::time::SystemTime;
use ed25519_dalek::VerifyingKey;
use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
#[cfg(feature = "serde_derive")] #[cfg(feature = "serde_derive")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// A pool of trusted CA certificates, and certificates that should be blocked. /// A pool of trusted CA certificates, and certificates that should be blocked.
/// This is equivalent to the `pki` section in a typical Nebula config.yml. /// This is equivalent to the `pki` section in a typical Nebula config.yml.
@ -20,7 +20,7 @@ pub struct NebulaCAPool {
/// The list of blocklisted certificate fingerprints /// The list of blocklisted certificate fingerprints
pub cert_blocklist: Vec<String>, pub cert_blocklist: Vec<String>,
/// True if any of the member CAs certificates are expired. Must be handled. /// True if any of the member CAs certificates are expired. Must be handled.
pub expired: bool pub expired: bool,
} }
impl NebulaCAPool { impl NebulaCAPool {
@ -41,8 +41,12 @@ impl NebulaCAPool {
for cert in pems { for cert in pems {
match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) { match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) {
Ok(did_expire) => if did_expire { pool.expired = true }, Ok(did_expire) => {
Err(e) => return Err(e) if did_expire {
pool.expired = true
}
}
Err(e) => return Err(e),
} }
} }
@ -56,21 +60,23 @@ impl NebulaCAPool {
let cert = deserialize_nebula_certificate_from_pem(bytes)?; let cert = deserialize_nebula_certificate_from_pem(bytes)?;
if !cert.details.is_ca { if !cert.details.is_ca {
return Err(CaPoolError::NotACA.into()) return Err(CaPoolError::NotACA.into());
} }
if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? { if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? {
return Err(CaPoolError::NotSelfSigned.into()) return Err(CaPoolError::NotSelfSigned.into());
} }
let fingerprint = cert.sha256sum()?; let fingerprint = cert.sha256sum()?;
let expired = cert.expired(SystemTime::now()); let expired = cert.expired(SystemTime::now());
if expired { self.expired = true } if expired {
self.expired = true
}
self.cas.insert(fingerprint, cert); self.cas.insert(fingerprint, cert);
Ok(expired) Ok(expired)
} }
/// Blocklist the given certificate in the CA pool /// Blocklist the given certificate in the CA pool
@ -92,9 +98,12 @@ impl NebulaCAPool {
/// Gets the CA certificate used to sign the given certificate /// Gets the CA certificate used to sign the given certificate
/// # Errors /// # Errors
/// This function will return an error if the certificate does not have an issuer attached (it is self-signed) /// This function will return an error if the certificate does not have an issuer attached (it is self-signed)
pub fn get_ca_for_cert(&self, cert: &NebulaCertificate) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> { pub fn get_ca_for_cert(
&self,
cert: &NebulaCertificate,
) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> {
if cert.details.issuer == String::new() { if cert.details.issuer == String::new() {
return Err(CaPoolError::NoIssuer.into()) return Err(CaPoolError::NoIssuer.into());
} }
Ok(self.cas.get(&cert.details.issuer)) Ok(self.cas.get(&cert.details.issuer))
@ -115,7 +124,7 @@ pub enum CaPoolError {
/// Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates) /// Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)
NotSelfSigned, NotSelfSigned,
/// Tried to look up a certificate that does not have an issuer field /// Tried to look up a certificate that does not have an issuer field
NoIssuer NoIssuer,
} }
impl Error for CaPoolError {} impl Error for CaPoolError {}
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -127,4 +136,4 @@ impl Display for CaPoolError {
Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field") Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field")
} }
} }
} }

View File

@ -1,22 +1,22 @@
//! Manage Nebula PKI Certificates //! Manage Nebula PKI Certificates
//! This is pretty much a direct port of nebula/cert/cert.go //! This is pretty much a direct port of nebula/cert/cert.go
use crate::ca::NebulaCAPool;
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use ipnet::Ipv4Net;
use pem::Pem;
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use sha2::Digest;
use sha2::Sha256;
use std::error::Error; use std::error::Error;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use std::ops::Add; use std::ops::Add;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use ipnet::{Ipv4Net};
use pem::Pem;
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use sha2::Sha256;
use crate::ca::NebulaCAPool;
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use sha2::Digest;
#[cfg(feature = "serde_derive")] #[cfg(feature = "serde_derive")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// The length, in bytes, of public keys /// The length, in bytes, of public keys
pub const PUBLIC_KEY_LENGTH: i32 = 32; pub const PUBLIC_KEY_LENGTH: i32 = 32;
@ -39,7 +39,7 @@ pub struct NebulaCertificate {
/// The signed data of this certificate /// The signed data of this certificate
pub details: NebulaCertificateDetails, pub details: NebulaCertificateDetails,
/// The Ed25519 signature of this certificate /// The Ed25519 signature of this certificate
pub signature: Vec<u8> pub signature: Vec<u8>,
} }
/// The signed details contained in a Nebula PKI certificate /// The signed details contained in a Nebula PKI certificate
@ -63,7 +63,7 @@ pub struct NebulaCertificateDetails {
/// Is this node a CA? /// Is this node a CA?
pub is_ca: bool, pub is_ca: bool,
/// SHA256 of issuer certificate. If blank, this cert is self-signed. /// SHA256 of issuer certificate. If blank, this cert is self-signed.
pub issuer: String pub issuer: String,
} }
/// A list of errors that can occur parsing certificates /// A list of errors that can occur parsing certificates
@ -87,7 +87,7 @@ pub enum CertificateError {
/// This certificate either is not yet valid or has already expired /// This certificate either is not yet valid or has already expired
Expired, Expired,
/// The public key does not match the expected value /// The public key does not match the expected value
KeyMismatch KeyMismatch,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
impl Display for CertificateError { impl Display for CertificateError {
@ -95,13 +95,29 @@ impl Display for CertificateError {
match self { match self {
Self::EmptyByteArray => write!(f, "Certificate bytearray is empty"), Self::EmptyByteArray => write!(f, "Certificate bytearray is empty"),
Self::NilDetails => write!(f, "The encoded Details field is null"), Self::NilDetails => write!(f, "The encoded Details field is null"),
Self::IpsNotPairs => write!(f, "encoded IPs should be in pairs, an odd number was found"), Self::IpsNotPairs => {
Self::SubnetsNotPairs => write!(f, "encoded subnets should be in pairs, an odd number was found"), write!(f, "encoded IPs should be in pairs, an odd number was found")
Self::WrongSigLength => write!(f, "Signature should be 64 bytes but is a different size"), }
Self::WrongKeyLength => write!(f, "Public keys are expected to be 32 bytes but the public key on this cert is not"), Self::SubnetsNotPairs => write!(
Self::WrongPemTag => write!(f, "Certificates should have the PEM tag `NEBULA CERTIFICATE`, but this block did not"), f,
Self::Expired => write!(f, "This certificate either is not yet valid or has already expired"), "encoded subnets should be in pairs, an odd number was found"
Self::KeyMismatch => write!(f, "Key does not match expected value") ),
Self::WrongSigLength => {
write!(f, "Signature should be 64 bytes but is a different size")
}
Self::WrongKeyLength => write!(
f,
"Public keys are expected to be 32 bytes but the public key on this cert is not"
),
Self::WrongPemTag => write!(
f,
"Certificates should have the PEM tag `NEBULA CERTIFICATE`, but this block did not"
),
Self::Expired => write!(
f,
"This certificate either is not yet valid or has already expired"
),
Self::KeyMismatch => write!(f, "Key does not match expected value"),
} }
} }
} }
@ -110,7 +126,10 @@ impl Error for CertificateError {}
fn map_cidr_pairs(pairs: &[u32]) -> Result<Vec<Ipv4Net>, Box<dyn Error>> { fn map_cidr_pairs(pairs: &[u32]) -> Result<Vec<Ipv4Net>, Box<dyn Error>> {
let mut res_vec = vec![]; let mut res_vec = vec![];
for pair in pairs.chunks(2) { for pair in pairs.chunks(2) {
res_vec.push(Ipv4Net::with_netmask(Ipv4Addr::from(pair[0]), Ipv4Addr::from(pair[1]))?); res_vec.push(Ipv4Net::with_netmask(
Ipv4Addr::from(pair[0]),
Ipv4Addr::from(pair[1]),
)?);
} }
Ok(res_vec) Ok(res_vec)
} }
@ -129,7 +148,11 @@ impl Display for NebulaCertificate {
writeln!(f, " Not after: {:?}", self.details.not_after)?; writeln!(f, " Not after: {:?}", self.details.not_after)?;
writeln!(f, " Is CA: {}", self.details.is_ca)?; writeln!(f, " Is CA: {}", self.details.is_ca)?;
writeln!(f, " Issuer: {}", self.details.issuer)?; writeln!(f, " Issuer: {}", self.details.issuer)?;
writeln!(f, " Public key: {}", hex::encode(self.details.public_key))?; writeln!(
f,
" Public key: {}",
hex::encode(self.details.public_key)
)?;
writeln!(f, " }}")?; writeln!(f, " }}")?;
writeln!(f, " Fingerprint: {}", self.sha256sum().unwrap())?; writeln!(f, " Fingerprint: {}", self.sha256sum().unwrap())?;
writeln!(f, " Signature: {}", hex::encode(self.signature.clone()))?; writeln!(f, " Signature: {}", hex::encode(self.signature.clone()))?;
@ -143,7 +166,7 @@ impl Display for NebulaCertificate {
/// # Panics /// # Panics
pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate, Box<dyn Error>> { pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate, Box<dyn Error>> {
if bytes.is_empty() { if bytes.is_empty() {
return Err(CertificateError::EmptyByteArray.into()) return Err(CertificateError::EmptyByteArray.into());
} }
let mut reader = BytesReader::from_bytes(bytes); let mut reader = BytesReader::from_bytes(bytes);
@ -153,11 +176,11 @@ pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate,
let details = raw_cert.Details.ok_or(CertificateError::NilDetails)?; let details = raw_cert.Details.ok_or(CertificateError::NilDetails)?;
if details.Ips.len() % 2 != 0 { if details.Ips.len() % 2 != 0 {
return Err(CertificateError::IpsNotPairs.into()) return Err(CertificateError::IpsNotPairs.into());
} }
if details.Subnets.len() % 2 != 0 { if details.Subnets.len() % 2 != 0 {
return Err(CertificateError::SubnetsNotPairs.into()) return Err(CertificateError::SubnetsNotPairs.into());
} }
let mut nebula_cert; let mut nebula_cert;
@ -168,8 +191,13 @@ pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate,
name: details.Name.to_string(), name: details.Name.to_string(),
ips: map_cidr_pairs(&details.Ips)?, ips: map_cidr_pairs(&details.Ips)?,
subnets: map_cidr_pairs(&details.Subnets)?, subnets: map_cidr_pairs(&details.Subnets)?,
groups: details.Groups.iter().map(std::string::ToString::to_string).collect(), groups: details
not_before: SystemTime::UNIX_EPOCH.add(Duration::from_secs(details.NotBefore as u64)), .Groups
.iter()
.map(std::string::ToString::to_string)
.collect(),
not_before: SystemTime::UNIX_EPOCH
.add(Duration::from_secs(details.NotBefore as u64)),
not_after: SystemTime::UNIX_EPOCH.add(Duration::from_secs(details.NotAfter as u64)), not_after: SystemTime::UNIX_EPOCH.add(Duration::from_secs(details.NotAfter as u64)),
public_key: [0u8; 32], public_key: [0u8; 32],
is_ca: details.IsCA, is_ca: details.IsCA,
@ -182,10 +210,13 @@ pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate,
nebula_cert.signature = raw_cert.Signature; nebula_cert.signature = raw_cert.Signature;
if details.PublicKey.len() != 32 { if details.PublicKey.len() != 32 {
return Err(CertificateError::WrongKeyLength.into()) return Err(CertificateError::WrongKeyLength.into());
} }
#[allow(clippy::unwrap_used)] { nebula_cert.details.public_key = details.PublicKey.try_into().unwrap(); } #[allow(clippy::unwrap_used)]
{
nebula_cert.details.public_key = details.PublicKey.try_into().unwrap();
}
Ok(nebula_cert) Ok(nebula_cert)
} }
@ -199,28 +230,32 @@ pub enum KeyError {
/// Ed25519 private keys are 64 bytes /// Ed25519 private keys are 64 bytes
Not64Bytes, Not64Bytes,
/// X25519 private keys are 32 bytes /// X25519 private keys are 32 bytes
Not32Bytes Not32Bytes,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
impl Display for KeyError { impl Display for KeyError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::WrongPemTag => write!(f, "Keys should have their associated PEM tags but this had the wrong one"), Self::WrongPemTag => write!(
f,
"Keys should have their associated PEM tags but this had the wrong one"
),
Self::Not64Bytes => write!(f, "Ed25519 private keys are 64 bytes"), Self::Not64Bytes => write!(f, "Ed25519 private keys are 64 bytes"),
Self::Not32Bytes => write!(f, "X25519 private keys are 32 bytes") Self::Not32Bytes => write!(f, "X25519 private keys are 32 bytes"),
} }
} }
} }
impl Error for KeyError {} impl Error for KeyError {}
/// Deserialize the first PEM block in the given byte array into a `NebulaCertificate` /// Deserialize the first PEM block in the given byte array into a `NebulaCertificate`
/// # Errors /// # Errors
/// This function will return an error if the PEM data is invalid, or if there is an error parsing the certificate (see `deserialize_nebula_certificate`) /// This function will return an error if the PEM data is invalid, or if there is an error parsing the certificate (see `deserialize_nebula_certificate`)
pub fn deserialize_nebula_certificate_from_pem(bytes: &[u8]) -> Result<NebulaCertificate, Box<dyn Error>> { pub fn deserialize_nebula_certificate_from_pem(
bytes: &[u8],
) -> Result<NebulaCertificate, Box<dyn Error>> {
let pem = pem::parse(bytes)?; let pem = pem::parse(bytes)?;
if pem.tag != CERT_BANNER { if pem.tag != CERT_BANNER {
return Err(CertificateError::WrongPemTag.into()) return Err(CertificateError::WrongPemTag.into());
} }
deserialize_nebula_certificate(&pem.contents) deserialize_nebula_certificate(&pem.contents)
} }
@ -230,7 +265,9 @@ pub fn serialize_x25519_private(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem { pem::encode(&Pem {
tag: X25519_PRIVATE_KEY_BANNER.to_string(), tag: X25519_PRIVATE_KEY_BANNER.to_string(),
contents: bytes.to_vec(), contents: bytes.to_vec(),
}).as_bytes().to_vec() })
.as_bytes()
.to_vec()
} }
/// Simple helper to PEM encode an X25519 public key /// Simple helper to PEM encode an X25519 public key
@ -238,7 +275,9 @@ pub fn serialize_x25519_public(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem { pem::encode(&Pem {
tag: X25519_PUBLIC_KEY_BANNER.to_string(), tag: X25519_PUBLIC_KEY_BANNER.to_string(),
contents: bytes.to_vec(), contents: bytes.to_vec(),
}).as_bytes().to_vec() })
.as_bytes()
.to_vec()
} }
/// Attempt to deserialize a PEM encoded X25519 private key /// Attempt to deserialize a PEM encoded X25519 private key
@ -247,10 +286,10 @@ pub fn serialize_x25519_public(bytes: &[u8]) -> Vec<u8> {
pub fn deserialize_x25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> { pub fn deserialize_x25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?; let pem = pem::parse(bytes)?;
if pem.tag != X25519_PRIVATE_KEY_BANNER { if pem.tag != X25519_PRIVATE_KEY_BANNER {
return Err(KeyError::WrongPemTag.into()) return Err(KeyError::WrongPemTag.into());
} }
if pem.contents.len() != 32 { if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into()) return Err(KeyError::Not32Bytes.into());
} }
Ok(pem.contents) Ok(pem.contents)
} }
@ -261,10 +300,10 @@ pub fn deserialize_x25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error
pub fn deserialize_x25519_public(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> { pub fn deserialize_x25519_public(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?; let pem = pem::parse(bytes)?;
if pem.tag != X25519_PUBLIC_KEY_BANNER { if pem.tag != X25519_PUBLIC_KEY_BANNER {
return Err(KeyError::WrongPemTag.into()) return Err(KeyError::WrongPemTag.into());
} }
if pem.contents.len() != 32 { if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into()) return Err(KeyError::Not32Bytes.into());
} }
Ok(pem.contents) Ok(pem.contents)
} }
@ -274,7 +313,9 @@ pub fn serialize_ed25519_private(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem { pem::encode(&Pem {
tag: ED25519_PRIVATE_KEY_BANNER.to_string(), tag: ED25519_PRIVATE_KEY_BANNER.to_string(),
contents: bytes.to_vec(), contents: bytes.to_vec(),
}).as_bytes().to_vec() })
.as_bytes()
.to_vec()
} }
/// Simple helper to PEM encode an Ed25519 public key /// Simple helper to PEM encode an Ed25519 public key
@ -282,7 +323,9 @@ pub fn serialize_ed25519_public(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem { pem::encode(&Pem {
tag: ED25519_PUBLIC_KEY_BANNER.to_string(), tag: ED25519_PUBLIC_KEY_BANNER.to_string(),
contents: bytes.to_vec(), contents: bytes.to_vec(),
}).as_bytes().to_vec() })
.as_bytes()
.to_vec()
} }
/// Attempt to deserialize a PEM encoded Ed25519 private key /// Attempt to deserialize a PEM encoded Ed25519 private key
@ -291,10 +334,10 @@ pub fn serialize_ed25519_public(bytes: &[u8]) -> Vec<u8> {
pub fn deserialize_ed25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> { pub fn deserialize_ed25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?; let pem = pem::parse(bytes)?;
if pem.tag != ED25519_PRIVATE_KEY_BANNER { if pem.tag != ED25519_PRIVATE_KEY_BANNER {
return Err(KeyError::WrongPemTag.into()) return Err(KeyError::WrongPemTag.into());
} }
if pem.contents.len() != 64 { if pem.contents.len() != 64 {
return Err(KeyError::Not64Bytes.into()) return Err(KeyError::Not64Bytes.into());
} }
Ok(pem.contents) Ok(pem.contents)
} }
@ -305,10 +348,10 @@ pub fn deserialize_ed25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Erro
pub fn deserialize_ed25519_public(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> { pub fn deserialize_ed25519_public(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?; let pem = pem::parse(bytes)?;
if pem.tag != ED25519_PUBLIC_KEY_BANNER { if pem.tag != ED25519_PUBLIC_KEY_BANNER {
return Err(KeyError::WrongPemTag.into()) return Err(KeyError::WrongPemTag.into());
} }
if pem.contents.len() != 32 { if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into()) return Err(KeyError::Not32Bytes.into());
} }
Ok(pem.contents) Ok(pem.contents)
} }
@ -322,10 +365,10 @@ pub fn deserialize_ed25519_public_many(bytes: &[u8]) -> Result<Vec<Vec<u8>>, Box
for pem in pems { for pem in pems {
if pem.tag != ED25519_PUBLIC_KEY_BANNER { if pem.tag != ED25519_PUBLIC_KEY_BANNER {
return Err(KeyError::WrongPemTag.into()) return Err(KeyError::WrongPemTag.into());
} }
if pem.contents.len() != 32 { if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into()) return Err(KeyError::Not32Bytes.into());
} }
keys.push(pem.contents); keys.push(pem.contents);
} }
@ -367,7 +410,11 @@ impl NebulaCertificate {
/// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) /// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
/// # Errors /// # Errors
/// This function will return an error if there is an error parsing the cert or the CA pool. /// This function will return an error if there is an error parsing the cert or the CA pool.
pub fn verify(&self, time: SystemTime, ca_pool: &NebulaCAPool) -> Result<CertificateValidity, Box<dyn Error>> { pub fn verify(
&self,
time: SystemTime,
ca_pool: &NebulaCAPool,
) -> Result<CertificateValidity, Box<dyn Error>> {
if ca_pool.is_blocklisted(self) { if ca_pool.is_blocklisted(self) {
return Ok(CertificateValidity::Blocklisted); return Ok(CertificateValidity::Blocklisted);
} }
@ -375,15 +422,15 @@ impl NebulaCertificate {
let Some(signer) = ca_pool.get_ca_for_cert(self)? else { return Ok(CertificateValidity::NotSignedByThisCAPool) }; let Some(signer) = ca_pool.get_ca_for_cert(self)? else { return Ok(CertificateValidity::NotSignedByThisCAPool) };
if signer.expired(time) { if signer.expired(time) {
return Ok(CertificateValidity::RootCertExpired) return Ok(CertificateValidity::RootCertExpired);
} }
if self.expired(time) { if self.expired(time) {
return Ok(CertificateValidity::CertExpired) return Ok(CertificateValidity::CertExpired);
} }
if !self.check_signature(&VerifyingKey::from_bytes(&signer.details.public_key)?)? { if !self.check_signature(&VerifyingKey::from_bytes(&signer.details.public_key)?)? {
return Ok(CertificateValidity::BadSignature) return Ok(CertificateValidity::BadSignature);
} }
Ok(self.check_root_constraints(signer)) Ok(self.check_root_constraints(signer))
@ -392,7 +439,10 @@ impl NebulaCertificate {
/// Make sure that this certificate does not break any of the constraints set by the signing certificate /// Make sure that this certificate does not break any of the constraints set by the signing certificate
pub fn check_root_constraints(&self, signer: &Self) -> CertificateValidity { pub fn check_root_constraints(&self, signer: &Self) -> CertificateValidity {
// Make sure this cert doesn't expire after the signer // Make sure this cert doesn't expire after the signer
println!("{:?} {:?}", signer.details.not_before, self.details.not_before); println!(
"{:?} {:?}",
signer.details.not_before, self.details.not_before
);
if signer.details.not_before < self.details.not_before { if signer.details.not_before < self.details.not_before {
return CertificateValidity::CertExpiresAfterSigner; return CertificateValidity::CertExpiresAfterSigner;
} }
@ -404,7 +454,10 @@ impl NebulaCertificate {
// If the signer contains a limited set of groups, make sure this cert only has a subset of them // If the signer contains a limited set of groups, make sure this cert only has a subset of them
if !signer.details.groups.is_empty() { if !signer.details.groups.is_empty() {
println!("root groups: {:?}, child groups: {:?}", signer.details.groups, self.details.groups); println!(
"root groups: {:?}, child groups: {:?}",
signer.details.groups, self.details.groups
);
for group in &self.details.groups { for group in &self.details.groups {
if !signer.details.groups.contains(group) { if !signer.details.groups.contains(group) {
return CertificateValidity::GroupNotPresentOnSigner; return CertificateValidity::GroupNotPresentOnSigner;
@ -443,10 +496,9 @@ impl NebulaCertificate {
if self.details.is_ca { if self.details.is_ca {
// convert the keys // convert the keys
if key.len() != 64 { if key.len() != 64 {
return Err("key not 64-bytes long".into()) return Err("key not 64-bytes long".into());
} }
let secret = SigningKey::from_keypair_bytes(key.try_into().unwrap())?; let secret = SigningKey::from_keypair_bytes(key.try_into().unwrap())?;
let pub_key = secret.verifying_key().to_bytes(); let pub_key = secret.verifying_key().to_bytes();
if pub_key != self.details.public_key { if pub_key != self.details.public_key {
@ -457,13 +509,17 @@ impl NebulaCertificate {
} }
if key.len() != 32 { if key.len() != 32 {
return Err("key not 32-bytes long".into()) return Err("key not 32-bytes long".into());
} }
let pubkey_raw = SigningKey::from_bytes(key.try_into()?).verifying_key(); let pubkey_raw = SigningKey::from_bytes(key.try_into()?).verifying_key();
let pubkey = pubkey_raw.as_bytes(); let pubkey = pubkey_raw.as_bytes();
println!("{} {}", hex::encode(pubkey), hex::encode(self.details.public_key)); println!(
"{} {}",
hex::encode(pubkey),
hex::encode(self.details.public_key)
);
if *pubkey != self.details.public_key { if *pubkey != self.details.public_key {
return Err(CertificateError::KeyMismatch.into()); return Err(CertificateError::KeyMismatch.into());
} }
@ -471,21 +527,34 @@ impl NebulaCertificate {
Ok(()) Ok(())
} }
/// Get a protobuf-ready raw struct, ready for serialization /// Get a protobuf-ready raw struct, ready for serialization
#[allow(clippy::expect_used)] #[allow(clippy::expect_used)]
#[allow(clippy::cast_possible_wrap)] #[allow(clippy::cast_possible_wrap)]
/// # Panics /// # Panics
/// This function will panic if time went backwards, or if the certificate contains extremely invalid data. /// This function will panic if time went backwards, or if the certificate contains extremely invalid data.
pub fn get_raw_details(&self) -> RawNebulaCertificateDetails { pub fn get_raw_details(&self) -> RawNebulaCertificateDetails {
let mut raw = RawNebulaCertificateDetails { let mut raw = RawNebulaCertificateDetails {
Name: self.details.name.clone(), Name: self.details.name.clone(),
Ips: vec![], Ips: vec![],
Subnets: vec![], Subnets: vec![],
Groups: self.details.groups.iter().map(std::convert::Into::into).collect(), Groups: self
NotBefore: self.details.not_before.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs() as i64, .details
NotAfter: self.details.not_after.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs() as i64, .groups
.iter()
.map(std::convert::Into::into)
.collect(),
NotBefore: self
.details
.not_before
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64,
NotAfter: self
.details
.not_after
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64,
PublicKey: self.details.public_key.into(), PublicKey: self.details.public_key.into(),
IsCA: self.details.is_ca, IsCA: self.details.is_ca,
Issuer: hex::decode(&self.details.issuer).expect("Issuer was not a hex-encoded value"), Issuer: hex::decode(&self.details.issuer).expect("Issuer was not a hex-encoded value"),
@ -529,7 +598,9 @@ impl NebulaCertificate {
Ok(pem::encode(&Pem { Ok(pem::encode(&Pem {
tag: CERT_BANNER.to_string(), tag: CERT_BANNER.to_string(),
contents: pbuf_bytes, contents: pbuf_bytes,
}).as_bytes().to_vec()) })
.as_bytes()
.to_vec())
} }
/// Get the fingerprint of this certificate /// Get the fingerprint of this certificate
@ -570,7 +641,7 @@ pub enum CertificateValidity {
/// An IP present on this certificate is not present on the signer's certificate /// An IP present on this certificate is not present on the signer's certificate
IPNotPresentOnSigner, IPNotPresentOnSigner,
/// A subnet on this certificate is not present on the signer's certificate /// A subnet on this certificate is not present on the signer's certificate
SubnetNotPresentOnSigner SubnetNotPresentOnSigner,
} }
fn net_match(cert_ip: Ipv4Net, root_ips: &Vec<Ipv4Net>) -> bool { fn net_match(cert_ip: Ipv4Net, root_ips: &Vec<Ipv4Net>) -> bool {
@ -580,4 +651,4 @@ fn net_match(cert_ip: Ipv4Net, root_ips: &Vec<Ipv4Net>) -> bool {
} }
} }
false false
} }

View File

@ -32,7 +32,6 @@
//! // } //! // }
//! ``` //! ```
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![warn(clippy::nursery)] #![warn(clippy::nursery)]
#![deny(clippy::unwrap_used)] #![deny(clippy::unwrap_used)]
@ -46,8 +45,8 @@
#![allow(clippy::module_name_repetitions)] #![allow(clippy::module_name_repetitions)]
pub use ed25519_dalek; pub use ed25519_dalek;
pub use x25519_dalek;
pub use rand_core; pub use rand_core;
pub use x25519_dalek;
extern crate core; extern crate core;
@ -60,4 +59,4 @@ pub(crate) mod cert_codec;
pub mod test; pub mod test;
/// Get the compiled version of trifid-pki. /// Get the compiled version of trifid-pki.
pub const TRIFID_PKI_VERSION: &str = env!("CARGO_PKG_VERSION"); pub const TRIFID_PKI_VERSION: &str = env!("CARGO_PKG_VERSION");

View File

@ -1,18 +1,24 @@
#![allow(clippy::unwrap_used)] #![allow(clippy::unwrap_used)]
#![allow(clippy::expect_used)] #![allow(clippy::expect_used)]
use crate::ca::NebulaCAPool;
use crate::cert::{
deserialize_ed25519_private, deserialize_ed25519_public, deserialize_ed25519_public_many,
deserialize_nebula_certificate, deserialize_nebula_certificate_from_pem,
deserialize_x25519_private, deserialize_x25519_public, serialize_ed25519_private,
serialize_ed25519_public, serialize_x25519_private, serialize_x25519_public,
CertificateValidity, NebulaCertificate, NebulaCertificateDetails,
};
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use crate::netmask; use crate::netmask;
use std::net::Ipv4Addr;
use std::ops::{Add, Sub};
use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
use ipnet::Ipv4Net;
use crate::cert::{CertificateValidity, deserialize_ed25519_private, deserialize_ed25519_public, deserialize_ed25519_public_many, deserialize_nebula_certificate, deserialize_nebula_certificate_from_pem, deserialize_x25519_private, deserialize_x25519_public, NebulaCertificate, NebulaCertificateDetails, serialize_ed25519_private, serialize_ed25519_public, serialize_x25519_private, serialize_x25519_public};
use std::str::FromStr;
use ed25519_dalek::{SigningKey, VerifyingKey}; use ed25519_dalek::{SigningKey, VerifyingKey};
use ipnet::Ipv4Net;
use quick_protobuf::{MessageWrite, Writer}; use quick_protobuf::{MessageWrite, Writer};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use crate::ca::{NebulaCAPool}; use std::net::Ipv4Addr;
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails}; use std::ops::{Add, Sub};
use std::str::FromStr;
use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
/// This is a cert that we (e3team) actually use in production, and it's a known-good certificate. /// This is a cert that we (e3team) actually use in production, and it's a known-good certificate.
pub const KNOWN_GOOD_CERT: &[u8; 258] = b"-----BEGIN NEBULA CERTIFICATE-----\nCkkKF2UzdGVhbSBJbnRlcm5hbCBOZXR3b3JrKJWev5wGMJWFxKsGOiCvpwoHyKY5\n8Q5+2XxDjtoCf/zlNY/EUdB8bwXQSwEo50ABEkB0Dx76lkMqc3IyH5+ml2dKjTyv\nB4Jiw6x3abf5YZcf8rDuVEgQpvFdJmo3xJyIb3C9vKZ6kXsUxjw6s1JdWgkA\n-----END NEBULA CERTIFICATE-----"; pub const KNOWN_GOOD_CERT: &[u8; 258] = b"-----BEGIN NEBULA CERTIFICATE-----\nCkkKF2UzdGVhbSBJbnRlcm5hbCBOZXR3b3JrKJWev5wGMJWFxKsGOiCvpwoHyKY5\n8Q5+2XxDjtoCf/zlNY/EUdB8bwXQSwEo50ABEkB0Dx76lkMqc3IyH5+ml2dKjTyv\nB4Jiw6x3abf5YZcf8rDuVEgQpvFdJmo3xJyIb3C9vKZ6kXsUxjw6s1JdWgkA\n-----END NEBULA CERTIFICATE-----";
@ -29,14 +35,18 @@ fn certificate_serialization() {
ips: vec![ ips: vec![
netmask!("10.1.1.1", "255.255.255.0"), netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"), netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0") netmask!("10.1.1.3", "255.0.0.0"),
], ],
subnets: vec![ subnets: vec![
netmask!("9.1.1.1", "255.255.255.128"), netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"), netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0") netmask!("9.1.1.3", "255.255.0.0"),
],
groups: vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
], ],
groups: vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()],
not_before: before, not_before: before,
not_after: after, not_after: after,
public_key: *pub_key, public_key: *pub_key,
@ -59,17 +69,29 @@ fn certificate_serialization() {
assert_eq!(cert.details.ips.len(), deserialized.details.ips.len()); assert_eq!(cert.details.ips.len(), deserialized.details.ips.len());
for item in &cert.details.ips { for item in &cert.details.ips {
assert!(deserialized.details.ips.contains(item), "deserialized does not contain from source"); assert!(
deserialized.details.ips.contains(item),
"deserialized does not contain from source"
);
} }
assert_eq!(cert.details.subnets.len(), deserialized.details.subnets.len()); assert_eq!(
cert.details.subnets.len(),
deserialized.details.subnets.len()
);
for item in &cert.details.subnets { for item in &cert.details.subnets {
assert!(deserialized.details.subnets.contains(item), "deserialized does not contain from source"); assert!(
deserialized.details.subnets.contains(item),
"deserialized does not contain from source"
);
} }
assert_eq!(cert.details.groups.len(), deserialized.details.groups.len()); assert_eq!(cert.details.groups.len(), deserialized.details.groups.len());
for item in &cert.details.groups { for item in &cert.details.groups {
assert!(deserialized.details.groups.contains(item), "deserialized does not contain from source"); assert!(
deserialized.details.groups.contains(item),
"deserialized does not contain from source"
);
} }
} }
@ -85,14 +107,18 @@ fn certificate_serialization_pem() {
ips: vec![ ips: vec![
netmask!("10.1.1.1", "255.255.255.0"), netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"), netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0") netmask!("10.1.1.3", "255.0.0.0"),
], ],
subnets: vec![ subnets: vec![
netmask!("9.1.1.1", "255.255.255.128"), netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"), netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0") netmask!("9.1.1.3", "255.255.0.0"),
],
groups: vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
], ],
groups: vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()],
not_before: before, not_before: before,
not_after: after, not_after: after,
public_key: *pub_key, public_key: *pub_key,
@ -115,17 +141,29 @@ fn certificate_serialization_pem() {
assert_eq!(cert.details.ips.len(), deserialized.details.ips.len()); assert_eq!(cert.details.ips.len(), deserialized.details.ips.len());
for item in &cert.details.ips { for item in &cert.details.ips {
assert!(deserialized.details.ips.contains(item), "deserialized does not contain from source"); assert!(
deserialized.details.ips.contains(item),
"deserialized does not contain from source"
);
} }
assert_eq!(cert.details.subnets.len(), deserialized.details.subnets.len()); assert_eq!(
cert.details.subnets.len(),
deserialized.details.subnets.len()
);
for item in &cert.details.subnets { for item in &cert.details.subnets {
assert!(deserialized.details.subnets.contains(item), "deserialized does not contain from source"); assert!(
deserialized.details.subnets.contains(item),
"deserialized does not contain from source"
);
} }
assert_eq!(cert.details.groups.len(), deserialized.details.groups.len()); assert_eq!(cert.details.groups.len(), deserialized.details.groups.len());
for item in &cert.details.groups { for item in &cert.details.groups {
assert!(deserialized.details.groups.contains(item), "deserialized does not contain from source"); assert!(
deserialized.details.groups.contains(item),
"deserialized does not contain from source"
);
} }
} }
@ -141,14 +179,18 @@ fn cert_signing() {
ips: vec![ ips: vec![
netmask!("10.1.1.1", "255.255.255.0"), netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"), netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0") netmask!("10.1.1.3", "255.0.0.0"),
], ],
subnets: vec![ subnets: vec![
netmask!("9.1.1.1", "255.255.255.128"), netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"), netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0") netmask!("9.1.1.3", "255.255.0.0"),
],
groups: vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
], ],
groups: vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()],
not_before: before, not_before: before,
not_after: after, not_after: after,
public_key: *pub_key, public_key: *pub_key,
@ -287,9 +329,15 @@ fn cert_deserialize_wrong_pubkey_len() {
#[test] #[test]
fn x25519_serialization() { fn x25519_serialization() {
let bytes = [0u8; 32]; let bytes = [0u8; 32];
assert_eq!(deserialize_x25519_private(&serialize_x25519_private(&bytes)).unwrap(), bytes); assert_eq!(
deserialize_x25519_private(&serialize_x25519_private(&bytes)).unwrap(),
bytes
);
assert!(deserialize_x25519_private(&[0u8; 32]).is_err()); assert!(deserialize_x25519_private(&[0u8; 32]).is_err());
assert_eq!(deserialize_x25519_public(&serialize_x25519_public(&bytes)).unwrap(), bytes); assert_eq!(
deserialize_x25519_public(&serialize_x25519_public(&bytes)).unwrap(),
bytes
);
assert!(deserialize_x25519_public(&[0u8; 32]).is_err()); assert!(deserialize_x25519_public(&[0u8; 32]).is_err());
} }
@ -297,9 +345,15 @@ fn x25519_serialization() {
fn ed25519_serialization() { fn ed25519_serialization() {
let bytes = [0u8; 64]; let bytes = [0u8; 64];
let bytes2 = [0u8; 32]; let bytes2 = [0u8; 32];
assert_eq!(deserialize_ed25519_private(&serialize_ed25519_private(&bytes)).unwrap(), bytes); assert_eq!(
deserialize_ed25519_private(&serialize_ed25519_private(&bytes)).unwrap(),
bytes
);
assert!(deserialize_ed25519_private(&[0u8; 32]).is_err()); assert!(deserialize_ed25519_private(&[0u8; 32]).is_err());
assert_eq!(deserialize_ed25519_public(&serialize_ed25519_public(&bytes2)).unwrap(), bytes2); assert_eq!(
deserialize_ed25519_public(&serialize_ed25519_public(&bytes2)).unwrap(),
bytes2
);
assert!(deserialize_ed25519_public(&[0u8; 64]).is_err()); assert!(deserialize_ed25519_public(&[0u8; 64]).is_err());
let mut bytes = vec![]; let mut bytes = vec![];
@ -315,29 +369,87 @@ fn ed25519_serialization() {
#[test] #[test]
fn cert_verify() { fn cert_verify() {
let (ca_cert, ca_key, _ca_pub) = test_ca_cert(round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(), vec![], vec![], vec!["groupa".to_string()]); let (ca_cert, ca_key, _ca_pub) = test_ca_cert(
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(),
vec![],
vec![],
vec!["groupa".to_string()],
);
let (cert, _, _) = test_cert(&ca_cert, &ca_key, SystemTime::now(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![], vec![]); let (cert, _, _) = test_cert(
&ca_cert,
&ca_key,
SystemTime::now(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![],
vec![],
);
let mut ca_pool = NebulaCAPool::new(); let mut ca_pool = NebulaCAPool::new();
ca_pool.add_ca_certificate(&ca_cert.serialize_to_pem().unwrap()).unwrap(); ca_pool
.add_ca_certificate(&ca_cert.serialize_to_pem().unwrap())
.unwrap();
let fingerprint = cert.sha256sum().unwrap(); let fingerprint = cert.sha256sum().unwrap();
ca_pool.blocklist_fingerprint(&fingerprint); ca_pool.blocklist_fingerprint(&fingerprint);
assert!(matches!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Blocklisted)); assert!(matches!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Blocklisted
));
ca_pool.reset_blocklist(); ca_pool.reset_blocklist();
assert!(matches!(cert.verify(SystemTime::now() + Duration::from_secs(60 * 60 * 60), &ca_pool).unwrap(), CertificateValidity::RootCertExpired)); assert!(matches!(
assert!(matches!(cert.verify(SystemTime::now() + Duration::from_secs(60 * 60 * 6), &ca_pool).unwrap(), CertificateValidity::CertExpired)); cert.verify(
SystemTime::now() + Duration::from_secs(60 * 60 * 60),
&ca_pool
)
.unwrap(),
CertificateValidity::RootCertExpired
));
assert!(matches!(
cert.verify(
SystemTime::now() + Duration::from_secs(60 * 60 * 6),
&ca_pool
)
.unwrap(),
CertificateValidity::CertExpired
));
let (cert_with_bad_group, _, _) = test_cert(&ca_cert, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), in_a_minute(), vec![], vec![], vec!["group-not-present on parent".to_string()]); let (cert_with_bad_group, _, _) = test_cert(
assert_eq!(cert_with_bad_group.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::GroupNotPresentOnSigner); &ca_cert,
&ca_key,
let (cert_with_good_group, _, _) = test_cert(&ca_cert, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), in_a_minute(), vec![], vec![], vec!["groupa".to_string()]); round_systime_to_secs(SystemTime::now()).unwrap(),
assert_eq!(cert_with_good_group.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); in_a_minute(),
vec![],
vec![],
vec!["group-not-present on parent".to_string()],
);
assert_eq!(
cert_with_bad_group
.verify(SystemTime::now(), &ca_pool)
.unwrap(),
CertificateValidity::GroupNotPresentOnSigner
);
let (cert_with_good_group, _, _) = test_cert(
&ca_cert,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
in_a_minute(),
vec![],
vec![],
vec!["groupa".to_string()],
);
assert_eq!(
cert_with_good_group
.verify(SystemTime::now(), &ca_pool)
.unwrap(),
CertificateValidity::Ok
);
} }
#[test] #[test]
@ -345,7 +457,13 @@ fn cert_verify_ip() {
let ca_ip_1 = Ipv4Net::from_str("10.0.0.0/16").unwrap(); let ca_ip_1 = Ipv4Net::from_str("10.0.0.0/16").unwrap();
let ca_ip_2 = Ipv4Net::from_str("192.168.0.0/24").unwrap(); let ca_ip_2 = Ipv4Net::from_str("192.168.0.0/24").unwrap();
let (ca, ca_key, _ca_pub) = test_ca_cert(round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(), vec![ca_ip_1, ca_ip_2], vec![], vec![]); let (ca, ca_key, _ca_pub) = test_ca_cert(
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(),
vec![ca_ip_1, ca_ip_2],
vec![],
vec![],
);
let ca_pem = ca.serialize_to_pem().unwrap(); let ca_pem = ca.serialize_to_pem().unwrap();
@ -355,49 +473,137 @@ fn cert_verify_ip() {
// ip is outside the network // ip is outside the network
let cip1 = netmask!("10.1.0.0", "255.255.255.0"); let cip1 = netmask!("10.1.0.0", "255.255.255.0");
let cip2 = netmask!("192.198.0.1", "255.255.0.0"); let cip2 = netmask!("192.198.0.1", "255.255.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip is outside the network - reversed order from above // ip is outside the network - reversed order from above
let cip1 = netmask!("192.198.0.1", "255.255.255.0"); let cip1 = netmask!("192.198.0.1", "255.255.255.0");
let cip2 = netmask!("10.1.0.0", "255.255.255.0"); let cip2 = netmask!("10.1.0.0", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip is within the network but mask is outside // ip is within the network but mask is outside
let cip1 = netmask!("10.0.1.0", "255.254.0.0"); let cip1 = netmask!("10.0.1.0", "255.254.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.0"); let cip2 = netmask!("192.168.0.1", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip is within the network but mask is outside - reversed order from above // ip is within the network but mask is outside - reversed order from above
let cip1 = netmask!("192.168.0.1", "255.255.255.0"); let cip1 = netmask!("192.168.0.1", "255.255.255.0");
let cip2 = netmask!("10.0.1.0", "255.254.0.0"); let cip2 = netmask!("10.0.1.0", "255.254.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip and mask are within the network // ip and mask are within the network
let cip1 = netmask!("10.0.1.0", "255.255.0.0"); let cip1 = netmask!("10.0.1.0", "255.255.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.128"); let cip2 = netmask!("192.168.0.1", "255.255.255.128");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches // Exact matches
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![ca_ip_1, ca_ip_2], vec![], vec![]); let (cert, _, _) = test_cert(
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); &ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![ca_ip_1, ca_ip_2],
vec![],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed // Exact matches reversed
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![ca_ip_2, ca_ip_1], vec![], vec![]); let (cert, _, _) = test_cert(
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); &ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![ca_ip_2, ca_ip_1],
vec![],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed with just one // Exact matches reversed with just one
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![ca_ip_2], vec![], vec![]); let (cert, _, _) = test_cert(
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); &ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![ca_ip_2],
vec![],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
} }
#[test] #[test]
@ -405,7 +611,13 @@ fn cert_verify_subnet() {
let ca_ip_1 = Ipv4Net::from_str("10.0.0.0/16").unwrap(); let ca_ip_1 = Ipv4Net::from_str("10.0.0.0/16").unwrap();
let ca_ip_2 = Ipv4Net::from_str("192.168.0.0/24").unwrap(); let ca_ip_2 = Ipv4Net::from_str("192.168.0.0/24").unwrap();
let (ca, ca_key, _ca_pub) = test_ca_cert(round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(), vec![],vec![ca_ip_1, ca_ip_2], vec![]); let (ca, ca_key, _ca_pub) = test_ca_cert(
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(),
vec![],
vec![ca_ip_1, ca_ip_2],
vec![],
);
let ca_pem = ca.serialize_to_pem().unwrap(); let ca_pem = ca.serialize_to_pem().unwrap();
@ -415,63 +627,170 @@ fn cert_verify_subnet() {
// ip is outside the network // ip is outside the network
let cip1 = netmask!("10.1.0.0", "255.255.255.0"); let cip1 = netmask!("10.1.0.0", "255.255.255.0");
let cip2 = netmask!("192.198.0.1", "255.255.0.0"); let cip2 = netmask!("192.198.0.1", "255.255.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![cip1, cip2], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip is outside the network - reversed order from above // ip is outside the network - reversed order from above
let cip1 = netmask!("192.198.0.1", "255.255.255.0"); let cip1 = netmask!("192.198.0.1", "255.255.255.0");
let cip2 = netmask!("10.1.0.0", "255.255.255.0"); let cip2 = netmask!("10.1.0.0", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![cip1, cip2], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip is within the network but mask is outside // ip is within the network but mask is outside
let cip1 = netmask!("10.0.1.0", "255.254.0.0"); let cip1 = netmask!("10.0.1.0", "255.254.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.0"); let cip2 = netmask!("192.168.0.1", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![cip1, cip2], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip is within the network but mask is outside - reversed order from above // ip is within the network but mask is outside - reversed order from above
let cip1 = netmask!("192.168.0.1", "255.255.255.0"); let cip1 = netmask!("192.168.0.1", "255.255.255.0");
let cip2 = netmask!("10.0.1.0", "255.254.0.0"); let cip2 = netmask!("10.0.1.0", "255.254.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip and mask are within the network // ip and mask are within the network
let cip1 = netmask!("10.0.1.0", "255.255.0.0"); let cip1 = netmask!("10.0.1.0", "255.255.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.128"); let cip2 = netmask!("192.168.0.1", "255.255.255.128");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![cip1, cip2], vec![]); let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches // Exact matches
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![ca_ip_1, ca_ip_2], vec![]); let (cert, _, _) = test_cert(
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); &ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![ca_ip_1, ca_ip_2],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed // Exact matches reversed
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![ca_ip_2, ca_ip_1], vec![]); let (cert, _, _) = test_cert(
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); &ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![ca_ip_2, ca_ip_1],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed with just one // Exact matches reversed with just one
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![ca_ip_2], vec![]); let (cert, _, _) = test_cert(
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok); &ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![ca_ip_2],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
} }
#[test] #[test]
fn cert_private_key() { fn cert_private_key() {
let (ca, ca_key, _) = test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]); let (ca, ca_key, _) =
test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
ca.verify_private_key(&ca_key.to_keypair_bytes()).unwrap(); ca.verify_private_key(&ca_key.to_keypair_bytes()).unwrap();
let (_, ca_key2, _) = test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]); let (_, ca_key2, _) =
ca.verify_private_key(&ca_key2.to_keypair_bytes()).unwrap_err(); test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
ca.verify_private_key(&ca_key2.to_keypair_bytes())
.unwrap_err();
let (cert, priv_key, _) = test_cert(&ca, &ca_key, SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]); let (cert, priv_key, _) = test_cert(
&ca,
&ca_key,
SystemTime::now(),
SystemTime::now(),
vec![],
vec![],
vec![],
);
cert.verify_private_key(&priv_key.to_bytes()).unwrap(); cert.verify_private_key(&priv_key.to_bytes()).unwrap();
let (cert2, _, _) = test_cert(&ca, &ca_key, SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]); let (cert2, _, _) = test_cert(
&ca,
&ca_key,
SystemTime::now(),
SystemTime::now(),
vec![],
vec![],
vec![],
);
cert2.verify_private_key(&priv_key.to_bytes()).unwrap_err(); cert2.verify_private_key(&priv_key.to_bytes()).unwrap_err();
} }
@ -489,7 +808,8 @@ CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF 8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
-----END NEBULA CERTIFICATE-----"; -----END NEBULA CERTIFICATE-----";
let with_newlines = b"# Current provisional, Remove once everything moves over to the real root. let with_newlines =
b"# Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
@ -511,24 +831,64 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
-----END NEBULA CERTIFICATE-----"; -----END NEBULA CERTIFICATE-----";
let pool_a = NebulaCAPool::new_from_pem(no_newlines).unwrap(); let pool_a = NebulaCAPool::new_from_pem(no_newlines).unwrap();
assert_eq!(pool_a.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"].details.name, "nebula root ca".to_string()); assert_eq!(
assert_eq!(pool_a.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"].details.name, "nebula root ca 01".to_string()); pool_a.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"]
.details
.name,
"nebula root ca".to_string()
);
assert_eq!(
pool_a.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"]
.details
.name,
"nebula root ca 01".to_string()
);
assert!(!pool_a.expired); assert!(!pool_a.expired);
let pool_b = NebulaCAPool::new_from_pem(with_newlines).unwrap(); let pool_b = NebulaCAPool::new_from_pem(with_newlines).unwrap();
assert_eq!(pool_b.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"].details.name, "nebula root ca".to_string()); assert_eq!(
assert_eq!(pool_b.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"].details.name, "nebula root ca 01".to_string()); pool_b.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"]
.details
.name,
"nebula root ca".to_string()
);
assert_eq!(
pool_b.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"]
.details
.name,
"nebula root ca 01".to_string()
);
assert!(!pool_b.expired); assert!(!pool_b.expired);
let pool_c = NebulaCAPool::new_from_pem(expired).unwrap(); let pool_c = NebulaCAPool::new_from_pem(expired).unwrap();
assert!(pool_c.expired); assert!(pool_c.expired);
assert_eq!(pool_c.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"].details.name, "expired"); assert_eq!(
pool_c.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"]
.details
.name,
"expired"
);
let mut pool_d = NebulaCAPool::new_from_pem(with_newlines).unwrap(); let mut pool_d = NebulaCAPool::new_from_pem(with_newlines).unwrap();
pool_d.add_ca_certificate(expired).unwrap(); pool_d.add_ca_certificate(expired).unwrap();
assert_eq!(pool_d.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"].details.name, "nebula root ca".to_string()); assert_eq!(
assert_eq!(pool_d.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"].details.name, "nebula root ca 01".to_string()); pool_d.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"]
assert_eq!(pool_d.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"].details.name, "expired"); .details
.name,
"nebula root ca".to_string()
);
assert_eq!(
pool_d.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"]
.details
.name,
"nebula root ca 01".to_string()
);
assert_eq!(
pool_d.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"]
.details
.name,
"expired"
);
assert!(pool_d.expired); assert!(pool_d.expired);
assert_eq!(pool_d.get_fingerprints().len(), 3); assert_eq!(pool_d.get_fingerprints().len(), 3);
} }
@ -536,7 +896,11 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
#[macro_export] #[macro_export]
macro_rules! netmask { macro_rules! netmask {
($ip:expr,$mask:expr) => { ($ip:expr,$mask:expr) => {
Ipv4Net::with_netmask(Ipv4Addr::from_str($ip).unwrap(), Ipv4Addr::from_str($mask).unwrap()).unwrap() Ipv4Net::with_netmask(
Ipv4Addr::from_str($ip).unwrap(),
Ipv4Addr::from_str($mask).unwrap(),
)
.unwrap()
}; };
} }
@ -545,7 +909,13 @@ fn round_systime_to_secs(time: SystemTime) -> Result<SystemTime, SystemTimeError
Ok(SystemTime::UNIX_EPOCH.add(Duration::from_secs(secs))) Ok(SystemTime::UNIX_EPOCH.add(Duration::from_secs(secs)))
} }
fn test_ca_cert(before: SystemTime, after: SystemTime, ips: Vec<Ipv4Net>, subnets: Vec<Ipv4Net>, groups: Vec<String>) -> (NebulaCertificate, SigningKey, VerifyingKey) { fn test_ca_cert(
before: SystemTime,
after: SystemTime,
ips: Vec<Ipv4Net>,
subnets: Vec<Ipv4Net>,
groups: Vec<String>,
) -> (NebulaCertificate, SigningKey, VerifyingKey) {
let mut csprng = OsRng; let mut csprng = OsRng;
let key = SigningKey::generate(&mut csprng); let key = SigningKey::generate(&mut csprng);
let pub_key = key.verifying_key(); let pub_key = key.verifying_key();
@ -710,17 +1080,45 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
fn ca_pool_add_non_ca() { fn ca_pool_add_non_ca() {
let mut ca_pool = NebulaCAPool::new(); let mut ca_pool = NebulaCAPool::new();
let (ca, ca_key, _) = test_ca_cert(SystemTime::now(), SystemTime::now() + Duration::from_secs(3600), vec![], vec![], vec![]); let (ca, ca_key, _) = test_ca_cert(
let (cert, _, _) = test_cert(&ca, &ca_key, SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]); SystemTime::now(),
SystemTime::now() + Duration::from_secs(3600),
vec![],
vec![],
vec![],
);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
SystemTime::now(),
SystemTime::now(),
vec![],
vec![],
vec![],
);
ca_pool.add_ca_certificate(&cert.serialize_to_pem().unwrap()).unwrap_err(); ca_pool
.add_ca_certificate(&cert.serialize_to_pem().unwrap())
.unwrap_err();
} }
fn test_cert(ca: &NebulaCertificate, key: &SigningKey, before: SystemTime, after: SystemTime, ips: Vec<Ipv4Net>, subnets: Vec<Ipv4Net>, groups: Vec<String>) -> (NebulaCertificate, SigningKey, VerifyingKey) { fn test_cert(
ca: &NebulaCertificate,
key: &SigningKey,
before: SystemTime,
after: SystemTime,
ips: Vec<Ipv4Net>,
subnets: Vec<Ipv4Net>,
groups: Vec<String>,
) -> (NebulaCertificate, SigningKey, VerifyingKey) {
let issuer = ca.sha256sum().unwrap(); let issuer = ca.sha256sum().unwrap();
let real_groups = if groups.is_empty() { let real_groups = if groups.is_empty() {
vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()] vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
]
} else { } else {
groups groups
}; };
@ -729,7 +1127,7 @@ fn test_cert(ca: &NebulaCertificate, key: &SigningKey, before: SystemTime, after
vec![ vec![
netmask!("10.1.1.1", "255.255.255.0"), netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"), netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0") netmask!("10.1.1.3", "255.0.0.0"),
] ]
} else { } else {
ips ips
@ -739,7 +1137,7 @@ fn test_cert(ca: &NebulaCertificate, key: &SigningKey, before: SystemTime, after
vec![ vec![
netmask!("9.1.1.1", "255.255.255.128"), netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"), netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0") netmask!("9.1.1.3", "255.255.0.0"),
] ]
} else { } else {
subnets subnets
@ -774,4 +1172,4 @@ fn in_a_minute() -> SystemTime {
#[allow(dead_code)] #[allow(dead_code)]
fn a_minute_ago() -> SystemTime { fn a_minute_ago() -> SystemTime {
round_systime_to_secs(SystemTime::now().sub(Duration::from_secs(60))).unwrap() round_systime_to_secs(SystemTime::now().sub(Duration::from_secs(60))).unwrap()
} }