enrollment

This commit is contained in:
c0repwn3r 2023-05-14 13:47:49 -04:00
parent 224e3680e0
commit b8c6ddd123
Signed by: core
GPG key ID: FDBF740DADDCEECF
36 changed files with 3365 additions and 1201 deletions

1135
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[package]
name = "dnapi-rs"
version = "0.1.9"
version = "0.1.11"
edition = "2021"
description = "A rust client for the Defined Networking API"
license = "AGPL-3.0-or-later"

View file

@ -1,17 +1,21 @@
//! Client structs to handle communication with the Defined Networking API. This is the async client API - if you want blocking instead, enable the blocking (or default) feature instead.
use std::error::Error;
use crate::credentials::{ed25519_public_keys_from_pem, Credentials};
use crate::crypto::{new_keys, nonce};
use crate::message::{
CheckForUpdateResponseWrapper, DoUpdateRequest, DoUpdateResponse, EnrollRequest,
EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper, CHECK_FOR_UPDATE, DO_UPDATE,
ENDPOINT_V1, ENROLL_ENDPOINT,
};
use base64::Engine;
use chrono::Local;
use log::{debug, error};
use reqwest::StatusCode;
use url::Url;
use serde::{Deserialize, Serialize};
use std::error::Error;
use trifid_pki::cert::serialize_ed25519_public;
use trifid_pki::ed25519_dalek::{Signature, Signer, SigningKey, Verifier};
use crate::credentials::{Credentials, ed25519_public_keys_from_pem};
use crate::crypto::{new_keys, nonce};
use crate::message::{CHECK_FOR_UPDATE, CheckForUpdateResponseWrapper, DO_UPDATE, DoUpdateRequest, DoUpdateResponse, ENDPOINT_V1, ENROLL_ENDPOINT, EnrollRequest, EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper};
use serde::{Serialize, Deserialize};
use base64::Engine;
use url::Url;
/// A type alias to abstract return types
pub type NebulaConfig = Vec<u8>;
@ -22,7 +26,7 @@ pub type DHPrivateKeyPEM = Vec<u8>;
/// A combination of persistent data and HTTP client used for communicating with the API.
pub struct Client {
http_client: reqwest::Client,
server_url: Url
server_url: Url,
}
#[derive(Serialize, Deserialize, Clone)]
@ -31,7 +35,7 @@ pub struct EnrollMeta {
/// The server organization ID this node is now a member of
pub organization_id: String,
/// The server organization name this node is now a member of
pub organization_name: String
pub organization_name: String,
}
impl Client {
@ -42,7 +46,7 @@ impl Client {
let client = reqwest::Client::builder().user_agent(user_agent).build()?;
Ok(Self {
http_client: client,
server_url: api_base
server_url: api_base,
})
}
@ -59,8 +63,14 @@ impl Client {
/// - the server returns an error
/// - the server returns invalid JSON
/// - the `trusted_keys` field is invalid
pub async fn enroll(&self, code: &str) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> {
debug!("making enrollment request to API {{server: {}, code: {}}}", self.server_url, code);
pub async fn enroll(
&self,
code: &str,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> {
debug!(
"making enrollment request to API {{server: {}, code: {}}}",
self.server_url, code
);
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
@ -71,9 +81,19 @@ impl Client {
timestamp: Local::now().format("%Y-%m-%dT%H:%M:%S.%f%:z").to_string(),
})?;
let resp = self.http_client.post(self.server_url.join(ENROLL_ENDPOINT)?).body(req_json).send().await?;
let resp = self
.http_client
.post(self.server_url.join(ENROLL_ENDPOINT)?)
.body(req_json)
.header("Content-Type", "application/json")
.send()
.await?;
let req_id = resp.headers().get("X-Request-ID").ok_or("Response missing X-Request-ID")?.to_str()?;
let req_id = resp
.headers()
.get("X-Request-ID")
.ok_or("Response missing X-Request-ID")?
.to_str()?;
debug!("enrollment request complete {{req_id: {}}}", req_id);
let resp: EnrollResponse = resp.json().await?;
@ -107,7 +127,15 @@ impl Client {
/// # Errors
/// This function returns an error if the dnclient request fails, or the server returns invalid data.
pub async fn check_for_update(&self, creds: &Credentials) -> Result<bool, Box<dyn Error>> {
let body = self.post_dnclient(CHECK_FOR_UPDATE, &[], &creds.host_id, creds.counter, &creds.ed_privkey).await?;
let body = self
.post_dnclient(
CHECK_FOR_UPDATE,
&[],
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)
.await?;
let result: CheckForUpdateResponseWrapper = serde_json::from_slice(&body)?;
@ -125,7 +153,10 @@ impl Client {
/// - if the response could not be deserialized
/// - if the signature is invalid
/// - if the keys are invalid
pub async fn do_update(&self, creds: &Credentials) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> {
pub async fn do_update(
&self,
creds: &Credentials,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> {
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
let update_keys = DoUpdateRequest {
@ -136,28 +167,45 @@ impl Client {
let update_keys_blob = serde_json::to_vec(&update_keys)?;
let resp = self.post_dnclient(DO_UPDATE, &update_keys_blob, &creds.host_id, creds.counter, &creds.ed_privkey).await?;
let resp = self
.post_dnclient(
DO_UPDATE,
&update_keys_blob,
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)
.await?;
let result_wrapper: SignedResponseWrapper = serde_json::from_slice(&resp)?;
let mut valid = false;
for ca_pubkey in &creds.trusted_keys {
if ca_pubkey.verify(&result_wrapper.data.message, &Signature::from_slice(&result_wrapper.data.signature)?).is_ok() {
if ca_pubkey
.verify(
&result_wrapper.data.message,
&Signature::from_slice(&result_wrapper.data.signature)?,
)
.is_ok()
{
valid = true;
break;
}
}
if !valid {
return Err("Failed to verify signed API result".into())
return Err("Failed to verify signed API result".into());
}
let result: DoUpdateResponse = serde_json::from_slice(&result_wrapper.data.message)?;
if result.nonce != update_keys.nonce {
error!("nonce mismatch between request {:x?} and response {:x?}", result.nonce, update_keys.nonce);
return Err("nonce mismatch between request and response".into())
error!(
"nonce mismatch between request {:x?} and response {:x?}",
result.nonce, update_keys.nonce
);
return Err("nonce mismatch between request and response".into());
}
let trusted_keys = ed25519_public_keys_from_pem(&result.trusted_keys)?;
@ -179,7 +227,14 @@ impl Client {
/// - serialization in any step fails
/// - if the `server_url` is invalid
/// - if the request could not be sent
pub async fn post_dnclient(&self, req_type: &str, value: &[u8], host_id: &str, counter: u32, ed_privkey: &SigningKey) -> Result<Vec<u8>, Box<dyn Error>> {
pub async fn post_dnclient(
&self,
req_type: &str,
value: &[u8],
host_id: &str,
counter: u32,
ed_privkey: &SigningKey,
) -> Result<Vec<u8>, Box<dyn Error>> {
let encoded_msg = serde_json::to_string(&RequestWrapper {
message_type: req_type.to_string(),
value: value.to_vec(),
@ -203,19 +258,23 @@ impl Client {
let post_body = serde_json::to_string(&body)?;
let resp = self.http_client.post(self.server_url.join(ENDPOINT_V1)?).body(post_body).send().await?;
let resp = self
.http_client
.post(self.server_url.join(ENDPOINT_V1)?)
.body(post_body)
.send()
.await?;
match resp.status() {
StatusCode::OK => {
Ok(resp.bytes().await?.to_vec())
},
StatusCode::FORBIDDEN => {
Err("Forbidden".into())
},
StatusCode::OK => Ok(resp.bytes().await?.to_vec()),
StatusCode::FORBIDDEN => Err("Forbidden".into()),
_ => {
error!("dnclient endpoint returned bad status code {}", resp.status());
error!(
"dnclient endpoint returned bad status code {}",
resp.status()
);
Err("dnclient endpoint returned error".into())
}
}
}
}
}

View file

@ -1,17 +1,21 @@
//! Client structs to handle communication with the Defined Networking API. This is the blocking client API - if you want async instead, set no-default-features and enable the async feature instead.
use std::error::Error;
use crate::credentials::{ed25519_public_keys_from_pem, Credentials};
use crate::crypto::{new_keys, nonce};
use crate::message::{
CheckForUpdateResponseWrapper, DoUpdateRequest, DoUpdateResponse, EnrollRequest,
EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper, CHECK_FOR_UPDATE, DO_UPDATE,
ENDPOINT_V1, ENROLL_ENDPOINT,
};
use base64::Engine;
use chrono::Local;
use log::{debug, error, trace};
use reqwest::StatusCode;
use url::Url;
use serde::{Deserialize, Serialize};
use std::error::Error;
use trifid_pki::cert::serialize_ed25519_public;
use trifid_pki::ed25519_dalek::{Signature, Signer, SigningKey, Verifier};
use crate::credentials::{Credentials, ed25519_public_keys_from_pem};
use crate::crypto::{new_keys, nonce};
use crate::message::{CHECK_FOR_UPDATE, CheckForUpdateResponseWrapper, DO_UPDATE, DoUpdateRequest, DoUpdateResponse, ENDPOINT_V1, ENROLL_ENDPOINT, EnrollRequest, EnrollResponse, RequestV1, RequestWrapper, SignedResponseWrapper};
use serde::{Serialize, Deserialize};
use url::Url;
/// A type alias to abstract return types
pub type NebulaConfig = Vec<u8>;
@ -22,7 +26,7 @@ pub type DHPrivateKeyPEM = Vec<u8>;
/// A combination of persistent data and HTTP client used for communicating with the API.
pub struct Client {
http_client: reqwest::blocking::Client,
server_url: Url
server_url: Url,
}
#[derive(Serialize, Deserialize, Clone)]
@ -31,7 +35,7 @@ pub struct EnrollMeta {
/// The server organization ID this node is now a member of
pub organization_id: String,
/// The server organization name this node is now a member of
pub organization_name: String
pub organization_name: String,
}
impl Client {
@ -39,10 +43,12 @@ impl Client {
/// # Errors
/// This function will return an error if the reqwest Client could not be created.
pub fn new(user_agent: String, api_base: Url) -> Result<Self, Box<dyn Error>> {
let client = reqwest::blocking::Client::builder().user_agent(user_agent).build()?;
let client = reqwest::blocking::Client::builder()
.user_agent(user_agent)
.build()?;
Ok(Self {
http_client: client,
server_url: api_base
server_url: api_base,
})
}
@ -59,8 +65,14 @@ impl Client {
/// - the server returns an error
/// - the server returns invalid JSON
/// - the `trusted_keys` field is invalid
pub fn enroll(&self, code: &str) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> {
debug!("making enrollment request to API {{server: {}, code: {}}}", self.server_url, code);
pub fn enroll(
&self,
code: &str,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials, EnrollMeta), Box<dyn Error>> {
debug!(
"making enrollment request to API {{server: {}, code: {}}}",
self.server_url, code
);
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
@ -71,9 +83,18 @@ impl Client {
timestamp: Local::now().format("%Y-%m-%dT%H:%M:%S.%f%:z").to_string(),
})?;
let resp = self.http_client.post(self.server_url.join(ENROLL_ENDPOINT)?).body(req_json).send()?;
let resp = self
.http_client
.post(self.server_url.join(ENROLL_ENDPOINT)?)
.header("Content-Type", "application/json")
.body(req_json)
.send()?;
let req_id = resp.headers().get("X-Request-ID").ok_or("Response missing X-Request-ID")?.to_str()?;
let req_id = resp
.headers()
.get("X-Request-ID")
.ok_or("Response missing X-Request-ID")?
.to_str()?;
debug!("enrollment request complete {{req_id: {}}}", req_id);
let resp: EnrollResponse = resp.json()?;
@ -93,7 +114,6 @@ impl Client {
debug!("parsing public keys");
let trusted_keys = ed25519_public_keys_from_pem(&r.trusted_keys)?;
let creds = Credentials {
@ -110,7 +130,13 @@ impl Client {
/// # Errors
/// This function returns an error if the dnclient request fails, or the server returns invalid data.
pub fn check_for_update(&self, creds: &Credentials) -> Result<bool, Box<dyn Error>> {
let body = self.post_dnclient(CHECK_FOR_UPDATE, &[], &creds.host_id, creds.counter, &creds.ed_privkey)?;
let body = self.post_dnclient(
CHECK_FOR_UPDATE,
&[],
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)?;
let result: CheckForUpdateResponseWrapper = serde_json::from_slice(&body)?;
@ -128,7 +154,10 @@ impl Client {
/// - if the response could not be deserialized
/// - if the signature is invalid
/// - if the keys are invalid
pub fn do_update(&self, creds: &Credentials) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> {
pub fn do_update(
&self,
creds: &Credentials,
) -> Result<(NebulaConfig, DHPrivateKeyPEM, Credentials), Box<dyn Error>> {
let (dh_pubkey_pem, dh_privkey_pem, ed_pubkey, ed_privkey) = new_keys();
let update_keys = DoUpdateRequest {
@ -139,33 +168,51 @@ impl Client {
let update_keys_blob = serde_json::to_vec(&update_keys)?;
let resp = self.post_dnclient(DO_UPDATE, &update_keys_blob, &creds.host_id, creds.counter, &creds.ed_privkey)?;
let resp = self.post_dnclient(
DO_UPDATE,
&update_keys_blob,
&creds.host_id,
creds.counter,
&creds.ed_privkey,
)?;
let result_wrapper: SignedResponseWrapper = serde_json::from_slice(&resp)?;
let mut valid = false;
for ca_pubkey in &creds.trusted_keys {
if ca_pubkey.verify(&result_wrapper.data.message, &Signature::from_slice(&result_wrapper.data.signature)?).is_ok() {
if ca_pubkey
.verify(
&result_wrapper.data.message,
&Signature::from_slice(&result_wrapper.data.signature)?,
)
.is_ok()
{
valid = true;
break;
}
}
if !valid {
return Err("Failed to verify signed API result".into())
return Err("Failed to verify signed API result".into());
}
let result: DoUpdateResponse = serde_json::from_slice(&result_wrapper.data.message)?;
if result.nonce != update_keys.nonce {
error!("nonce mismatch between request {:x?} and response {:x?}", result.nonce, update_keys.nonce);
return Err("nonce mismatch between request and response".into())
error!(
"nonce mismatch between request {:x?} and response {:x?}",
result.nonce, update_keys.nonce
);
return Err("nonce mismatch between request and response".into());
}
if result.counter <= creds.counter {
error!("counter in request {} should be less than counter in response {}", creds.counter, result.counter);
return Err("received older config than what we already had".into())
error!(
"counter in request {} should be less than counter in response {}",
creds.counter, result.counter
);
return Err("received older config than what we already had".into());
}
let trusted_keys = ed25519_public_keys_from_pem(&result.trusted_keys)?;
@ -187,7 +234,14 @@ impl Client {
/// - serialization in any step fails
/// - if the `server_url` is invalid
/// - if the request could not be sent
pub fn post_dnclient(&self, req_type: &str, value: &[u8], host_id: &str, counter: u32, ed_privkey: &SigningKey) -> Result<Vec<u8>, Box<dyn Error>> {
pub fn post_dnclient(
&self,
req_type: &str,
value: &[u8],
host_id: &str,
counter: u32,
ed_privkey: &SigningKey,
) -> Result<Vec<u8>, Box<dyn Error>> {
let encoded_msg = serde_json::to_string(&RequestWrapper {
message_type: req_type.to_string(),
value: value.to_vec(),
@ -213,19 +267,22 @@ impl Client {
trace!("sending dnclient request {}", post_body);
let resp = self.http_client.post(self.server_url.join(ENDPOINT_V1)?).body(post_body).send()?;
let resp = self
.http_client
.post(self.server_url.join(ENDPOINT_V1)?)
.body(post_body)
.send()?;
match resp.status() {
StatusCode::OK => {
Ok(resp.bytes()?.to_vec())
},
StatusCode::FORBIDDEN => {
Err("Forbidden".into())
},
StatusCode::OK => Ok(resp.bytes()?.to_vec()),
StatusCode::FORBIDDEN => Err("Forbidden".into()),
_ => {
error!("dnclient endpoint returned bad status code {}", resp.status());
error!(
"dnclient endpoint returned bad status code {}",
resp.status()
);
Err("dnclient endpoint returned error".into())
}
}
}
}
}

View file

@ -1,9 +1,9 @@
//! Contains the `Credentials` struct, which contains all keys, IDs, organizations and other identity-related and security-related data that is persistent in a `Client`
use serde::{Deserialize, Serialize};
use std::error::Error;
use trifid_pki::cert::{deserialize_ed25519_public_many, serialize_ed25519_public};
use trifid_pki::ed25519_dalek::{SigningKey, VerifyingKey};
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize, Clone)]
/// Contains information necessary to make requests against the `DNClient` API.
@ -15,7 +15,7 @@ pub struct Credentials {
/// The counter used in the other API requests. It is unknown what the purpose of this is, but the original client persists it and it is needed for API calls.
pub counter: u32,
/// The set of trusted ed25519 keys that may be used by the API to sign API responses.
pub trusted_keys: Vec<VerifyingKey>
pub trusted_keys: Vec<VerifyingKey>,
}
/// Converts an array of `VerifyingKey`s to a singular bundle of PEM-encoded keys
@ -38,8 +38,10 @@ pub fn ed25519_public_keys_from_pem(pem: &[u8]) -> Result<Vec<VerifyingKey>, Box
#[allow(clippy::unwrap_used)]
for pem in pems {
keys.push(VerifyingKey::from_bytes(&pem.try_into().unwrap_or_else(|_| unreachable!()))?);
keys.push(VerifyingKey::from_bytes(
&pem.try_into().unwrap_or_else(|_| unreachable!()),
)?);
}
Ok(keys)
}
}

View file

@ -1,7 +1,7 @@
//! Functions for generating keys and nonces for use in API calls
use rand::Rng;
use rand::rngs::OsRng;
use rand::Rng;
use trifid_pki::cert::{serialize_x25519_private, serialize_x25519_public};
use trifid_pki::ed25519_dalek::{SigningKey, VerifyingKey};
use trifid_pki::x25519_dalek::{PublicKey, StaticSecret};
@ -38,4 +38,4 @@ pub fn new_ed25519_keypair() -> (VerifyingKey, SigningKey) {
/// Generates a 16-byte random nonce for use in API calls
pub fn nonce() -> [u8; 16] {
rand::thread_rng().gen()
}
}

View file

@ -17,8 +17,8 @@
pub mod message;
pub mod client_blocking;
pub mod client_async;
pub mod client_blocking;
pub mod credentials;
pub mod crypto;
pub mod crypto;

View file

@ -1,7 +1,7 @@
//! Models for interacting with the Defined Networking API.
use base64_serde::base64_serde_type;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// The version 1 `DNClient` API endpoint
pub const ENDPOINT_V1: &str = "/v1/dnclient";
@ -27,7 +27,7 @@ pub struct RequestV1 {
pub message: String,
#[serde(with = "Base64Standard")]
/// An ed25519 signature over the `message`, which can be verified with the host's previously enrolled ed25519 public key
pub signature: Vec<u8>
pub signature: Vec<u8>,
}
#[derive(Serialize, Deserialize)]
@ -45,14 +45,14 @@ pub struct RequestWrapper {
/// For example:
/// `2023-03-29T09:56:42.380006369-04:00`
/// would represent `29 March 03, 2023, 09:56:42.380006369 UTC-4`
pub timestamp: String
pub timestamp: String,
}
#[derive(Serialize, Deserialize)]
/// `SignedResponseWrapper` contains a response message and a signature to validate inside `data`.
pub struct SignedResponseWrapper {
/// The response data contained in this message
pub data: SignedResponse
pub data: SignedResponse,
}
#[derive(Serialize, Deserialize)]
@ -65,14 +65,14 @@ pub struct SignedResponse {
pub message: Vec<u8>,
#[serde(with = "Base64Standard")]
/// The ed25519 signature over the `message`
pub signature: Vec<u8>
pub signature: Vec<u8>,
}
#[derive(Serialize, Deserialize)]
/// `CheckForUpdateResponseWrapper` contains a response to `CheckForUpdate` inside "data."
pub struct CheckForUpdateResponseWrapper {
/// The response data contained in this message
pub data: CheckForUpdateResponse
pub data: CheckForUpdateResponse,
}
#[derive(Serialize, Deserialize)]
@ -80,7 +80,7 @@ pub struct CheckForUpdateResponseWrapper {
pub struct CheckForUpdateResponse {
#[serde(rename = "updateAvailable")]
/// Set to true if a config update is available
pub update_available: bool
pub update_available: bool,
}
#[derive(Serialize, Deserialize)]
@ -97,7 +97,7 @@ pub struct DoUpdateRequest {
#[serde(with = "Base64Standard")]
/// A randomized value used to uniquely identify this request.
/// The original client uses a randomized, 16-byte value here, which dnapi-rs replicates
pub nonce: Vec<u8>
pub nonce: Vec<u8>,
}
#[derive(Serialize, Deserialize)]
@ -114,13 +114,13 @@ pub struct DoUpdateResponse {
#[serde(rename = "trustedKeys")]
#[serde(with = "Base64Standard")]
/// A new set of trusted ed25519 keys that can be used by the server to sign messages.
pub trusted_keys: Vec<u8>
pub trusted_keys: Vec<u8>,
}
/// The REST enrollment endpoint
pub const ENROLL_ENDPOINT: &str = "/v2/enroll";
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
/// `EnrollRequest` is issued to the `ENROLL_ENDPOINT` to enroll this `dnclient` with a dnapi organization
pub struct EnrollRequest {
/// The enrollment code given by the API server.
@ -138,10 +138,9 @@ pub struct EnrollRequest {
/// For example:
/// `2023-03-29T09:56:42.380006369-04:00`
/// would represent `29 March 03, 2023, 09:56:42.380006369 UTC-4`
pub timestamp: String
pub timestamp: String,
}
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
/// The response to an `EnrollRequest`
@ -149,13 +148,13 @@ pub enum EnrollResponse {
/// A successful enrollment, with a `data` field pointing to an `EnrollResponseData`
Success {
/// The response data from this response
data: EnrollResponseData
data: EnrollResponseData,
},
/// An unsuccessful enrollment, with an `errors` field pointing to an array of `APIError`s.
Error {
/// A list of `APIError`s that happened while trying to enroll. `APIErrors` is a type alias to `Vec<APIError>`
errors: APIErrors
}
errors: APIErrors,
},
}
#[derive(Serialize, Deserialize)]
@ -174,7 +173,7 @@ pub struct EnrollResponseData {
/// A new set of trusted ed25519 keys that can be used by the server to sign messages.
pub trusted_keys: Vec<u8>,
/// The organization data that this node is now a part of
pub organization: EnrollResponseDataOrg
pub organization: EnrollResponseDataOrg,
}
#[derive(Serialize, Deserialize)]
@ -183,7 +182,7 @@ pub struct EnrollResponseDataOrg {
/// The organization ID that this node is now a part of
pub id: String,
/// The name of the organization that this node is now a part of
pub name: String
pub name: String,
}
#[derive(Serialize, Deserialize)]
@ -194,8 +193,8 @@ pub struct APIError {
/// The human-readable error message
pub message: String,
/// An optional path to where the error occured
pub path: Option<String>
pub path: Option<String>,
}
/// A type alias to a array of `APIErrors`. Just for parity with dnapi.
pub type APIErrors = Vec<APIError>;
pub type APIErrors = Vec<APIError>;

View file

@ -1,6 +1,6 @@
[package]
name = "tfclient"
version = "0.1.7"
version = "0.1.8"
edition = "2021"
description = "An open-source reimplementation of a Defined Networking-compatible client"
license = "GPL-3.0-or-later"

View file

@ -1,19 +1,18 @@
use flate2::read::GzDecoder;
use reqwest::blocking::Response;
use reqwest::header::HeaderMap;
use std::fs;
use std::fs::{File, remove_file};
use std::fs::{remove_file, File};
use std::io::{Read, Write};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::{Command, Output};
use flate2::read::GzDecoder;
use reqwest::blocking::Response;
use reqwest::header::HeaderMap;
use tar::Archive;
#[derive(serde::Deserialize, Debug)]
struct GithubRelease {
name: String,
assets: Vec<GithubReleaseAsset>
assets: Vec<GithubReleaseAsset>,
}
#[derive(serde::Deserialize, Debug)]
@ -23,11 +22,18 @@ struct GithubUser {}
struct GithubReleaseAsset {
browser_download_url: String,
name: String,
size: i64
size: i64,
}
fn main() {
if Path::new(&format!("{}/{}", std::env::var("OUT_DIR").unwrap(), "noredownload")).exists() && std::env::var("TFBUILD_FORCE_REDOWNLOAD").is_err() {
if Path::new(&format!(
"{}/{}",
std::env::var("OUT_DIR").unwrap(),
"noredownload"
))
.exists()
&& std::env::var("TFBUILD_FORCE_REDOWNLOAD").is_err()
{
println!("noredownload exists and TFBUILD_FORCE_REDOWNLOAD is not set. Not redoing build process.");
return;
}
@ -38,15 +44,32 @@ fn main() {
let mut has_api_key = false;
if let Ok(api_key) = std::env::var("GH_API_KEY") {
headers.insert("Authorization", format!("Bearer {}", api_key).parse().unwrap());
headers.insert(
"Authorization",
format!("Bearer {}", api_key).parse().unwrap(),
);
has_api_key = true;
}
let client = reqwest::blocking::Client::builder().user_agent("curl/7.57.1").default_headers(headers).build().unwrap();
let client = reqwest::blocking::Client::builder()
.user_agent("curl/7.57.1")
.default_headers(headers)
.build()
.unwrap();
let resp: Response = client.get("https://api.github.com/repos/slackhq/nebula/releases/latest").send().unwrap();
let resp: Response = client
.get("https://api.github.com/repos/slackhq/nebula/releases/latest")
.send()
.unwrap();
if resp.headers().get("X-Ratelimit-Remaining").unwrap().to_str().unwrap() == "0" {
if resp
.headers()
.get("X-Ratelimit-Remaining")
.unwrap()
.to_str()
.unwrap()
== "0"
{
println!("You've been ratelimited from the GitHub API. Wait a while (1 hour)");
if !has_api_key {
println!("You can also set a GitHub API key with the environment variable GH_API_KEY, which will increase your ratelimit ( a lot )");
@ -54,7 +77,6 @@ fn main() {
panic!("Ratelimited");
}
let release: GithubRelease = resp.json().unwrap();
println!("[*] Fetching target triplet...");
@ -84,9 +106,16 @@ fn main() {
println!("[*] Embedding {} {}", target_file, release.name);
let download = release.assets.iter().find(|r| r.name == format!("{}.tar.gz", target_file)).expect("That architecture isn't avaliable :(");
let download = release
.assets
.iter()
.find(|r| r.name == format!("{}.tar.gz", target_file))
.expect("That architecture isn't avaliable :(");
println!("[*] Downloading {}.tar.gz ({}, {} bytes) from {}", target_file, target, download.size, download.browser_download_url);
println!(
"[*] Downloading {}.tar.gz ({}, {} bytes) from {}",
target_file, target, download.size, download.browser_download_url
);
let response = reqwest::blocking::get(&download.browser_download_url).unwrap();
let content = response.bytes().unwrap().to_vec();
@ -102,10 +131,14 @@ fn main() {
for entry in entries {
let mut entry = entry.unwrap();
if entry.path().unwrap() == Path::new("nebula") || entry.path().unwrap() == Path::new("nebula.exe") {
if entry.path().unwrap() == Path::new("nebula")
|| entry.path().unwrap() == Path::new("nebula.exe")
{
nebula_bin.reserve(entry.size() as usize);
entry.read_to_end(&mut nebula_bin).unwrap();
} else if entry.path().unwrap() == Path::new("nebula-cert") || entry.path().unwrap() == Path::new("nebula-cert.exe") {
} else if entry.path().unwrap() == Path::new("nebula-cert")
|| entry.path().unwrap() == Path::new("nebula-cert.exe")
{
nebula_cert_bin.reserve(entry.size() as usize);
entry.read_to_end(&mut nebula_cert_bin).unwrap();
} else if entry.path().unwrap() == Path::new("SHASUM256.txt") {
@ -121,18 +154,28 @@ fn main() {
panic!("[x] Release did not contain nebula_cert binary");
}
let mut nebula_file = File::create(format!("{}/nebula.bin", std::env::var("OUT_DIR").unwrap())).unwrap();
let mut nebula_file =
File::create(format!("{}/nebula.bin", std::env::var("OUT_DIR").unwrap())).unwrap();
nebula_file.write_all(&nebula_bin).unwrap();
codegen_version(&nebula_bin, "nebula.bin", "NEBULA");
let mut nebula_cert_file = File::create(format!("{}/nebula_cert.bin", std::env::var("OUT_DIR").unwrap())).unwrap();
let mut nebula_cert_file = File::create(format!(
"{}/nebula_cert.bin",
std::env::var("OUT_DIR").unwrap()
))
.unwrap();
nebula_cert_file.write_all(&nebula_cert_bin).unwrap();
codegen_version(&nebula_cert_bin, "nebula_cert.bin", "NEBULA_CERT");
// Indicate to cargo and ourselves that we have already downloaded and codegenned
File::create(format!("{}/{}", std::env::var("OUT_DIR").unwrap(), "noredownload")).unwrap();
File::create(format!(
"{}/{}",
std::env::var("OUT_DIR").unwrap(),
"noredownload"
))
.unwrap();
println!("cargo:rerun-if-changed=build.rs");
}
@ -149,7 +192,8 @@ fn codegen_version(bin: &[u8], fp: &str, name: &str) {
let code = format!("// This code was automatically @generated by build.rs. It should not be modified.\npub const {}_BIN: &[u8] = include_bytes!(concat!(env!(\"OUT_DIR\"), \"/{}\"));\npub const {}_VERSION: &str = \"{}\";", name, fp, name, version);
let mut file = File::create(format!("{}/{}.rs", std::env::var("OUT_DIR").unwrap(), fp)).unwrap();
let mut file =
File::create(format!("{}/{}.rs", std::env::var("OUT_DIR").unwrap(), fp)).unwrap();
file.write_all(code.as_bytes()).unwrap();
}

View file

@ -1,13 +1,9 @@
use std::fs;
use std::sync::mpsc::{Receiver, RecvError, TryRecvError};
use std::sync::mpsc::Receiver;
use dnapi_rs::client_blocking::Client;
use log::{error, info, warn};
use url::Url;
use dnapi_rs::client_blocking::Client;
use crate::config::{load_cdata, save_cdata, TFClientConfig};
use crate::daemon::ThreadMessageSender;
@ -18,10 +14,16 @@ pub enum APIWorkerMessage {
Shutdown,
Enroll { code: String },
Update,
Timer
Timer,
}
pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx: ThreadMessageSender, rx: Receiver<APIWorkerMessage>) {
pub fn apiworker_main(
_config: TFClientConfig,
instance: String,
url: String,
tx: ThreadMessageSender,
rx: Receiver<APIWorkerMessage>,
) {
let server = Url::parse(&url).unwrap();
let client = Client::new(format!("tfclient/{}", env!("CARGO_PKG_VERSION")), server).unwrap();
@ -33,7 +35,7 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
APIWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
break;
},
}
APIWorkerMessage::Timer | APIWorkerMessage::Update => {
info!("updating config");
let mut cdata = match load_cdata(&instance) {
@ -108,9 +110,13 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
};
cdata.creds = Some(creds);
cdata.dh_privkey = Some(dh_privkey.try_into().expect("32 != 32"));
cdata.dh_privkey = Some(dh_privkey);
match fs::write(get_nebulaconfig_file(&instance).expect("Unable to determine nebula config file location"), config) {
match fs::write(
get_nebulaconfig_file(&instance)
.expect("Unable to determine nebula config file location"),
config,
) {
Ok(_) => (),
Err(e) => {
error!("unable to save nebula config: {}", e);
@ -146,7 +152,7 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
return;
}
}
},
}
APIWorkerMessage::Enroll { code } => {
info!("recv on command socket: enroll {}", code);
let mut cdata = match load_cdata(&instance) {
@ -170,7 +176,11 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
}
};
match fs::write(get_nebulaconfig_file(&instance).expect("Unable to determine nebula config file location"), config) {
match fs::write(
get_nebulaconfig_file(&instance)
.expect("Unable to determine nebula config file location"),
config,
) {
Ok(_) => (),
Err(e) => {
error!("unable to save nebula config: {}", e);
@ -179,7 +189,7 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
}
cdata.creds = Some(creds);
cdata.dh_privkey = Some(dh_privkey.try_into().expect("32 != 32"));
cdata.dh_privkey = Some(dh_privkey);
cdata.meta = Some(meta);
// Save vardata
@ -204,11 +214,11 @@ pub fn apiworker_main(_config: TFClientConfig, instance: String, url: String, tx
}
}
}
},
}
Err(e) => {
error!("error on command socket: {}", e);
return;
}
}
}
}
}

View file

@ -1,33 +1,34 @@
use ipnet::{IpNet, Ipv4Net};
use std::collections::HashMap;
use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use ipnet::{IpNet, Ipv4Net};
use log::{debug, info};
use serde::{Deserialize, Serialize};
use dnapi_rs::client_blocking::EnrollMeta;
use dnapi_rs::credentials::Credentials;
use log::{debug, info};
use serde::{Deserialize, Serialize};
use crate::dirs::{get_cdata_dir, get_cdata_file, get_config_dir, get_config_file};
pub const DEFAULT_PORT: u16 = 8157;
fn default_port() -> u16 { DEFAULT_PORT }
fn default_port() -> u16 {
DEFAULT_PORT
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct TFClientConfig {
#[serde(default = "default_port")]
pub listen_port: u16,
#[serde(default = "bool_false")]
pub disable_automatic_config_updates: bool
pub disable_automatic_config_updates: bool,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct TFClientData {
pub dh_privkey: Option<Vec<u8>>,
pub creds: Option<Credentials>,
pub meta: Option<EnrollMeta>
pub meta: Option<EnrollMeta>,
}
pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> {
@ -39,7 +40,10 @@ pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> {
disable_automatic_config_updates: false,
};
let config_str = toml::to_string(&config)?;
fs::write(get_config_file(instance).ok_or("Unable to load config dir")?, config_str)?;
fs::write(
get_config_file(instance).ok_or("Unable to load config dir")?,
config_str,
)?;
Ok(())
}
@ -63,9 +67,16 @@ pub fn create_cdata(instance: &str) -> Result<(), Box<dyn Error>> {
info!("Creating data directory...");
fs::create_dir_all(get_cdata_dir(instance).ok_or("Unable to load data dir")?)?;
info!("Copying default data file to config directory...");
let config = TFClientData { dh_privkey: None, creds: None, meta: None };
let config = TFClientData {
dh_privkey: None,
creds: None,
meta: None,
};
let config_str = toml::to_string(&config)?;
fs::write(get_cdata_file(instance).ok_or("Unable to load data dir")?, config_str)?;
fs::write(
get_cdata_file(instance).ok_or("Unable to load data dir")?,
config_str,
)?;
Ok(())
}
@ -141,7 +152,7 @@ pub struct NebulaConfig {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub local_range: Option<Ipv4Net>
pub local_range: Option<Ipv4Net>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -156,7 +167,7 @@ pub struct NebulaConfigPki {
pub blocklist: Vec<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disconnect_invalid: bool
pub disconnect_invalid: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -190,7 +201,7 @@ pub struct NebulaConfigLighthouseDns {
pub host: String,
#[serde(default = "u16_53")]
#[serde(skip_serializing_if = "is_u16_53")]
pub port: u16
pub port: u16,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -207,7 +218,7 @@ pub struct NebulaConfigListen {
#[serde(skip_serializing_if = "is_none")]
pub read_buffer: Option<u32>,
#[serde(skip_serializing_if = "is_none")]
pub write_buffer: Option<u32>
pub write_buffer: Option<u32>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -220,7 +231,7 @@ pub struct NebulaConfigPunchy {
pub respond: bool,
#[serde(default = "string_1s")]
#[serde(skip_serializing_if = "is_string_1s")]
pub delay: String
pub delay: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -228,7 +239,7 @@ pub enum NebulaConfigCipher {
#[serde(rename = "aes")]
Aes,
#[serde(rename = "chachapoly")]
ChaChaPoly
ChaChaPoly,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -241,7 +252,7 @@ pub struct NebulaConfigRelay {
pub am_relay: bool,
#[serde(default = "bool_true")]
#[serde(skip_serializing_if = "is_bool_true")]
pub use_relays: bool
pub use_relays: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -268,13 +279,13 @@ pub struct NebulaConfigTun {
pub routes: Vec<NebulaConfigTunRouteOverride>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute>
pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunRouteOverride {
pub mtu: u64,
pub route: Ipv4Net
pub route: Ipv4Net,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -286,7 +297,7 @@ pub struct NebulaConfigTunUnsafeRoute {
pub mtu: u64,
#[serde(default = "i64_100")]
#[serde(skip_serializing_if = "is_i64_100")]
pub metric: i64
pub metric: i64,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -302,7 +313,7 @@ pub struct NebulaConfigLogging {
pub disable_timestamp: bool,
#[serde(default = "timestamp")]
#[serde(skip_serializing_if = "is_timestamp")]
pub timestamp_format: String
pub timestamp_format: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -318,7 +329,7 @@ pub enum NebulaConfigLoggingLevel {
#[serde(rename = "info")]
Info,
#[serde(rename = "debug")]
Debug
Debug,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -326,7 +337,7 @@ pub enum NebulaConfigLoggingFormat {
#[serde(rename = "json")]
Json,
#[serde(rename = "text")]
Text
Text,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -338,7 +349,7 @@ pub struct NebulaConfigSshd {
pub host_key: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser>
pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -346,7 +357,7 @@ pub struct NebulaConfigSshdAuthorizedUser {
pub user: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub keys: Vec<String>
pub keys: Vec<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -355,7 +366,7 @@ pub enum NebulaConfigStats {
#[serde(rename = "graphite")]
Graphite(NebulaConfigStatsGraphite),
#[serde(rename = "prometheus")]
Prometheus(NebulaConfigStatsPrometheus)
Prometheus(NebulaConfigStatsPrometheus),
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -373,7 +384,7 @@ pub struct NebulaConfigStatsGraphite {
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -381,7 +392,7 @@ pub enum NebulaConfigStatsGraphiteProtocol {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "udp")]
Udp
Udp,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -400,7 +411,7 @@ pub struct NebulaConfigStatsPrometheus {
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -428,7 +439,7 @@ pub struct NebulaConfigFirewallConntrack {
pub udp_timeout: String,
#[serde(default = "string_10m")]
#[serde(skip_serializing_if = "is_string_10m")]
pub default_timeout: String
pub default_timeout: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
@ -456,82 +467,175 @@ pub struct NebulaConfigFirewallRule {
pub groups: Option<Vec<String>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub cidr: Option<String>
pub cidr: Option<String>,
}
// Default values for serde
fn string_12m() -> String { "12m".to_string() }
fn is_string_12m(s: &str) -> bool { s == "12m" }
fn string_12m() -> String {
"12m".to_string()
}
fn is_string_12m(s: &str) -> bool {
s == "12m"
}
fn string_3m() -> String { "3m".to_string() }
fn is_string_3m(s: &str) -> bool { s == "3m" }
fn string_3m() -> String {
"3m".to_string()
}
fn is_string_3m(s: &str) -> bool {
s == "3m"
}
fn string_10m() -> String { "10m".to_string() }
fn is_string_10m(s: &str) -> bool { s == "10m" }
fn string_10m() -> String {
"10m".to_string()
}
fn is_string_10m(s: &str) -> bool {
s == "10m"
}
fn empty_vec<T>() -> Vec<T> { vec![] }
fn is_empty_vec<T>(v: &Vec<T>) -> bool { v.is_empty() }
fn empty_vec<T>() -> Vec<T> {
vec![]
}
fn is_empty_vec<T>(v: &Vec<T>) -> bool {
v.is_empty()
}
fn empty_hashmap<A, B>() -> HashMap<A, B> { HashMap::new() }
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool { h.is_empty() }
fn empty_hashmap<A, B>() -> HashMap<A, B> {
HashMap::new()
}
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool {
h.is_empty()
}
fn bool_false() -> bool { false }
fn is_bool_false(b: &bool) -> bool { !*b }
fn bool_false() -> bool {
false
}
fn is_bool_false(b: &bool) -> bool {
!*b
}
fn bool_true() -> bool { true }
fn is_bool_true(b: &bool) -> bool { *b }
fn bool_true() -> bool {
true
}
fn is_bool_true(b: &bool) -> bool {
*b
}
fn u16_53() -> u16 { 53 }
fn is_u16_53(u: &u16) -> bool { *u == 53 }
fn u16_53() -> u16 {
53
}
fn is_u16_53(u: &u16) -> bool {
*u == 53
}
fn u32_10() -> u32 { 10 }
fn is_u32_10(u: &u32) -> bool { *u == 10 }
fn u32_10() -> u32 {
10
}
fn is_u32_10(u: &u32) -> bool {
*u == 10
}
fn ipv4_0000() -> Ipv4Addr { Ipv4Addr::new(0, 0, 0, 0) }
fn is_ipv4_0000(i: &Ipv4Addr) -> bool { *i == ipv4_0000() }
fn u16_0() -> u16 {
0
}
fn is_u16_0(u: &u16) -> bool {
*u == 0
}
fn u16_0() -> u16 { 0 }
fn is_u16_0(u: &u16) -> bool { *u == 0 }
fn u32_64() -> u32 {
64
}
fn is_u32_64(u: &u32) -> bool {
*u == 64
}
fn u32_64() -> u32 { 64 }
fn is_u32_64(u: &u32) -> bool { *u == 64 }
fn string_1s() -> String {
"1s".to_string()
}
fn is_string_1s(s: &str) -> bool {
s == "1s"
}
fn string_1s() -> String { "1s".to_string() }
fn is_string_1s(s: &str) -> bool { s == "1s" }
fn cipher_aes() -> NebulaConfigCipher {
NebulaConfigCipher::Aes
}
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool {
matches!(c, NebulaConfigCipher::Aes)
}
fn cipher_aes() -> NebulaConfigCipher { NebulaConfigCipher::Aes }
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool { matches!(c, NebulaConfigCipher::Aes) }
fn u64_500() -> u64 {
500
}
fn is_u64_500(u: &u64) -> bool {
*u == 500
}
fn u64_500() -> u64 { 500 }
fn is_u64_500(u: &u64) -> bool { *u == 500 }
fn u64_1300() -> u64 {
1300
}
fn is_u64_1300(u: &u64) -> bool {
*u == 1300
}
fn u64_1300() -> u64 { 1300 }
fn is_u64_1300(u: &u64) -> bool { *u == 1300 }
fn i64_100() -> i64 {
100
}
fn is_i64_100(i: &i64) -> bool {
*i == 100
}
fn i64_100() -> i64 { 100 }
fn is_i64_100(i: &i64) -> bool { *i == 100 }
fn loglevel_info() -> NebulaConfigLoggingLevel {
NebulaConfigLoggingLevel::Info
}
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool {
matches!(l, NebulaConfigLoggingLevel::Info)
}
fn loglevel_info() -> NebulaConfigLoggingLevel { NebulaConfigLoggingLevel::Info }
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool { matches!(l, NebulaConfigLoggingLevel::Info) }
fn format_text() -> NebulaConfigLoggingFormat {
NebulaConfigLoggingFormat::Text
}
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool {
matches!(f, NebulaConfigLoggingFormat::Text)
}
fn format_text() -> NebulaConfigLoggingFormat { NebulaConfigLoggingFormat::Text }
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool { matches!(f, NebulaConfigLoggingFormat::Text) }
fn timestamp() -> String {
"2006-01-02T15:04:05Z07:00".to_string()
}
fn is_timestamp(s: &str) -> bool {
s == "2006-01-02T15:04:05Z07:00"
}
fn timestamp() -> String { "2006-01-02T15:04:05Z07:00".to_string() }
fn is_timestamp(s: &str) -> bool { s == "2006-01-02T15:04:05Z07:00" }
fn u64_1() -> u64 {
1
}
fn is_u64_1(u: &u64) -> bool {
*u == 1
}
fn u64_1() -> u64 { 1 }
fn is_u64_1(u: &u64) -> bool { *u == 1 }
fn string_nebula() -> String {
"nebula".to_string()
}
fn is_string_nebula(s: &str) -> bool {
s == "nebula"
}
fn string_nebula() -> String { "nebula".to_string() }
fn is_string_nebula(s: &str) -> bool { s == "nebula" }
fn string_empty() -> String {
String::new()
}
fn is_string_empty(s: &str) -> bool {
s.is_empty()
}
fn string_empty() -> String { String::new() }
fn is_string_empty(s: &str) -> bool { s == "" }
fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol {
NebulaConfigStatsGraphiteProtocol::Tcp
}
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool {
matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp)
}
fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol { NebulaConfigStatsGraphiteProtocol::Tcp }
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool { matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp) }
fn none<T>() -> Option<T> { None }
fn is_none<T>(o: &Option<T>) -> bool { o.is_none() }
fn none<T>() -> Option<T> {
None
}
fn is_none<T>(o: &Option<T>) -> bool {
o.is_none()
}

View file

@ -1,7 +1,7 @@
use log::{error, info};
use std::sync::mpsc;
use std::sync::mpsc::Sender;
use std::thread;
use log::{error, info};
use crate::apiworker::{apiworker_main, APIWorkerMessage};
use crate::config::load_config;
@ -44,28 +44,49 @@ pub fn daemon_main(name: String, server: String) {
match ctrlc::set_handler(move || {
info!("Ctrl-C detected. Stopping threads...");
match mainthread_transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) {
match mainthread_transmitter
.nebula_thread
.send(NebulaWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to nebula worker thread: {}", e);
error!(
"Error sending shutdown message to nebula worker thread: {}",
e
);
}
}
match mainthread_transmitter.api_thread.send(APIWorkerMessage::Shutdown) {
match mainthread_transmitter
.api_thread
.send(APIWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to api worker thread: {}", e);
}
}
match mainthread_transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) {
match mainthread_transmitter
.socket_thread
.send(SocketWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to socket worker thread: {}", e);
error!(
"Error sending shutdown message to socket worker thread: {}",
e
);
}
}
match mainthread_transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) {
match mainthread_transmitter
.timer_thread
.send(TimerWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to timer worker thread: {}", e);
error!(
"Error sending shutdown message to timer worker thread: {}",
e
);
}
}
}) {
@ -81,21 +102,21 @@ pub fn daemon_main(name: String, server: String) {
let config_api = config.clone();
let transmitter_api = transmitter.clone();
let name_api = name.clone();
let server_api = server.clone();
let server_api = server;
let api_thread = thread::spawn(move || {
apiworker_main(config_api, name_api, server_api,transmitter_api, rx_api);
apiworker_main(config_api, name_api, server_api, transmitter_api, rx_api);
});
info!("Starting Nebula thread...");
let config_nebula = config.clone();
let transmitter_nebula = transmitter.clone();
let name_nebula = name.clone();
let nebula_thread = thread::spawn(move || {
nebulaworker_main(config_nebula, name_nebula, transmitter_nebula, rx_nebula);
});
//let nebula_thread = thread::spawn(move || {
// nebulaworker_main(config_nebula, name_nebula, transmitter_nebula, rx_nebula);
//});
info!("Starting socket worker thread...");
let name_socket = name.clone();
let name_socket = name;
let config_socket = config.clone();
let tx_socket = transmitter.clone();
let socket_thread = thread::spawn(move || {
@ -104,7 +125,7 @@ pub fn daemon_main(name: String, server: String) {
info!("Starting timer thread...");
if !config.disable_automatic_config_updates {
let timer_transmitter = transmitter.clone();
let timer_transmitter = transmitter;
let timer_thread = thread::spawn(move || {
timer_main(timer_transmitter, rx_timer);
});
@ -142,13 +163,13 @@ pub fn daemon_main(name: String, server: String) {
info!("API thread exited");
info!("Waiting for Nebula thread to exit...");
match nebula_thread.join() {
Ok(_) => (),
Err(_) => {
error!("Error waiting for nebula thread to exit.");
std::process::exit(1);
}
}
//match nebula_thread.join() {
// Ok(_) => (),
// Err(_) => {
// error!("Error waiting for nebula thread to exit.");
// std::process::exit(1);
// }
//}
info!("Nebula thread exited");
info!("All threads exited");
@ -159,5 +180,5 @@ pub struct ThreadMessageSender {
pub socket_thread: Sender<SocketWorkerMessage>,
pub api_thread: Sender<APIWorkerMessage>,
pub nebula_thread: Sender<NebulaWorkerMessage>,
pub timer_thread: Sender<TimerWorkerMessage>
}
pub timer_thread: Sender<TimerWorkerMessage>,
}

View file

@ -22,4 +22,4 @@ pub fn get_cdata_file(instance: &str) -> Option<PathBuf> {
pub fn get_nebulaconfig_file(instance: &str) -> Option<PathBuf> {
get_cdata_dir(instance).map(|f| f.join("nebula.sk_embedded.yml"))
}
}

View file

@ -1,3 +1,6 @@
use crate::dirs::get_data_dir;
use crate::util::sha256;
use log::debug;
use std::error::Error;
use std::fs;
use std::fs::File;
@ -5,9 +8,6 @@ use std::io::Write;
use std::os::unix::fs::PermissionsExt;
use std::path::PathBuf;
use std::process::{Child, Command};
use log::debug;
use crate::dirs::get_data_dir;
use crate::util::sha256;
pub fn extract_embedded_nebula() -> Result<PathBuf, Box<dyn Error>> {
let data_dir = get_data_dir().ok_or("Unable to get platform-specific data dir")?;
@ -25,7 +25,11 @@ pub fn extract_embedded_nebula() -> Result<PathBuf, Box<dyn Error>> {
}
let executable_postfix = if cfg!(windows) { ".exe" } else { "" };
let executable_name = format!("nebula-{}{}", crate::nebula_bin::NEBULA_VERSION, executable_postfix);
let executable_name = format!(
"nebula-{}{}",
crate::nebula_bin::NEBULA_VERSION,
executable_postfix
);
let file_path = hash_dir.join(executable_name);
@ -49,7 +53,10 @@ pub fn extract_embedded_nebula_cert() -> Result<PathBuf, Box<dyn Error>> {
}
let bin_dir = data_dir.join("cache/");
let hash_dir = bin_dir.join(format!("{}/", sha256(crate::nebula_cert_bin::NEBULA_CERT_BIN)));
let hash_dir = bin_dir.join(format!(
"{}/",
sha256(crate::nebula_cert_bin::NEBULA_CERT_BIN)
));
if !hash_dir.exists() {
fs::create_dir_all(&hash_dir)?;
@ -57,7 +64,11 @@ pub fn extract_embedded_nebula_cert() -> Result<PathBuf, Box<dyn Error>> {
}
let executable_postfix = if cfg!(windows) { ".exe" } else { "" };
let executable_name = format!("nebula-cert-{}{}", crate::nebula_cert_bin::NEBULA_CERT_VERSION, executable_postfix);
let executable_name = format!(
"nebula-cert-{}{}",
crate::nebula_cert_bin::NEBULA_CERT_VERSION,
executable_postfix
);
let file_path = hash_dir.join(executable_name);
@ -101,4 +112,4 @@ pub fn run_embedded_nebula_cert(args: &[String]) -> Result<Child, Box<dyn Error>
debug!("Running {} with args {:?}", path.as_path().display(), args);
_setup_permissions(&path)?;
Ok(Command::new(path).args(args).spawn()?)
}
}

View file

@ -14,16 +14,16 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
pub mod embedded_nebula;
pub mod dirs;
pub mod util;
pub mod nebulaworker;
pub mod daemon;
pub mod config;
pub mod apiworker;
pub mod socketworker;
pub mod config;
pub mod daemon;
pub mod dirs;
pub mod embedded_nebula;
pub mod nebulaworker;
pub mod socketclient;
pub mod socketworker;
pub mod timerworker;
pub mod util;
pub mod nebula_bin {
include!(concat!(env!("OUT_DIR"), "/nebula.bin.rs"));
@ -32,15 +32,14 @@ pub mod nebula_cert_bin {
include!(concat!(env!("OUT_DIR"), "/nebula_cert.bin.rs"));
}
use std::fs;
use clap::{Parser, ArgAction, Subcommand};
use log::{error, info};
use simple_logger::SimpleLogger;
use crate::config::load_config;
use crate::dirs::get_data_dir;
use crate::embedded_nebula::{run_embedded_nebula, run_embedded_nebula_cert};
use clap::{ArgAction, Parser, Subcommand};
use log::{error, info};
use simple_logger::SimpleLogger;
#[derive(Parser)]
#[command(author = "c0repwn3r", version, about, long_about = None)]
@ -52,7 +51,7 @@ struct Cli {
version: bool,
#[command(subcommand)]
subcommand: Commands
subcommand: Commands,
}
#[derive(Subcommand)]
@ -60,14 +59,14 @@ enum Commands {
/// Run the `nebula` binary. This is useful if you want to do debugging with tfclient's internal nebula.
RunNebula {
/// Arguments to pass to the `nebula` binary
#[clap(trailing_var_arg=true, allow_hyphen_values=true)]
args: Vec<String>
#[clap(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String>,
},
/// Run the `nebula-cert` binary. This is useful if you want to mess with certificates. Note: tfclient does not actually use nebula-cert for certificate operations, and instead uses trifid-pki internally
RunNebulaCert {
/// Arguments to pass to the `nebula-cert` binary
#[clap(trailing_var_arg=true, allow_hyphen_values=true)]
args: Vec<String>
#[clap(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String>,
},
/// Clear any cached data that tfclient may have added
ClearCache {},
@ -79,7 +78,7 @@ enum Commands {
name: String,
#[clap(short, long)]
/// Server to use for API calls.
server: String
server: String,
},
/// Enroll this host using a trifid-api enrollment code
@ -97,7 +96,7 @@ enum Commands {
#[clap(short, long, default_value = "tfclient")]
/// Service name specified on install
name: String,
}
},
}
fn main() {
@ -110,34 +109,28 @@ fn main() {
}
match args.subcommand {
Commands::RunNebula { args } => {
match run_embedded_nebula(&args) {
Ok(mut c) => {
match c.wait() {
Ok(stat) => {
match stat.code() {
Some(code) => {
if code != 0 {
error!("Nebula process exited with nonzero status code {}", code);
}
std::process::exit(code);
},
None => {
info!("Nebula process terminated by signal");
std::process::exit(0);
}
}
},
Err(e) => {
error!("Unable to wait for child to exit: {}", e);
std::process::exit(1);
Commands::RunNebula { args } => match run_embedded_nebula(&args) {
Ok(mut c) => match c.wait() {
Ok(stat) => match stat.code() {
Some(code) => {
if code != 0 {
error!("Nebula process exited with nonzero status code {}", code);
}
std::process::exit(code);
}
None => {
info!("Nebula process terminated by signal");
std::process::exit(0);
}
},
Err(e) => {
error!("Unable to start nebula binary: {}", e);
error!("Unable to wait for child to exit: {}", e);
std::process::exit(1);
}
},
Err(e) => {
error!("Unable to start nebula binary: {}", e);
std::process::exit(1);
}
},
Commands::ClearCache { .. } => {
@ -159,37 +152,34 @@ fn main() {
info!("Removed all cached data.");
std::process::exit(0);
},
Commands::RunNebulaCert { args } => {
match run_embedded_nebula_cert(&args) {
Ok(mut c) => {
match c.wait() {
Ok(stat) => {
match stat.code() {
Some(code) => {
if code != 0 {
error!("nebula-cert process exited with nonzero status code {}", code);
}
std::process::exit(code);
},
None => {
info!("nebula-cert process terminated by signal");
std::process::exit(0);
}
}
},
Err(e) => {
error!("Unable to wait for child to exit: {}", e);
std::process::exit(1);
}
Commands::RunNebulaCert { args } => match run_embedded_nebula_cert(&args) {
Ok(mut c) => match c.wait() {
Ok(stat) => match stat.code() {
Some(code) => {
if code != 0 {
error!(
"nebula-cert process exited with nonzero status code {}",
code
);
}
std::process::exit(code);
}
None => {
info!("nebula-cert process terminated by signal");
std::process::exit(0);
}
},
Err(e) => {
error!("Unable to start nebula-cert binary: {}", e);
error!("Unable to wait for child to exit: {}", e);
std::process::exit(1);
}
},
Err(e) => {
error!("Unable to start nebula-cert binary: {}", e);
std::process::exit(1);
}
}
},
Commands::Run { name, server } => {
daemon::daemon_main(name, server);
}
@ -209,7 +199,7 @@ fn main() {
std::process::exit(1);
}
};
},
}
Commands::Update { name } => {
info!("Loading config...");
let config = match load_config(&name) {
@ -231,5 +221,11 @@ fn main() {
}
fn print_version() {
println!("tfclient v{} linked to trifid-pki v{}, embedding nebula v{} and nebula-cert v{}", env!("CARGO_PKG_VERSION"), trifid_pki::TRIFID_PKI_VERSION, crate::nebula_bin::NEBULA_VERSION, crate::nebula_cert_bin::NEBULA_CERT_VERSION);
}
println!(
"tfclient v{} linked to trifid-pki v{}, embedding nebula v{} and nebula-cert v{}",
env!("CARGO_PKG_VERSION"),
trifid_pki::TRIFID_PKI_VERSION,
crate::nebula_bin::NEBULA_VERSION,
crate::nebula_cert_bin::NEBULA_CERT_VERSION
);
}

View file

@ -1,29 +1,34 @@
// Code to handle the nebula worker
use std::error::Error;
use std::fs;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::time::{Duration, SystemTime};
use log::{debug, error, info};
use crate::config::{load_cdata, NebulaConfig, TFClientConfig};
use crate::daemon::ThreadMessageSender;
use crate::dirs::get_nebulaconfig_file;
use crate::embedded_nebula::run_embedded_nebula;
use log::{debug, error, info};
use std::error::Error;
use std::fs;
use std::sync::mpsc::Receiver;
use std::time::{Duration, SystemTime};
pub enum NebulaWorkerMessage {
Shutdown,
ConfigUpdated,
WakeUp
WakeUp,
}
fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
if !get_nebulaconfig_file(instance).ok_or("Could not get config file location")?.exists() {
if !get_nebulaconfig_file(instance)
.ok_or("Could not get config file location")?
.exists()
{
return Ok(()); // cant insert private key into a file that does not exist - BUT. we can gracefully handle nebula crashing - we cannot gracefully handle this fn failing
}
let cdata = load_cdata(instance)?;
let key = cdata.dh_privkey.ok_or("Missing private key")?;
let config_str = fs::read_to_string(get_nebulaconfig_file(instance).ok_or("Could not get config file location")?)?;
let config_str = fs::read_to_string(
get_nebulaconfig_file(instance).ok_or("Could not get config file location")?,
)?;
let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?;
config.pki.key = Some(String::from_utf8(key)?);
@ -31,12 +36,20 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
debug!("inserted private key into config: {:?}", config);
let config_str = serde_yaml::to_string(&config)?;
fs::write(get_nebulaconfig_file(instance).ok_or("Could not get config file location")?, config_str)?;
fs::write(
get_nebulaconfig_file(instance).ok_or("Could not get config file location")?,
config_str,
)?;
Ok(())
}
pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter: ThreadMessageSender, rx: Receiver<NebulaWorkerMessage>) {
pub fn nebulaworker_main(
_config: TFClientConfig,
instance: String,
_transmitter: ThreadMessageSender,
rx: Receiver<NebulaWorkerMessage>,
) {
let _cdata = match load_cdata(&instance) {
Ok(data) => data,
Err(e) => {
@ -50,7 +63,7 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
match insert_private_key(&instance) {
Ok(_) => {
info!("config fixed (private-key embedded)");
},
}
Err(e) => {
error!("unable to fix config: {}", e);
error!("nebula thread exiting with error");
@ -58,7 +71,14 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
}
}
info!("starting nebula child...");
let mut child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) {
let mut child = match run_embedded_nebula(&[
"-config".to_string(),
get_nebulaconfig_file(&instance)
.unwrap()
.to_str()
.unwrap()
.to_string(),
]) {
Ok(c) => c,
Err(e) => {
error!("unable to start embedded nebula binary: {}", e);
@ -75,7 +95,14 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
if let Ok(e) = child.try_wait() {
if e.is_some() && SystemTime::now() > last_restart_time + Duration::from_secs(5) {
info!("nebula process has exited, restarting");
child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) {
child = match run_embedded_nebula(&[
"-config".to_string(),
get_nebulaconfig_file(&instance)
.unwrap()
.to_str()
.unwrap()
.to_string(),
]) {
Ok(c) => c,
Err(e) => {
error!("unable to start embedded nebula binary: {}", e);
@ -88,59 +115,64 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
}
}
match rx.recv() {
Ok(msg) => {
match msg {
NebulaWorkerMessage::WakeUp => {
continue;
},
NebulaWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
info!("shutting down nebula binary");
match child.kill() {
Ok(_) => {
debug!("nebula process exited");
},
Err(e) => {
error!("nebula process already exited: {}", e);
}
Ok(msg) => match msg {
NebulaWorkerMessage::WakeUp => {
continue;
}
NebulaWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
info!("shutting down nebula binary");
match child.kill() {
Ok(_) => {
debug!("nebula process exited");
}
info!("nebula shut down");
break;
},
NebulaWorkerMessage::ConfigUpdated => {
info!("our configuration has been updated - restarting");
debug!("killing existing process");
match child.kill() {
Ok(_) => {
debug!("nebula process exited");
},
Err(e) => {
error!("nebula process already exited: {}", e);
}
Err(e) => {
error!("nebula process already exited: {}", e);
}
debug!("fixing config...");
match insert_private_key(&instance) {
Ok(_) => {
debug!("config fixed (private-key embedded)");
},
Err(e) => {
error!("unable to fix config: {}", e);
error!("nebula thread exiting with error");
return;
}
}
debug!("restarting nebula process");
child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) {
Ok(c) => c,
Err(e) => {
error!("unable to start embedded nebula binary: {}", e);
error!("nebula thread exiting with error");
return;
}
};
last_restart_time = SystemTime::now();
debug!("nebula process restarted");
}
info!("nebula shut down");
break;
}
NebulaWorkerMessage::ConfigUpdated => {
info!("our configuration has been updated - restarting");
debug!("killing existing process");
match child.kill() {
Ok(_) => {
debug!("nebula process exited");
}
Err(e) => {
error!("nebula process already exited: {}", e);
}
}
debug!("fixing config...");
match insert_private_key(&instance) {
Ok(_) => {
debug!("config fixed (private-key embedded)");
}
Err(e) => {
error!("unable to fix config: {}", e);
error!("nebula thread exiting with error");
return;
}
}
debug!("restarting nebula process");
child = match run_embedded_nebula(&[
"-config".to_string(),
get_nebulaconfig_file(&instance)
.unwrap()
.to_str()
.unwrap()
.to_string(),
]) {
Ok(c) => c,
Err(e) => {
error!("unable to start embedded nebula binary: {}", e);
error!("nebula thread exiting with error");
return;
}
};
last_restart_time = SystemTime::now();
debug!("nebula process restarted");
}
},
Err(e) => {
@ -149,4 +181,4 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter
}
}
}
}
}

View file

@ -1,13 +1,16 @@
use crate::config::TFClientConfig;
use crate::socketworker::{ctob, DisconnectReason, JsonMessage, JSON_API_VERSION};
use log::{error, info};
use std::error::Error;
use std::io::{BufRead, BufReader, Write};
use std::net::{IpAddr, SocketAddr, TcpStream};
use log::{error, info};
use crate::config::TFClientConfig;
use crate::socketworker::{ctob, DisconnectReason, JSON_API_VERSION, JsonMessage};
pub fn enroll(code: &str, config: &TFClientConfig) -> Result<(), Box<dyn Error>> {
info!("Connecting to local command socket...");
let mut stream = TcpStream::connect(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?;
let mut stream = TcpStream::connect(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
let stream2 = stream.try_clone()?;
let mut reader = BufReader::new(&stream2);
@ -52,7 +55,10 @@ pub fn enroll(code: &str, config: &TFClientConfig) -> Result<(), Box<dyn Error>>
pub fn update(config: &TFClientConfig) -> Result<(), Box<dyn Error>> {
info!("Connecting to local command socket...");
let mut stream = TcpStream::connect(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?;
let mut stream = TcpStream::connect(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
let stream2 = stream.try_clone()?;
let mut reader = BufReader::new(&stream2);
@ -98,4 +104,4 @@ fn read_msg(reader: &mut BufReader<&TcpStream>) -> Result<JsonMessage, Box<dyn E
reader.read_line(&mut str)?;
let msg: JsonMessage = serde_json::from_str(&str)?;
Ok(msg)
}
}

View file

@ -1,25 +1,30 @@
// Code to handle the nebula worker
use std::error::Error;
use std::{io, thread};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::io::{BufRead, BufReader, Write};
use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream};
use std::sync::mpsc::{Receiver, TryRecvError};
use std::sync::mpsc::Receiver;
use std::{io, thread};
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
use crate::apiworker::APIWorkerMessage;
use crate::config::{load_cdata, TFClientConfig};
use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage;
use crate::timerworker::TimerWorkerMessage;
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
pub enum SocketWorkerMessage {
Shutdown,
WakeUp
WakeUp,
}
pub fn socketworker_main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSender, rx: Receiver<SocketWorkerMessage>) {
pub fn socketworker_main(
config: TFClientConfig,
instance: String,
transmitter: ThreadMessageSender,
rx: Receiver<SocketWorkerMessage>,
) {
info!("socketworker_main called, entering realmain");
match _main(config, instance, transmitter, rx) {
Ok(_) => (),
@ -29,8 +34,16 @@ pub fn socketworker_main(config: TFClientConfig, instance: String, transmitter:
};
}
fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSender, rx: Receiver<SocketWorkerMessage>) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?;
fn _main(
config: TFClientConfig,
instance: String,
transmitter: ThreadMessageSender,
rx: Receiver<SocketWorkerMessage>,
) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
listener.set_nonblocking(true)?;
loop {
@ -47,21 +60,21 @@ fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSen
}
}
});
},
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => { Err(e)?; }
Err(e) => {
Err(e)?;
}
}
match rx.recv() {
Ok(msg) => {
match msg {
SocketWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
break;
},
SocketWorkerMessage::WakeUp => {
continue;
}
Ok(msg) => match msg {
SocketWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
break;
}
SocketWorkerMessage::WakeUp => {
continue;
}
},
Err(e) => {
@ -74,22 +87,27 @@ fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSen
Ok(())
}
fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSender, config: TFClientConfig, instance: String) -> Result<(), io::Error> {
fn handle_stream(
stream: (TcpStream, SocketAddr),
transmitter: ThreadMessageSender,
config: TFClientConfig,
instance: String,
) -> Result<(), io::Error> {
info!("Incoming client");
match handle_client(stream.0, transmitter, config, instance) {
Ok(()) => (),
Err(e) if e.kind() == io::ErrorKind::TimedOut => {
warn!("Client timed out, connection aborted");
},
}
Err(e) if e.kind() == io::ErrorKind::NotConnected => {
warn!("Client connection severed");
},
}
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => {
warn!("Client connection returned error: broken pipe");
},
}
Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
warn!("Client aborted connection");
},
}
Err(e) => {
error!("Error in client handler: {}", e);
return Err(e);
@ -98,15 +116,18 @@ fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSend
Ok(())
}
fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TFClientConfig, instance: String) -> Result<(), io::Error> {
fn handle_client(
stream: TcpStream,
transmitter: ThreadMessageSender,
_config: TFClientConfig,
instance: String,
) -> Result<(), io::Error> {
info!("Handling connection from {}", stream.peer_addr()?);
let mut client = Client {
state: ClientState::WaitHello,
reader: BufReader::new(&stream),
writer: BufWriter::new(&stream),
stream: &stream,
config,
instance,
};
@ -118,18 +139,14 @@ fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TF
trace!("recv {:?} from {}", command, client.stream.peer_addr()?);
let should_disconnect;
let should_disconnect = match client.state {
ClientState::WaitHello => waithello_handle(&mut client, &transmitter, command)?,
ClientState::SentHello => senthello_handle(&mut client, &transmitter, command)?,
};
match client.state {
ClientState::WaitHello => {
should_disconnect = waithello_handle(&mut client, &transmitter, command)?;
}
ClientState::SentHello => {
should_disconnect = senthello_handle(&mut client, &transmitter, command)?;
}
if should_disconnect {
break;
}
if should_disconnect { break; }
}
// Gracefully close the connection
@ -141,13 +158,15 @@ fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TF
struct Client<'a> {
state: ClientState,
reader: BufReader<&'a TcpStream>,
writer: BufWriter<&'a TcpStream>,
stream: &'a TcpStream,
config: TFClientConfig,
instance: String
instance: String,
}
fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, command: JsonMessage) -> Result<bool, io::Error> {
fn waithello_handle(
client: &mut Client,
_transmitter: &ThreadMessageSender,
command: JsonMessage,
) -> Result<bool, io::Error> {
trace!("state: WaitHello, handing with waithello_handle");
let mut should_disconnect = false;
@ -158,20 +177,20 @@ fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, com
client.stream.write_all(&ctob(JsonMessage::Goodbye {
reason: DisconnectReason::UnsupportedVersion {
expected: JSON_API_VERSION,
got: version
}
got: version,
},
}))?;
}
client.stream.write_all(&ctob(JsonMessage::Hello {
version: JSON_API_VERSION
version: JSON_API_VERSION,
}))?;
client.state = ClientState::SentHello;
trace!("setting state to SentHello");
},
}
JsonMessage::Goodbye { reason } => {
info!("Client sent disconnect: {:?}", reason);
should_disconnect = true;
},
}
_ => {
debug!("message type unexpected in WaitHello state");
should_disconnect = true;
@ -184,7 +203,11 @@ fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, com
Ok(should_disconnect)
}
fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, command: JsonMessage) -> Result<bool, io::Error> {
fn senthello_handle(
client: &mut Client,
transmitter: &ThreadMessageSender,
command: JsonMessage,
) -> Result<bool, io::Error> {
trace!("state: SentHello, handing with senthello_handle");
let mut should_disconnect = false;
@ -192,14 +215,20 @@ fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, comm
JsonMessage::Goodbye { reason } => {
info!("Client sent disconnect: {:?}", reason);
should_disconnect = true;
},
}
JsonMessage::Shutdown {} => {
info!("Requested to shutdown by local control socket. Sending shutdown message to threads");
match transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) {
match transmitter
.nebula_thread
.send(NebulaWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to nebula worker thread: {}", e);
error!(
"Error sending shutdown message to nebula worker thread: {}",
e
);
}
}
match transmitter.api_thread.send(APIWorkerMessage::Shutdown) {
@ -208,19 +237,28 @@ fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, comm
error!("Error sending shutdown message to api worker thread: {}", e);
}
}
match transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) {
match transmitter
.socket_thread
.send(SocketWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to socket worker thread: {}", e);
error!(
"Error sending shutdown message to socket worker thread: {}",
e
);
}
}
match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) {
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to timer worker thread: {}", e);
error!(
"Error sending shutdown message to timer worker thread: {}",
e
);
}
}
},
}
JsonMessage::GetHostID {} => {
let data = match load_cdata(&client.instance) {
@ -232,20 +270,26 @@ fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, comm
};
client.stream.write_all(&ctob(JsonMessage::HostID {
has_id: data.creds.is_some(),
id: data.creds.map(|c| c.host_id)
id: data.creds.map(|c| c.host_id),
}))?;
},
}
JsonMessage::Enroll { code } => {
info!("Client sent enroll with code {}", code);
info!("Sending enroll request to apiworker");
transmitter.api_thread.send(APIWorkerMessage::Enroll { code }).unwrap();
},
transmitter
.api_thread
.send(APIWorkerMessage::Enroll { code })
.unwrap();
}
JsonMessage::Update {} => {
info!("Client sent update request.");
info!("Telling apiworker to update configuration");
transmitter.api_thread.send(APIWorkerMessage::Update).unwrap();
transmitter
.api_thread
.send(APIWorkerMessage::Update)
.unwrap();
}
_ => {
@ -267,7 +311,7 @@ pub fn ctob(command: JsonMessage) -> Vec<u8> {
enum ClientState {
WaitHello,
SentHello
SentHello,
}
pub const JSON_API_VERSION: i32 = 1;
@ -276,28 +320,19 @@ pub const JSON_API_VERSION: i32 = 1;
#[serde(tag = "method")]
pub enum JsonMessage {
#[serde(rename = "hello")]
Hello {
version: i32
},
Hello { version: i32 },
#[serde(rename = "goodbye")]
Goodbye {
reason: DisconnectReason
},
Goodbye { reason: DisconnectReason },
#[serde(rename = "shutdown")]
Shutdown {},
#[serde(rename = "get_host_id")]
GetHostID {},
#[serde(rename = "host_id")]
HostID {
has_id: bool,
id: Option<String>
},
HostID { has_id: bool, id: Option<String> },
#[serde(rename = "enroll")]
Enroll {
code: String
},
Enroll { code: String },
#[serde(rename = "update")]
Update {}
Update {},
}
#[derive(Serialize, Deserialize, Debug)]
@ -308,5 +343,5 @@ pub enum DisconnectReason {
#[serde(rename = "unexpected_message_type")]
UnexpectedMessageType,
#[serde(rename = "done")]
Done
}
Done,
}

View file

@ -1,15 +1,15 @@
use std::ops::Add;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::thread;
use std::time::{Duration, SystemTime};
use log::{error, info};
use crate::apiworker::APIWorkerMessage;
use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage;
use crate::socketworker::SocketWorkerMessage;
use log::{error, info};
use std::ops::Add;
use std::sync::mpsc::{Receiver, TryRecvError};
use std::thread;
use std::time::{Duration, SystemTime};
pub enum TimerWorkerMessage {
Shutdown
Shutdown,
}
pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) {
@ -19,23 +19,19 @@ pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) {
thread::sleep(Duration::from_secs(10));
match rx.try_recv() {
Ok(msg) => {
match msg {
TimerWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
break;
}
Ok(msg) => match msg {
TimerWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
break;
}
},
Err(e) => {
match e {
TryRecvError::Empty => {}
TryRecvError::Disconnected => {
error!("timerworker command socket disconnected, shutting down to prevent orphaning");
break;
}
Err(e) => match e {
TryRecvError::Empty => {}
TryRecvError::Disconnected => {
error!("timerworker command socket disconnected, shutting down to prevent orphaning");
break;
}
}
},
}
if SystemTime::now().gt(&api_reload_timer) {
@ -52,15 +48,21 @@ pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>) {
match tx.nebula_thread.send(NebulaWorkerMessage::WakeUp) {
Ok(_) => (),
Err(e) => {
error!("Error sending wakeup message to nebula worker thread: {}", e);
error!(
"Error sending wakeup message to nebula worker thread: {}",
e
);
}
}
match tx.socket_thread.send(SocketWorkerMessage::WakeUp) {
Ok(_) => (),
Err(e) => {
error!("Error sending wakeup message to socket worker thread: {}", e);
error!(
"Error sending wakeup message to socket worker thread: {}",
e
);
}
}
}
}
}

View file

@ -1,6 +1,6 @@
use log::{error, warn};
use sha2::Sha256;
use sha2::Digest;
use sha2::Sha256;
use url::Url;
pub fn sha256(bytes: &[u8]) -> String {
@ -11,7 +11,7 @@ pub fn sha256(bytes: &[u8]) -> String {
}
pub fn check_server_url(server: &str) {
let api_base = match Url::parse(&server) {
let api_base = match Url::parse(server) {
Ok(u) => u,
Err(e) => {
error!("Invalid server url `{}`: {}", server, e);
@ -19,11 +19,16 @@ pub fn check_server_url(server: &str) {
}
};
match api_base.scheme() {
"http" => { warn!("HTTP api urls are not reccomended. Please switch to HTTPS if possible.") },
"http" => {
warn!("HTTP api urls are not reccomended. Please switch to HTTPS if possible.")
}
"https" => (),
_ => {
error!("Unsupported protocol `{}` (expected one of http, https)", api_base.scheme());
error!(
"Unsupported protocol `{}` (expected one of http, https)",
api_base.scheme()
);
std::process::exit(1);
}
}
}
}

View file

@ -12,8 +12,9 @@ actix-request-identifier = "4" # Web framework
serde = { version = "1", features = ["derive"] } # Serialization and deserialization
serde_json = "1.0.95" # Serialization and deserialization (cursors)
once_cell = "1" # Config
toml = "0.7" # Config / Serialization and deserialization
once_cell = "1" # Config
toml = "0.7" # Config / Serialization and deserialization
serde_yaml = "0.9.21" # Config / Serialization and deserialization
log = "0.4" # Logging
simple_logger = "4" # Logging
@ -28,5 +29,9 @@ totp-rs = { version = "5.0.1", features = ["gen_secret", "otpauth"] } # Misc.
base64 = "0.21.0" # Misc.
chrono = "0.4.24" # Misc.
trifid-pki = { version = "0.1.9" } # Cryptography
aes-gcm = "0.10.1" # Cryptography
trifid-pki = { version = "0.1.9", features = ["serde_derive"] } # Cryptography
aes-gcm = "0.10.1" # Cryptography
ed25519-dalek = "2.0.0-rc.2" # Cryptography
dnapi-rs = "0.1.9" # API message types
ipnet = "2.7.2" # API message types

View file

@ -120,4 +120,10 @@ url = "your-database-url-here"
# ------- WARNING -------
# Do not change this value in a production instance. It will make existing data inaccessible until changed back.
# ------- WARNING -------
data-key = "edd600bcebea461381ea23791b6967c8667e12827ac8b94dc022f189a5dc59a2"
data-key = "edd600bcebea461381ea23791b6967c8667e12827ac8b94dc022f189a5dc59a2"
# The data directory used for storing keys, configuration, signing keys, etc. Must be writable by this instance.
# This directory will be used to store very sensitive data - protect it like a password! It should be writable by
# this instance and ONLY this instance.
# Do not modify any files in this directory manually unless directed to do so by trifid.
local_keystore_directory = "./trifid_data"

View file

@ -0,0 +1,338 @@
use std::collections::HashMap;
use std::error::Error;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::time::{Duration, SystemTime};
use actix_web::web::Data;
use crate::config::{NebulaConfig, NebulaConfigCipher, NebulaConfigLighthouse, NebulaConfigListen, NebulaConfigPki, NebulaConfigPunchy, NebulaConfigRelay, NebulaConfigTun, CONFIG, NebulaConfigFirewall, NebulaConfigFirewallRule};
use crate::crypto::{decrypt_with_nonce, encrypt_with_nonce, get_cipher_from_config};
use crate::AppState;
use ed25519_dalek::SigningKey;
use ipnet::Ipv4Net;
use log::{debug, error};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use trifid_api_entities::entity::{firewall_rule, host, host_config_override, host_static_address, network, organization, signing_ca};
use trifid_pki::cert::{
deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate,
NebulaCertificateDetails,
};
pub struct CodegenRequiredInfo {
pub host: host::Model,
pub host_static_addresses: HashMap<String, Vec<SocketAddrV4>>,
pub host_config_overrides: Vec<host_config_override::Model>,
pub network: network::Model,
pub organization: organization::Model,
pub dh_pubkey: Vec<u8>,
pub ca: signing_ca::Model,
pub other_cas: Vec<signing_ca::Model>,
pub relay_ips: Vec<Ipv4Addr>,
pub lighthouse_ips: Vec<Ipv4Addr>,
pub blocked_hosts: Vec<String>,
pub firewall_rules: Vec<NebulaConfigFirewallRule>
}
pub async fn generate_config(
data: &Data<AppState>,
info: &CodegenRequiredInfo,
) -> Result<(NebulaConfig, NebulaCertificate), Box<dyn Error>> {
debug!("chk: deserialize CA cert {:x?}", hex::decode(&info.ca.cert)?);
// decode the CA data
let ca_cert = deserialize_nebula_certificate_from_pem(&hex::decode(&info.ca.cert)?)?;
// generate the client's new cert
let mut cert = NebulaCertificate {
details: NebulaCertificateDetails {
name: info.host.name.clone(),
ips: vec![Ipv4Net::new(
Ipv4Addr::from_str(&info.host.ip).unwrap(),
Ipv4Net::from_str(&info.network.cidr).unwrap().prefix_len(),
)
.unwrap()],
subnets: vec![],
groups: vec![
format!("role:{}", info.host.role)
],
not_before: SystemTime::now(),
not_after: SystemTime::now() + Duration::from_secs(CONFIG.crypto.certs_expiry_time),
public_key: info.dh_pubkey.clone().try_into().unwrap(),
is_ca: false,
issuer: ca_cert.sha256sum()?,
},
signature: vec![],
};
// decrypt the private key
let private_pem = decrypt_with_nonce(
&hex::decode(&info.ca.key)?,
hex::decode(&info.ca.nonce)?.try_into().unwrap(),
&get_cipher_from_config(&CONFIG)?,
)
.map_err(|_| "Encryption error")?;
let private_key = deserialize_ed25519_private(&private_pem)?;
let signing_key = SigningKey::from_keypair_bytes(&private_key.try_into().unwrap()).unwrap();
cert.sign(&signing_key)?;
// cas
let mut cas = String::new();
for ca in &info.other_cas {
cas += &String::from_utf8(hex::decode(&ca.cert)?)?;
}
// blocked hosts
let mut blocked_hosts_fingerprints = vec![];
for host in &info.blocked_hosts {
if let Some(host) = data.keystore.hosts.iter().find(|u| &u.id == host) {
for cert in &host.certs {
blocked_hosts_fingerprints.push(cert.cert.sha256sum()?);
}
}
}
let nebula_config = NebulaConfig {
pki: NebulaConfigPki {
ca: cas,
cert: String::from_utf8(cert.serialize_to_pem()?)?,
key: None,
blocklist: blocked_hosts_fingerprints,
disconnect_invalid: true,
},
static_host_map: info
.host_static_addresses
.iter()
.map(|(u, addrs)| (Ipv4Addr::from_str(u).unwrap(), addrs.clone()))
.collect(),
lighthouse: match info.host.is_lighthouse {
true => Some(NebulaConfigLighthouse {
am_lighthouse: true,
serve_dns: false,
dns: None,
interval: 60,
hosts: vec![],
remote_allow_list: HashMap::new(),
local_allow_list: HashMap::new(),
}),
false => Some(NebulaConfigLighthouse {
am_lighthouse: false,
serve_dns: false,
dns: None,
interval: 60,
hosts: info.lighthouse_ips.to_vec(),
remote_allow_list: HashMap::new(),
local_allow_list: HashMap::new(),
}),
},
listen: match info.host.is_lighthouse || info.host.is_relay {
true => Some(NebulaConfigListen {
host: "[::]".to_string(),
port: info.host.listen_port as u16,
batch: 64,
read_buffer: Some(10485760),
write_buffer: Some(10485760),
}),
false => Some(NebulaConfigListen {
host: "[::]".to_string(),
port: 0u16,
batch: 64,
read_buffer: Some(10485760),
write_buffer: Some(10485760),
}),
},
punchy: Some(NebulaConfigPunchy {
punch: true,
respond: true,
delay: "".to_string(),
}),
cipher: NebulaConfigCipher::Aes,
preferred_ranges: vec![],
relay: Some(NebulaConfigRelay {
relays: info.relay_ips.to_vec(),
am_relay: info.host.is_relay,
use_relays: true,
}),
tun: Some(NebulaConfigTun {
disabled: false,
dev: Some("trifid1".to_string()),
drop_local_broadcast: true,
drop_multicast: true,
tx_queue: 500,
mtu: 1300,
routes: vec![],
unsafe_routes: vec![],
}),
logging: None,
sshd: None,
firewall: Some(NebulaConfigFirewall {
conntrack: None,
inbound: Some(info.firewall_rules.clone()),
outbound: Some(vec![
NebulaConfigFirewallRule {
port: Some("any".to_string()),
proto: Some("any".to_string()),
ca_name: None,
ca_sha: None,
host: Some("any".to_string()),
group: None,
groups: None,
cidr: None,
}
]),
}),
routines: 0,
stats: None,
local_range: None,
};
Ok((nebula_config, cert))
}
pub async fn collect_info<'a>(
db: &'a Data<AppState>,
host: &'a str,
dh_pubkey: &'a [u8],
) -> Result<CodegenRequiredInfo, Box<dyn Error>> {
// load host info
let host = trifid_api_entities::entity::host::Entity::find()
.filter(host::Column::Id.eq(host))
.one(&db.conn)
.await?;
let host = match host {
Some(host) => host,
None => return Err("Host does not exist".into()),
};
let host_config_overrides = trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Id.eq(&host.id))
.all(&db.conn)
.await?;
let _host_static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
.filter(host_static_address::Column::Id.eq(&host.id))
.all(&db.conn)
.await?;
// load network info
let network = trifid_api_entities::entity::network::Entity::find()
.filter(network::Column::Id.eq(&host.network))
.one(&db.conn)
.await?;
let network = match network {
Some(network) => network,
None => {
return Err("Network does not exist".into());
}
};
// get all lighthouses and relays and get all of their static addresses, and get internal addresses of relays
let mut host_x_static_addresses = HashMap::new();
let mut relays = vec![];
let mut lighthouses = vec![];
let mut blocked_hosts = vec![];
let hosts = trifid_api_entities::entity::host::Entity::find()
.filter(host::Column::Network.eq(&network.id))
.filter(host::Column::IsRelay.eq(true))
.filter(host::Column::IsLighthouse.eq(true))
.all(&db.conn)
.await?;
for host in hosts {
if host.is_relay {
relays.push(Ipv4Addr::from_str(&host.ip).unwrap());
} else if host.is_lighthouse {
lighthouses.push(Ipv4Addr::from_str(&host.ip).unwrap());
}
if host.is_blocked {
blocked_hosts.push(host.id.clone());
}
let static_addresses = trifid_api_entities::entity::host_static_address::Entity::find()
.filter(host_static_address::Column::Host.eq(host.id))
.all(&db.conn)
.await?;
let static_addresses: Vec<SocketAddrV4> = static_addresses
.iter()
.map(|u| SocketAddrV4::from_str(&u.address).unwrap())
.collect();
host_x_static_addresses.insert(host.ip.clone(), static_addresses);
}
// load org info
let org = trifid_api_entities::entity::organization::Entity::find()
.filter(organization::Column::Id.eq(&network.organization))
.one(&db.conn)
.await?;
let org = match org {
Some(org) => org,
None => {
return Err("Organization does not exist".into());
}
};
// get the CA that is closest to expiry, but *not* expired
let available_cas = trifid_api_entities::entity::signing_ca::Entity::find()
.filter(signing_ca::Column::Organization.eq(&org.id))
.all(&db.conn)
.await?;
let mut best_ca: Option<signing_ca::Model> = None;
let mut all_cas = vec![];
for ca in available_cas {
if let Some(existing_best) = &best_ca {
if ca.expires < existing_best.expires {
best_ca = Some(ca.clone());
}
} else {
best_ca = Some(ca.clone());
}
all_cas.push(ca);
}
if best_ca.is_none() {
error!(
"!!! NO AVAILABLE CAS !!! while trying to sign cert for {}",
org.id
);
return Err("No signing CAs available".into());
}
let best_ca = best_ca.unwrap();
// pull our role's firewall rules
let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find().filter(firewall_rule::Column::Role.eq(&host.id)).all(&db.conn).await?;
let firewall_rules = firewall_rules.iter().map(|u| {
NebulaConfigFirewallRule {
port: Some(if u.port_range_from == 0 && u.port_range_to == 65535 { "any".to_string() } else { format!("{}-{}", u.port_range_from, u.port_range_to) }),
proto: Some(u.protocol.clone()),
ca_name: None,
ca_sha: None,
host: if u.allowed_role_id.is_some() { None } else { Some("any".to_string()) },
groups: if u.allowed_role_id.is_some() { Some(vec![format!("role:{}", u.allowed_role_id.clone().unwrap())])} else { None },
group: None,
cidr: None,
}
}).collect();
Ok(CodegenRequiredInfo {
host,
host_static_addresses: host_x_static_addresses,
host_config_overrides,
network,
organization: org,
dh_pubkey: dh_pubkey.to_vec(),
ca: best_ca,
other_cas: all_cas,
relay_ips: relays,
lighthouse_ips: lighthouses,
blocked_hosts,
firewall_rules
})
}

View file

@ -14,11 +14,14 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use ipnet::{IpNet, Ipv4Net};
use log::error;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::path::PathBuf;
pub static CONFIG: Lazy<TrifidConfig> = Lazy::new(|| {
let config_str = match fs::read_to_string("/etc/trifid/config.toml") {
@ -88,6 +91,9 @@ pub struct TrifidConfigTokens {
#[derive(Serialize, Deserialize, Debug)]
pub struct TrifidConfigCryptography {
pub data_encryption_key: String,
pub local_keystore_directory: PathBuf,
#[serde(default = "certs_expiry_time")]
pub certs_expiry_time: u64,
}
fn max_connections_default() -> u32 {
@ -120,3 +126,534 @@ fn mfa_tokens_expiry_time() -> u64 {
fn enrollment_tokens_expiry_time() -> u64 {
600
} // 10 minutes
fn certs_expiry_time() -> u64 {
3600 * 24 * 31 * 12 // 1 year
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfig {
pub pki: NebulaConfigPki,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub static_host_map: HashMap<Ipv4Addr, Vec<SocketAddrV4>>,
#[serde(skip_serializing_if = "is_none")]
pub lighthouse: Option<NebulaConfigLighthouse>,
#[serde(skip_serializing_if = "is_none")]
pub listen: Option<NebulaConfigListen>,
#[serde(skip_serializing_if = "is_none")]
pub punchy: Option<NebulaConfigPunchy>,
#[serde(default = "cipher_aes")]
#[serde(skip_serializing_if = "is_cipher_aes")]
pub cipher: NebulaConfigCipher,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub preferred_ranges: Vec<IpNet>,
#[serde(skip_serializing_if = "is_none")]
pub relay: Option<NebulaConfigRelay>,
#[serde(skip_serializing_if = "is_none")]
pub tun: Option<NebulaConfigTun>,
#[serde(skip_serializing_if = "is_none")]
pub logging: Option<NebulaConfigLogging>,
#[serde(skip_serializing_if = "is_none")]
pub sshd: Option<NebulaConfigSshd>,
#[serde(skip_serializing_if = "is_none")]
pub firewall: Option<NebulaConfigFirewall>,
#[serde(default = "u64_1")]
#[serde(skip_serializing_if = "is_u64_1")]
pub routines: u64,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub stats: Option<NebulaConfigStats>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub local_range: Option<Ipv4Net>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigPki {
pub ca: String,
pub cert: String,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub key: Option<String>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub blocklist: Vec<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disconnect_invalid: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLighthouse {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub am_lighthouse: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub serve_dns: bool,
#[serde(skip_serializing_if = "is_none")]
pub dns: Option<NebulaConfigLighthouseDns>,
#[serde(default = "u32_10")]
#[serde(skip_serializing_if = "is_u32_10")]
pub interval: u32,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub hosts: Vec<Ipv4Addr>,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub remote_allow_list: HashMap<Ipv4Net, bool>,
#[serde(default = "empty_hashmap")]
#[serde(skip_serializing_if = "is_empty_hashmap")]
pub local_allow_list: HashMap<Ipv4Net, bool>, // `interfaces` is not supported
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLighthouseDns {
#[serde(default = "string_empty")]
#[serde(skip_serializing_if = "is_string_empty")]
pub host: String,
#[serde(default = "u16_53")]
#[serde(skip_serializing_if = "is_u16_53")]
pub port: u16,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigListen {
#[serde(default = "string_empty")]
#[serde(skip_serializing_if = "is_string_empty")]
pub host: String,
#[serde(default = "u16_0")]
#[serde(skip_serializing_if = "is_u16_0")]
pub port: u16,
#[serde(default = "u32_64")]
#[serde(skip_serializing_if = "is_u32_64")]
pub batch: u32,
#[serde(skip_serializing_if = "is_none")]
pub read_buffer: Option<u32>,
#[serde(skip_serializing_if = "is_none")]
pub write_buffer: Option<u32>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigPunchy {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub punch: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub respond: bool,
#[serde(default = "string_1s")]
#[serde(skip_serializing_if = "is_string_1s")]
pub delay: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigCipher {
#[serde(rename = "aes")]
Aes,
#[serde(rename = "chachapoly")]
ChaChaPoly,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigRelay {
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub relays: Vec<Ipv4Addr>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub am_relay: bool,
#[serde(default = "bool_true")]
#[serde(skip_serializing_if = "is_bool_true")]
pub use_relays: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTun {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disabled: bool,
#[serde(skip_serializing_if = "is_none")]
pub dev: Option<String>,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub drop_local_broadcast: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub drop_multicast: bool,
#[serde(default = "u64_500")]
#[serde(skip_serializing_if = "is_u64_500")]
pub tx_queue: u64,
#[serde(default = "u64_1300")]
#[serde(skip_serializing_if = "is_u64_1300")]
pub mtu: u64,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub routes: Vec<NebulaConfigTunRouteOverride>,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub unsafe_routes: Vec<NebulaConfigTunUnsafeRoute>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunRouteOverride {
pub mtu: u64,
pub route: Ipv4Net,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigTunUnsafeRoute {
pub route: Ipv4Net,
pub via: Ipv4Addr,
#[serde(default = "u64_1300")]
#[serde(skip_serializing_if = "is_u64_1300")]
pub mtu: u64,
#[serde(default = "i64_100")]
#[serde(skip_serializing_if = "is_i64_100")]
pub metric: i64,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigLogging {
#[serde(default = "loglevel_info")]
#[serde(skip_serializing_if = "is_loglevel_info")]
pub level: NebulaConfigLoggingLevel,
#[serde(default = "format_text")]
#[serde(skip_serializing_if = "is_format_text")]
pub format: NebulaConfigLoggingFormat,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub disable_timestamp: bool,
#[serde(default = "timestamp")]
#[serde(skip_serializing_if = "is_timestamp")]
pub timestamp_format: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigLoggingLevel {
#[serde(rename = "panic")]
Panic,
#[serde(rename = "fatal")]
Fatal,
#[serde(rename = "error")]
Error,
#[serde(rename = "warning")]
Warning,
#[serde(rename = "info")]
Info,
#[serde(rename = "debug")]
Debug,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigLoggingFormat {
#[serde(rename = "json")]
Json,
#[serde(rename = "text")]
Text,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigSshd {
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub enabled: bool,
pub listen: SocketAddrV4,
pub host_key: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub authorized_users: Vec<NebulaConfigSshdAuthorizedUser>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigSshdAuthorizedUser {
pub user: String,
#[serde(default = "empty_vec")]
#[serde(skip_serializing_if = "is_empty_vec")]
pub keys: Vec<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "type")]
pub enum NebulaConfigStats {
#[serde(rename = "graphite")]
Graphite(NebulaConfigStatsGraphite),
#[serde(rename = "prometheus")]
Prometheus(NebulaConfigStatsPrometheus),
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigStatsGraphite {
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub prefix: String,
#[serde(default = "protocol_tcp")]
#[serde(skip_serializing_if = "is_protocol_tcp")]
pub protocol: NebulaConfigStatsGraphiteProtocol,
pub host: SocketAddrV4,
pub interval: String,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum NebulaConfigStatsGraphiteProtocol {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "udp")]
Udp,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigStatsPrometheus {
pub listen: String,
pub path: String,
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub namespace: String,
#[serde(default = "string_nebula")]
#[serde(skip_serializing_if = "is_string_nebula")]
pub subsystem: String,
pub interval: String,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub message_metrics: bool,
#[serde(default = "bool_false")]
#[serde(skip_serializing_if = "is_bool_false")]
pub lighthouse_metrics: bool,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewall {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub conntrack: Option<NebulaConfigFirewallConntrack>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub inbound: Option<Vec<NebulaConfigFirewallRule>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub outbound: Option<Vec<NebulaConfigFirewallRule>>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewallConntrack {
#[serde(default = "string_12m")]
#[serde(skip_serializing_if = "is_string_12m")]
pub tcp_timeout: String,
#[serde(default = "string_3m")]
#[serde(skip_serializing_if = "is_string_3m")]
pub udp_timeout: String,
#[serde(default = "string_10m")]
#[serde(skip_serializing_if = "is_string_10m")]
pub default_timeout: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct NebulaConfigFirewallRule {
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub port: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub proto: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub ca_name: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub ca_sha: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub host: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub group: Option<String>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub groups: Option<Vec<String>>,
#[serde(default = "none")]
#[serde(skip_serializing_if = "is_none")]
pub cidr: Option<String>,
}
// Default values for serde
fn string_12m() -> String {
"12m".to_string()
}
fn is_string_12m(s: &str) -> bool {
s == "12m"
}
fn string_3m() -> String {
"3m".to_string()
}
fn is_string_3m(s: &str) -> bool {
s == "3m"
}
fn string_10m() -> String {
"10m".to_string()
}
fn is_string_10m(s: &str) -> bool {
s == "10m"
}
fn empty_vec<T>() -> Vec<T> {
vec![]
}
fn is_empty_vec<T>(v: &Vec<T>) -> bool {
v.is_empty()
}
fn empty_hashmap<A, B>() -> HashMap<A, B> {
HashMap::new()
}
fn is_empty_hashmap<A, B>(h: &HashMap<A, B>) -> bool {
h.is_empty()
}
fn bool_false() -> bool {
false
}
fn is_bool_false(b: &bool) -> bool {
!*b
}
fn bool_true() -> bool {
true
}
fn is_bool_true(b: &bool) -> bool {
*b
}
fn u16_53() -> u16 {
53
}
fn is_u16_53(u: &u16) -> bool {
*u == 53
}
fn u32_10() -> u32 {
10
}
fn is_u32_10(u: &u32) -> bool {
*u == 10
}
fn u16_0() -> u16 {
0
}
fn is_u16_0(u: &u16) -> bool {
*u == 0
}
fn u32_64() -> u32 {
64
}
fn is_u32_64(u: &u32) -> bool {
*u == 64
}
fn string_1s() -> String {
"1s".to_string()
}
fn is_string_1s(s: &str) -> bool {
s == "1s"
}
fn cipher_aes() -> NebulaConfigCipher {
NebulaConfigCipher::Aes
}
fn is_cipher_aes(c: &NebulaConfigCipher) -> bool {
matches!(c, NebulaConfigCipher::Aes)
}
fn u64_500() -> u64 {
500
}
fn is_u64_500(u: &u64) -> bool {
*u == 500
}
fn u64_1300() -> u64 {
1300
}
fn is_u64_1300(u: &u64) -> bool {
*u == 1300
}
fn i64_100() -> i64 {
100
}
fn is_i64_100(i: &i64) -> bool {
*i == 100
}
fn loglevel_info() -> NebulaConfigLoggingLevel {
NebulaConfigLoggingLevel::Info
}
fn is_loglevel_info(l: &NebulaConfigLoggingLevel) -> bool {
matches!(l, NebulaConfigLoggingLevel::Info)
}
fn format_text() -> NebulaConfigLoggingFormat {
NebulaConfigLoggingFormat::Text
}
fn is_format_text(f: &NebulaConfigLoggingFormat) -> bool {
matches!(f, NebulaConfigLoggingFormat::Text)
}
fn timestamp() -> String {
"2006-01-02T15:04:05Z07:00".to_string()
}
fn is_timestamp(s: &str) -> bool {
s == "2006-01-02T15:04:05Z07:00"
}
fn u64_1() -> u64 {
1
}
fn is_u64_1(u: &u64) -> bool {
*u == 1
}
fn string_nebula() -> String {
"nebula".to_string()
}
fn is_string_nebula(s: &str) -> bool {
s == "nebula"
}
fn string_empty() -> String {
String::new()
}
fn is_string_empty(s: &str) -> bool {
s.is_empty()
}
fn protocol_tcp() -> NebulaConfigStatsGraphiteProtocol {
NebulaConfigStatsGraphiteProtocol::Tcp
}
fn is_protocol_tcp(p: &NebulaConfigStatsGraphiteProtocol) -> bool {
matches!(p, NebulaConfigStatsGraphiteProtocol::Tcp)
}
fn none<T>() -> Option<T> {
None
}
fn is_none<T>(o: &Option<T>) -> bool {
o.is_none()
}

View file

@ -0,0 +1,83 @@
use crate::config::{NebulaConfig, CONFIG};
use ed25519_dalek::{SigningKey, VerifyingKey};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use trifid_pki::cert::NebulaCertificate;
use trifid_pki::x25519_dalek::PublicKey;
#[derive(Serialize, Deserialize)]
pub struct Keystore {
#[serde(default = "default_vec")]
pub hosts: Vec<KeystoreHostInformation>,
}
fn default_vec<T>() -> Vec<T> {
vec![]
}
pub fn keystore_init() -> Result<Keystore, Box<dyn Error>> {
let mut ks_fp = CONFIG.crypto.local_keystore_directory.clone();
ks_fp.push("/tfks.toml");
if !ks_fp.exists() {
return Ok(Keystore {
hosts: vec![]
})
}
let f_str = fs::read_to_string(ks_fp)?;
let keystore: Keystore = toml::from_str(&f_str)?;
Ok(keystore)
}
pub fn keystore_flush(ks: &Keystore) -> Result<(), Box<dyn Error>> {
let mut ks_fp = CONFIG.crypto.local_keystore_directory.clone();
ks_fp.push("/tfks.toml");
fs::write(ks_fp, toml::to_string(ks)?)?;
Ok(())
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KeystoreHostInformation {
pub id: String,
pub current_signing_key: u64,
pub current_client_key: u64,
pub current_config: u64,
pub current_cert: u64,
pub certs: Vec<KSCert>,
pub config: Vec<KSConfig>,
pub signing_keys: Vec<KSSigningKey>,
pub client_keys: Vec<KSClientKey>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSCert {
pub id: u64,
pub cert: NebulaCertificate,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSConfig {
pub id: u64,
pub config: NebulaConfig,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSSigningKey {
pub id: u64,
pub key: SigningKey,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct KSClientKey {
pub id: u64,
pub dh_pub: PublicKey,
pub ed_pub: VerifyingKey,
}

View file

@ -26,14 +26,17 @@ use std::time::Duration;
use crate::config::CONFIG;
use crate::error::{APIError, APIErrorsResponse};
use crate::keystore::{keystore_init, Keystore};
use crate::tokens::random_id_no_id;
use trifid_api_migration::{Migrator, MigratorTrait};
pub mod auth_tokens;
pub mod codegen;
pub mod config;
pub mod crypto;
pub mod cursor;
pub mod error;
pub mod keystore;
pub mod magic_link;
pub mod routes;
pub mod timers;
@ -41,12 +44,17 @@ pub mod tokens;
pub struct AppState {
pub conn: DatabaseConnection,
pub keystore: Keystore,
}
#[actix_web::main]
async fn main() -> Result<(), Box<dyn Error>> {
simple_logger::init_with_level(Level::Debug).unwrap();
info!("Creating keystore...");
let keystore = keystore_init()?;
info!("Connecting to database at {}...", CONFIG.database.url);
let mut opt = ConnectOptions::new(CONFIG.database.url.clone());
@ -64,7 +72,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
info!("Performing database migration...");
Migrator::up(&db, None).await?;
let data = Data::new(AppState { conn: db });
let data = Data::new(AppState { conn: db, keystore });
HttpServer::new(move || {
App::new()
@ -103,6 +111,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
.service(routes::v1::hosts::block_host)
.service(routes::v1::hosts::enroll_host)
.service(routes::v1::hosts::create_host_and_enrollment_code)
.service(routes::v2::enroll::enroll)
})
.bind(CONFIG.server.bind)?
.run()

View file

@ -1 +1,2 @@
pub mod v1;
pub mod v2;

View file

@ -33,7 +33,7 @@ use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, Query
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use trifid_api_entities::entity::{network, organization, signing_ca};
use trifid_pki::cert::{serialize_x25519_private, NebulaCertificate, NebulaCertificateDetails};
use trifid_pki::cert::{serialize_x25519_private, NebulaCertificate, NebulaCertificateDetails, serialize_ed25519_private};
use trifid_pki::ed25519_dalek::SigningKey;
use trifid_pki::rand_core::OsRng;
@ -146,7 +146,7 @@ pub async fn create_org_request(
}
// PEM-encode the CA key
let ca_key_pem = serialize_x25519_private(&private_key.to_keypair_bytes());
let ca_key_pem = serialize_ed25519_private(&private_key.to_keypair_bytes());
// PEM-encode the CA cert
let ca_cert_pem = match cert.serialize_to_pem() {
Ok(pem) => pem,
@ -204,8 +204,8 @@ pub async fn create_org_request(
let signing_ca = signing_ca::Model {
id: random_id("ca"),
organization: org.id.clone(),
cert: ca_key_encrypted,
key: ca_crt,
cert: ca_crt,
key: ca_key_encrypted,
expires: cert
.details
.not_after

View file

@ -0,0 +1,233 @@
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest, HttpResponse};
use dnapi_rs::message::{
APIError, EnrollRequest, EnrollResponse, EnrollResponseData, EnrollResponseDataOrg,
};
use ed25519_dalek::{SigningKey, VerifyingKey};
use log::{debug, error, trace};
use rand::rngs::OsRng;
use sea_orm::{ColumnTrait, EntityTrait, ModelTrait, QueryFilter};
use crate::codegen::{collect_info, generate_config};
use crate::keystore::{KSCert, KSClientKey, KSConfig, KSSigningKey, KeystoreHostInformation};
use crate::AppState;
use trifid_api_entities::entity::host_enrollment_code;
use trifid_pki::cert::{
deserialize_ed25519_public, deserialize_x25519_public, serialize_ed25519_public,
};
use trifid_pki::x25519_dalek::PublicKey;
use crate::timers::expired;
#[post("/v2/enroll")]
pub async fn enroll(
req: Json<EnrollRequest>,
_req_info: HttpRequest,
db: Data<AppState>,
) -> HttpResponse {
debug!("{:x?} {:x?}", req.dh_pubkey, req.ed_pubkey);
// pull enroll information from the db
let code_info = match host_enrollment_code::Entity::find()
.filter(host_enrollment_code::Column::Id.eq(&req.code))
.one(&db.conn)
.await
{
Ok(ci) => ci,
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message:
"There was an error with the database request. Please try again later."
.to_string(),
path: None,
}],
});
}
};
let enroll_info = match code_info {
Some(ei) => ei,
None => {
return HttpResponse::Unauthorized().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_UNAUTHORIZED".to_string(),
message: "That code is invalid or has expired.".to_string(),
path: None,
}],
})
}
};
if expired(enroll_info.expires_on as u64) {
return HttpResponse::Unauthorized().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_UNAUTHORIZED".to_string(),
message: "That code is invalid or has expired.".to_string(),
path: None,
}],
});
}
// deserialize
let dh_pubkey = match deserialize_x25519_public(&req.dh_pubkey) {
Ok(k) => k,
Err(e) => {
error!("public key deserialization error: {}", e);
return HttpResponse::BadRequest().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_BAD_DH_PUB".to_string(),
message: "Unable to deserialize the DH public key.".to_string(),
path: None,
}],
});
}
};
let ed_pubkey = match deserialize_ed25519_public(&req.ed_pubkey) {
Ok(k) => k,
Err(e) => {
error!("public key deserialization error: {}", e);
return HttpResponse::BadRequest().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_BAD_ED_PUB".to_string(),
message: "Unable to deserialize the ED25519 public key.".to_string(),
path: None,
}],
});
}
};
// destroy the enrollment code before doing anything else
match enroll_info.clone().delete(&db.conn).await {
Ok(_) => (),
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message:
"There was an error with the database request. Please try again later."
.to_string(),
path: None,
}],
});
}
}
let info = match collect_info(&db, &enroll_info.host, &dh_pubkey).await {
Ok(i) => i,
Err(e) => {
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: e.to_string(),
path: None,
}],
});
}
};
// codegen: handoff to dedicated codegen module, we have collected all information
let (cfg, cert) = match generate_config(&db, &info).await {
Ok(cfg) => cfg,
Err(e) => {
error!("error generating configuration: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: "There was an error generating the host configuration.".to_string(),
path: None,
}],
});
}
};
let host_in_ks = db.keystore.hosts.iter().find(|u| u.id == enroll_info.id);
let host_in_ks = match host_in_ks {
Some(ksinfo) => {
let mut ks = ksinfo.clone();
ks.certs.push(KSCert {
id: ks.current_cert + 1,
cert,
});
ks.current_cert += 1;
ks.config.push(KSConfig {
id: ks.current_config + 1,
config: cfg.clone(),
});
ks.current_config += 1;
ks.signing_keys.push(KSSigningKey {
id: ks.current_signing_key,
key: SigningKey::generate(&mut OsRng),
});
ks.current_signing_key += 1;
let dh_pubkey_typed: [u8; 32] = dh_pubkey.clone().try_into().unwrap();
ks.client_keys.push(KSClientKey {
id: ks.current_client_key + 1,
dh_pub: PublicKey::from(dh_pubkey_typed),
ed_pub: VerifyingKey::from_bytes(&ed_pubkey.try_into().unwrap()).unwrap(),
});
ks.current_client_key += 1;
ks
}
None => {
let dh_pubkey_typed: [u8; 32] = dh_pubkey.clone().try_into().unwrap();
KeystoreHostInformation {
id: enroll_info.id.clone(),
current_signing_key: 1,
current_client_key: 1,
current_config: 1,
current_cert: 1,
certs: vec![KSCert { id: 1, cert }],
config: vec![KSConfig {
id: 1,
config: cfg.clone(),
}],
signing_keys: vec![KSSigningKey {
id: 1,
key: SigningKey::generate(&mut OsRng),
}],
client_keys: vec![KSClientKey {
id: 1,
dh_pub: PublicKey::from(dh_pubkey_typed),
ed_pub: VerifyingKey::from_bytes(&ed_pubkey.try_into().unwrap()).unwrap(),
}],
}
}
};
HttpResponse::Ok().json(EnrollResponse::Success {
data: EnrollResponseData {
config: match serde_yaml::to_string(&cfg) {
Ok(cfg) => cfg.as_bytes().to_vec(),
Err(e) => {
error!("serialization error: {}", e);
return HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![
APIError {
code: "ERR_CFG_SERIALIZATION_ERROR".to_string(),
message: "There was an error serializing the node's configuration. Please try again later.".to_string(),
path: None,
}
],
});
}
},
host_id: enroll_info.host.clone(),
counter: host_in_ks.current_config as u32,
trusted_keys: serialize_ed25519_public(host_in_ks.signing_keys.iter().find(|u| u.id == host_in_ks.current_signing_key).unwrap().key.verifying_key().as_bytes().as_slice()).to_vec(),
organization: EnrollResponseDataOrg { id: info.organization.id.clone(), name: info.organization.name.clone() },
},
})
}

View file

@ -0,0 +1 @@
pub mod enroll;

View file

@ -1,6 +1,6 @@
[package]
name = "trifid-pki"
version = "0.1.10"
version = "0.1.11"
edition = "2021"
description = "A rust implementation of the Nebula PKI system"
license = "AGPL-3.0-or-later"

View file

@ -1,14 +1,14 @@
//! Structs to represent a pool of CA's and blacklisted certificates
use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
use ed25519_dalek::VerifyingKey;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::time::SystemTime;
use ed25519_dalek::VerifyingKey;
use crate::cert::{deserialize_nebula_certificate_from_pem, NebulaCertificate};
#[cfg(feature = "serde_derive")]
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// A pool of trusted CA certificates, and certificates that should be blocked.
/// This is equivalent to the `pki` section in a typical Nebula config.yml.
@ -20,7 +20,7 @@ pub struct NebulaCAPool {
/// The list of blocklisted certificate fingerprints
pub cert_blocklist: Vec<String>,
/// True if any of the member CAs certificates are expired. Must be handled.
pub expired: bool
pub expired: bool,
}
impl NebulaCAPool {
@ -41,8 +41,12 @@ impl NebulaCAPool {
for cert in pems {
match pool.add_ca_certificate(pem::encode(&cert).as_bytes()) {
Ok(did_expire) => if did_expire { pool.expired = true },
Err(e) => return Err(e)
Ok(did_expire) => {
if did_expire {
pool.expired = true
}
}
Err(e) => return Err(e),
}
}
@ -56,21 +60,23 @@ impl NebulaCAPool {
let cert = deserialize_nebula_certificate_from_pem(bytes)?;
if !cert.details.is_ca {
return Err(CaPoolError::NotACA.into())
return Err(CaPoolError::NotACA.into());
}
if !cert.check_signature(&VerifyingKey::from_bytes(&cert.details.public_key)?)? {
return Err(CaPoolError::NotSelfSigned.into())
return Err(CaPoolError::NotSelfSigned.into());
}
let fingerprint = cert.sha256sum()?;
let expired = cert.expired(SystemTime::now());
if expired { self.expired = true }
if expired {
self.expired = true
}
self.cas.insert(fingerprint, cert);
Ok(expired)
Ok(expired)
}
/// Blocklist the given certificate in the CA pool
@ -92,9 +98,12 @@ impl NebulaCAPool {
/// Gets the CA certificate used to sign the given certificate
/// # Errors
/// This function will return an error if the certificate does not have an issuer attached (it is self-signed)
pub fn get_ca_for_cert(&self, cert: &NebulaCertificate) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> {
pub fn get_ca_for_cert(
&self,
cert: &NebulaCertificate,
) -> Result<Option<&NebulaCertificate>, Box<dyn Error>> {
if cert.details.issuer == String::new() {
return Err(CaPoolError::NoIssuer.into())
return Err(CaPoolError::NoIssuer.into());
}
Ok(self.cas.get(&cert.details.issuer))
@ -115,7 +124,7 @@ pub enum CaPoolError {
/// Tried to add a non-self-signed cert to the CA pool (all CAs must be root certificates)
NotSelfSigned,
/// Tried to look up a certificate that does not have an issuer field
NoIssuer
NoIssuer,
}
impl Error for CaPoolError {}
#[cfg(not(tarpaulin_include))]
@ -127,4 +136,4 @@ impl Display for CaPoolError {
Self::NoIssuer => write!(f, "Tried to look up a certificate with a null issuer field")
}
}
}
}

View file

@ -1,22 +1,22 @@
//! Manage Nebula PKI Certificates
//! This is pretty much a direct port of nebula/cert/cert.go
use crate::ca::NebulaCAPool;
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use ipnet::Ipv4Net;
use pem::Pem;
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use sha2::Digest;
use sha2::Sha256;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::net::Ipv4Addr;
use std::ops::Add;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use ipnet::{Ipv4Net};
use pem::Pem;
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use sha2::Sha256;
use crate::ca::NebulaCAPool;
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use sha2::Digest;
#[cfg(feature = "serde_derive")]
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// The length, in bytes, of public keys
pub const PUBLIC_KEY_LENGTH: i32 = 32;
@ -39,7 +39,7 @@ pub struct NebulaCertificate {
/// The signed data of this certificate
pub details: NebulaCertificateDetails,
/// The Ed25519 signature of this certificate
pub signature: Vec<u8>
pub signature: Vec<u8>,
}
/// The signed details contained in a Nebula PKI certificate
@ -63,7 +63,7 @@ pub struct NebulaCertificateDetails {
/// Is this node a CA?
pub is_ca: bool,
/// SHA256 of issuer certificate. If blank, this cert is self-signed.
pub issuer: String
pub issuer: String,
}
/// A list of errors that can occur parsing certificates
@ -87,7 +87,7 @@ pub enum CertificateError {
/// This certificate either is not yet valid or has already expired
Expired,
/// The public key does not match the expected value
KeyMismatch
KeyMismatch,
}
#[cfg(not(tarpaulin_include))]
impl Display for CertificateError {
@ -95,13 +95,29 @@ impl Display for CertificateError {
match self {
Self::EmptyByteArray => write!(f, "Certificate bytearray is empty"),
Self::NilDetails => write!(f, "The encoded Details field is null"),
Self::IpsNotPairs => write!(f, "encoded IPs should be in pairs, an odd number was found"),
Self::SubnetsNotPairs => write!(f, "encoded subnets should be in pairs, an odd number was found"),
Self::WrongSigLength => write!(f, "Signature should be 64 bytes but is a different size"),
Self::WrongKeyLength => write!(f, "Public keys are expected to be 32 bytes but the public key on this cert is not"),
Self::WrongPemTag => write!(f, "Certificates should have the PEM tag `NEBULA CERTIFICATE`, but this block did not"),
Self::Expired => write!(f, "This certificate either is not yet valid or has already expired"),
Self::KeyMismatch => write!(f, "Key does not match expected value")
Self::IpsNotPairs => {
write!(f, "encoded IPs should be in pairs, an odd number was found")
}
Self::SubnetsNotPairs => write!(
f,
"encoded subnets should be in pairs, an odd number was found"
),
Self::WrongSigLength => {
write!(f, "Signature should be 64 bytes but is a different size")
}
Self::WrongKeyLength => write!(
f,
"Public keys are expected to be 32 bytes but the public key on this cert is not"
),
Self::WrongPemTag => write!(
f,
"Certificates should have the PEM tag `NEBULA CERTIFICATE`, but this block did not"
),
Self::Expired => write!(
f,
"This certificate either is not yet valid or has already expired"
),
Self::KeyMismatch => write!(f, "Key does not match expected value"),
}
}
}
@ -110,7 +126,10 @@ impl Error for CertificateError {}
fn map_cidr_pairs(pairs: &[u32]) -> Result<Vec<Ipv4Net>, Box<dyn Error>> {
let mut res_vec = vec![];
for pair in pairs.chunks(2) {
res_vec.push(Ipv4Net::with_netmask(Ipv4Addr::from(pair[0]), Ipv4Addr::from(pair[1]))?);
res_vec.push(Ipv4Net::with_netmask(
Ipv4Addr::from(pair[0]),
Ipv4Addr::from(pair[1]),
)?);
}
Ok(res_vec)
}
@ -129,7 +148,11 @@ impl Display for NebulaCertificate {
writeln!(f, " Not after: {:?}", self.details.not_after)?;
writeln!(f, " Is CA: {}", self.details.is_ca)?;
writeln!(f, " Issuer: {}", self.details.issuer)?;
writeln!(f, " Public key: {}", hex::encode(self.details.public_key))?;
writeln!(
f,
" Public key: {}",
hex::encode(self.details.public_key)
)?;
writeln!(f, " }}")?;
writeln!(f, " Fingerprint: {}", self.sha256sum().unwrap())?;
writeln!(f, " Signature: {}", hex::encode(self.signature.clone()))?;
@ -143,7 +166,7 @@ impl Display for NebulaCertificate {
/// # Panics
pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate, Box<dyn Error>> {
if bytes.is_empty() {
return Err(CertificateError::EmptyByteArray.into())
return Err(CertificateError::EmptyByteArray.into());
}
let mut reader = BytesReader::from_bytes(bytes);
@ -153,11 +176,11 @@ pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate,
let details = raw_cert.Details.ok_or(CertificateError::NilDetails)?;
if details.Ips.len() % 2 != 0 {
return Err(CertificateError::IpsNotPairs.into())
return Err(CertificateError::IpsNotPairs.into());
}
if details.Subnets.len() % 2 != 0 {
return Err(CertificateError::SubnetsNotPairs.into())
return Err(CertificateError::SubnetsNotPairs.into());
}
let mut nebula_cert;
@ -168,8 +191,13 @@ pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate,
name: details.Name.to_string(),
ips: map_cidr_pairs(&details.Ips)?,
subnets: map_cidr_pairs(&details.Subnets)?,
groups: details.Groups.iter().map(std::string::ToString::to_string).collect(),
not_before: SystemTime::UNIX_EPOCH.add(Duration::from_secs(details.NotBefore as u64)),
groups: details
.Groups
.iter()
.map(std::string::ToString::to_string)
.collect(),
not_before: SystemTime::UNIX_EPOCH
.add(Duration::from_secs(details.NotBefore as u64)),
not_after: SystemTime::UNIX_EPOCH.add(Duration::from_secs(details.NotAfter as u64)),
public_key: [0u8; 32],
is_ca: details.IsCA,
@ -182,10 +210,13 @@ pub fn deserialize_nebula_certificate(bytes: &[u8]) -> Result<NebulaCertificate,
nebula_cert.signature = raw_cert.Signature;
if details.PublicKey.len() != 32 {
return Err(CertificateError::WrongKeyLength.into())
return Err(CertificateError::WrongKeyLength.into());
}
#[allow(clippy::unwrap_used)] { nebula_cert.details.public_key = details.PublicKey.try_into().unwrap(); }
#[allow(clippy::unwrap_used)]
{
nebula_cert.details.public_key = details.PublicKey.try_into().unwrap();
}
Ok(nebula_cert)
}
@ -199,28 +230,32 @@ pub enum KeyError {
/// Ed25519 private keys are 64 bytes
Not64Bytes,
/// X25519 private keys are 32 bytes
Not32Bytes
Not32Bytes,
}
#[cfg(not(tarpaulin_include))]
impl Display for KeyError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::WrongPemTag => write!(f, "Keys should have their associated PEM tags but this had the wrong one"),
Self::WrongPemTag => write!(
f,
"Keys should have their associated PEM tags but this had the wrong one"
),
Self::Not64Bytes => write!(f, "Ed25519 private keys are 64 bytes"),
Self::Not32Bytes => write!(f, "X25519 private keys are 32 bytes")
Self::Not32Bytes => write!(f, "X25519 private keys are 32 bytes"),
}
}
}
impl Error for KeyError {}
/// Deserialize the first PEM block in the given byte array into a `NebulaCertificate`
/// # Errors
/// This function will return an error if the PEM data is invalid, or if there is an error parsing the certificate (see `deserialize_nebula_certificate`)
pub fn deserialize_nebula_certificate_from_pem(bytes: &[u8]) -> Result<NebulaCertificate, Box<dyn Error>> {
pub fn deserialize_nebula_certificate_from_pem(
bytes: &[u8],
) -> Result<NebulaCertificate, Box<dyn Error>> {
let pem = pem::parse(bytes)?;
if pem.tag != CERT_BANNER {
return Err(CertificateError::WrongPemTag.into())
return Err(CertificateError::WrongPemTag.into());
}
deserialize_nebula_certificate(&pem.contents)
}
@ -230,7 +265,9 @@ pub fn serialize_x25519_private(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem {
tag: X25519_PRIVATE_KEY_BANNER.to_string(),
contents: bytes.to_vec(),
}).as_bytes().to_vec()
})
.as_bytes()
.to_vec()
}
/// Simple helper to PEM encode an X25519 public key
@ -238,7 +275,9 @@ pub fn serialize_x25519_public(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem {
tag: X25519_PUBLIC_KEY_BANNER.to_string(),
contents: bytes.to_vec(),
}).as_bytes().to_vec()
})
.as_bytes()
.to_vec()
}
/// Attempt to deserialize a PEM encoded X25519 private key
@ -247,10 +286,10 @@ pub fn serialize_x25519_public(bytes: &[u8]) -> Vec<u8> {
pub fn deserialize_x25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?;
if pem.tag != X25519_PRIVATE_KEY_BANNER {
return Err(KeyError::WrongPemTag.into())
return Err(KeyError::WrongPemTag.into());
}
if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into())
return Err(KeyError::Not32Bytes.into());
}
Ok(pem.contents)
}
@ -261,10 +300,10 @@ pub fn deserialize_x25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error
pub fn deserialize_x25519_public(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?;
if pem.tag != X25519_PUBLIC_KEY_BANNER {
return Err(KeyError::WrongPemTag.into())
return Err(KeyError::WrongPemTag.into());
}
if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into())
return Err(KeyError::Not32Bytes.into());
}
Ok(pem.contents)
}
@ -274,7 +313,9 @@ pub fn serialize_ed25519_private(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem {
tag: ED25519_PRIVATE_KEY_BANNER.to_string(),
contents: bytes.to_vec(),
}).as_bytes().to_vec()
})
.as_bytes()
.to_vec()
}
/// Simple helper to PEM encode an Ed25519 public key
@ -282,7 +323,9 @@ pub fn serialize_ed25519_public(bytes: &[u8]) -> Vec<u8> {
pem::encode(&Pem {
tag: ED25519_PUBLIC_KEY_BANNER.to_string(),
contents: bytes.to_vec(),
}).as_bytes().to_vec()
})
.as_bytes()
.to_vec()
}
/// Attempt to deserialize a PEM encoded Ed25519 private key
@ -291,10 +334,10 @@ pub fn serialize_ed25519_public(bytes: &[u8]) -> Vec<u8> {
pub fn deserialize_ed25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?;
if pem.tag != ED25519_PRIVATE_KEY_BANNER {
return Err(KeyError::WrongPemTag.into())
return Err(KeyError::WrongPemTag.into());
}
if pem.contents.len() != 64 {
return Err(KeyError::Not64Bytes.into())
return Err(KeyError::Not64Bytes.into());
}
Ok(pem.contents)
}
@ -305,10 +348,10 @@ pub fn deserialize_ed25519_private(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Erro
pub fn deserialize_ed25519_public(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let pem = pem::parse(bytes)?;
if pem.tag != ED25519_PUBLIC_KEY_BANNER {
return Err(KeyError::WrongPemTag.into())
return Err(KeyError::WrongPemTag.into());
}
if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into())
return Err(KeyError::Not32Bytes.into());
}
Ok(pem.contents)
}
@ -322,10 +365,10 @@ pub fn deserialize_ed25519_public_many(bytes: &[u8]) -> Result<Vec<Vec<u8>>, Box
for pem in pems {
if pem.tag != ED25519_PUBLIC_KEY_BANNER {
return Err(KeyError::WrongPemTag.into())
return Err(KeyError::WrongPemTag.into());
}
if pem.contents.len() != 32 {
return Err(KeyError::Not32Bytes.into())
return Err(KeyError::Not32Bytes.into());
}
keys.push(pem.contents);
}
@ -367,7 +410,11 @@ impl NebulaCertificate {
/// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
/// # Errors
/// This function will return an error if there is an error parsing the cert or the CA pool.
pub fn verify(&self, time: SystemTime, ca_pool: &NebulaCAPool) -> Result<CertificateValidity, Box<dyn Error>> {
pub fn verify(
&self,
time: SystemTime,
ca_pool: &NebulaCAPool,
) -> Result<CertificateValidity, Box<dyn Error>> {
if ca_pool.is_blocklisted(self) {
return Ok(CertificateValidity::Blocklisted);
}
@ -375,15 +422,15 @@ impl NebulaCertificate {
let Some(signer) = ca_pool.get_ca_for_cert(self)? else { return Ok(CertificateValidity::NotSignedByThisCAPool) };
if signer.expired(time) {
return Ok(CertificateValidity::RootCertExpired)
return Ok(CertificateValidity::RootCertExpired);
}
if self.expired(time) {
return Ok(CertificateValidity::CertExpired)
return Ok(CertificateValidity::CertExpired);
}
if !self.check_signature(&VerifyingKey::from_bytes(&signer.details.public_key)?)? {
return Ok(CertificateValidity::BadSignature)
return Ok(CertificateValidity::BadSignature);
}
Ok(self.check_root_constraints(signer))
@ -392,7 +439,10 @@ impl NebulaCertificate {
/// Make sure that this certificate does not break any of the constraints set by the signing certificate
pub fn check_root_constraints(&self, signer: &Self) -> CertificateValidity {
// Make sure this cert doesn't expire after the signer
println!("{:?} {:?}", signer.details.not_before, self.details.not_before);
println!(
"{:?} {:?}",
signer.details.not_before, self.details.not_before
);
if signer.details.not_before < self.details.not_before {
return CertificateValidity::CertExpiresAfterSigner;
}
@ -404,7 +454,10 @@ impl NebulaCertificate {
// If the signer contains a limited set of groups, make sure this cert only has a subset of them
if !signer.details.groups.is_empty() {
println!("root groups: {:?}, child groups: {:?}", signer.details.groups, self.details.groups);
println!(
"root groups: {:?}, child groups: {:?}",
signer.details.groups, self.details.groups
);
for group in &self.details.groups {
if !signer.details.groups.contains(group) {
return CertificateValidity::GroupNotPresentOnSigner;
@ -443,10 +496,9 @@ impl NebulaCertificate {
if self.details.is_ca {
// convert the keys
if key.len() != 64 {
return Err("key not 64-bytes long".into())
return Err("key not 64-bytes long".into());
}
let secret = SigningKey::from_keypair_bytes(key.try_into().unwrap())?;
let pub_key = secret.verifying_key().to_bytes();
if pub_key != self.details.public_key {
@ -457,13 +509,17 @@ impl NebulaCertificate {
}
if key.len() != 32 {
return Err("key not 32-bytes long".into())
return Err("key not 32-bytes long".into());
}
let pubkey_raw = SigningKey::from_bytes(key.try_into()?).verifying_key();
let pubkey = pubkey_raw.as_bytes();
println!("{} {}", hex::encode(pubkey), hex::encode(self.details.public_key));
println!(
"{} {}",
hex::encode(pubkey),
hex::encode(self.details.public_key)
);
if *pubkey != self.details.public_key {
return Err(CertificateError::KeyMismatch.into());
}
@ -471,21 +527,34 @@ impl NebulaCertificate {
Ok(())
}
/// Get a protobuf-ready raw struct, ready for serialization
#[allow(clippy::expect_used)]
#[allow(clippy::cast_possible_wrap)]
/// # Panics
/// This function will panic if time went backwards, or if the certificate contains extremely invalid data.
pub fn get_raw_details(&self) -> RawNebulaCertificateDetails {
let mut raw = RawNebulaCertificateDetails {
Name: self.details.name.clone(),
Ips: vec![],
Subnets: vec![],
Groups: self.details.groups.iter().map(std::convert::Into::into).collect(),
NotBefore: self.details.not_before.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs() as i64,
NotAfter: self.details.not_after.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs() as i64,
Groups: self
.details
.groups
.iter()
.map(std::convert::Into::into)
.collect(),
NotBefore: self
.details
.not_before
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64,
NotAfter: self
.details
.not_after
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64,
PublicKey: self.details.public_key.into(),
IsCA: self.details.is_ca,
Issuer: hex::decode(&self.details.issuer).expect("Issuer was not a hex-encoded value"),
@ -529,7 +598,9 @@ impl NebulaCertificate {
Ok(pem::encode(&Pem {
tag: CERT_BANNER.to_string(),
contents: pbuf_bytes,
}).as_bytes().to_vec())
})
.as_bytes()
.to_vec())
}
/// Get the fingerprint of this certificate
@ -570,7 +641,7 @@ pub enum CertificateValidity {
/// An IP present on this certificate is not present on the signer's certificate
IPNotPresentOnSigner,
/// A subnet on this certificate is not present on the signer's certificate
SubnetNotPresentOnSigner
SubnetNotPresentOnSigner,
}
fn net_match(cert_ip: Ipv4Net, root_ips: &Vec<Ipv4Net>) -> bool {
@ -580,4 +651,4 @@ fn net_match(cert_ip: Ipv4Net, root_ips: &Vec<Ipv4Net>) -> bool {
}
}
false
}
}

View file

@ -32,7 +32,6 @@
//! // }
//! ```
#![warn(clippy::pedantic)]
#![warn(clippy::nursery)]
#![deny(clippy::unwrap_used)]
@ -46,8 +45,8 @@
#![allow(clippy::module_name_repetitions)]
pub use ed25519_dalek;
pub use x25519_dalek;
pub use rand_core;
pub use x25519_dalek;
extern crate core;
@ -60,4 +59,4 @@ pub(crate) mod cert_codec;
pub mod test;
/// Get the compiled version of trifid-pki.
pub const TRIFID_PKI_VERSION: &str = env!("CARGO_PKG_VERSION");
pub const TRIFID_PKI_VERSION: &str = env!("CARGO_PKG_VERSION");

View file

@ -1,18 +1,24 @@
#![allow(clippy::unwrap_used)]
#![allow(clippy::expect_used)]
use crate::ca::NebulaCAPool;
use crate::cert::{
deserialize_ed25519_private, deserialize_ed25519_public, deserialize_ed25519_public_many,
deserialize_nebula_certificate, deserialize_nebula_certificate_from_pem,
deserialize_x25519_private, deserialize_x25519_public, serialize_ed25519_private,
serialize_ed25519_public, serialize_x25519_private, serialize_x25519_public,
CertificateValidity, NebulaCertificate, NebulaCertificateDetails,
};
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use crate::netmask;
use std::net::Ipv4Addr;
use std::ops::{Add, Sub};
use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
use ipnet::Ipv4Net;
use crate::cert::{CertificateValidity, deserialize_ed25519_private, deserialize_ed25519_public, deserialize_ed25519_public_many, deserialize_nebula_certificate, deserialize_nebula_certificate_from_pem, deserialize_x25519_private, deserialize_x25519_public, NebulaCertificate, NebulaCertificateDetails, serialize_ed25519_private, serialize_ed25519_public, serialize_x25519_private, serialize_x25519_public};
use std::str::FromStr;
use ed25519_dalek::{SigningKey, VerifyingKey};
use ipnet::Ipv4Net;
use quick_protobuf::{MessageWrite, Writer};
use rand::rngs::OsRng;
use crate::ca::{NebulaCAPool};
use crate::cert_codec::{RawNebulaCertificate, RawNebulaCertificateDetails};
use std::net::Ipv4Addr;
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
/// This is a cert that we (e3team) actually use in production, and it's a known-good certificate.
pub const KNOWN_GOOD_CERT: &[u8; 258] = b"-----BEGIN NEBULA CERTIFICATE-----\nCkkKF2UzdGVhbSBJbnRlcm5hbCBOZXR3b3JrKJWev5wGMJWFxKsGOiCvpwoHyKY5\n8Q5+2XxDjtoCf/zlNY/EUdB8bwXQSwEo50ABEkB0Dx76lkMqc3IyH5+ml2dKjTyv\nB4Jiw6x3abf5YZcf8rDuVEgQpvFdJmo3xJyIb3C9vKZ6kXsUxjw6s1JdWgkA\n-----END NEBULA CERTIFICATE-----";
@ -29,14 +35,18 @@ fn certificate_serialization() {
ips: vec![
netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0")
netmask!("10.1.1.3", "255.0.0.0"),
],
subnets: vec![
netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0")
netmask!("9.1.1.3", "255.255.0.0"),
],
groups: vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
],
groups: vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()],
not_before: before,
not_after: after,
public_key: *pub_key,
@ -59,17 +69,29 @@ fn certificate_serialization() {
assert_eq!(cert.details.ips.len(), deserialized.details.ips.len());
for item in &cert.details.ips {
assert!(deserialized.details.ips.contains(item), "deserialized does not contain from source");
assert!(
deserialized.details.ips.contains(item),
"deserialized does not contain from source"
);
}
assert_eq!(cert.details.subnets.len(), deserialized.details.subnets.len());
assert_eq!(
cert.details.subnets.len(),
deserialized.details.subnets.len()
);
for item in &cert.details.subnets {
assert!(deserialized.details.subnets.contains(item), "deserialized does not contain from source");
assert!(
deserialized.details.subnets.contains(item),
"deserialized does not contain from source"
);
}
assert_eq!(cert.details.groups.len(), deserialized.details.groups.len());
for item in &cert.details.groups {
assert!(deserialized.details.groups.contains(item), "deserialized does not contain from source");
assert!(
deserialized.details.groups.contains(item),
"deserialized does not contain from source"
);
}
}
@ -85,14 +107,18 @@ fn certificate_serialization_pem() {
ips: vec![
netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0")
netmask!("10.1.1.3", "255.0.0.0"),
],
subnets: vec![
netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0")
netmask!("9.1.1.3", "255.255.0.0"),
],
groups: vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
],
groups: vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()],
not_before: before,
not_after: after,
public_key: *pub_key,
@ -115,17 +141,29 @@ fn certificate_serialization_pem() {
assert_eq!(cert.details.ips.len(), deserialized.details.ips.len());
for item in &cert.details.ips {
assert!(deserialized.details.ips.contains(item), "deserialized does not contain from source");
assert!(
deserialized.details.ips.contains(item),
"deserialized does not contain from source"
);
}
assert_eq!(cert.details.subnets.len(), deserialized.details.subnets.len());
assert_eq!(
cert.details.subnets.len(),
deserialized.details.subnets.len()
);
for item in &cert.details.subnets {
assert!(deserialized.details.subnets.contains(item), "deserialized does not contain from source");
assert!(
deserialized.details.subnets.contains(item),
"deserialized does not contain from source"
);
}
assert_eq!(cert.details.groups.len(), deserialized.details.groups.len());
for item in &cert.details.groups {
assert!(deserialized.details.groups.contains(item), "deserialized does not contain from source");
assert!(
deserialized.details.groups.contains(item),
"deserialized does not contain from source"
);
}
}
@ -141,14 +179,18 @@ fn cert_signing() {
ips: vec![
netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0")
netmask!("10.1.1.3", "255.0.0.0"),
],
subnets: vec![
netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0")
netmask!("9.1.1.3", "255.255.0.0"),
],
groups: vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
],
groups: vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()],
not_before: before,
not_after: after,
public_key: *pub_key,
@ -287,9 +329,15 @@ fn cert_deserialize_wrong_pubkey_len() {
#[test]
fn x25519_serialization() {
let bytes = [0u8; 32];
assert_eq!(deserialize_x25519_private(&serialize_x25519_private(&bytes)).unwrap(), bytes);
assert_eq!(
deserialize_x25519_private(&serialize_x25519_private(&bytes)).unwrap(),
bytes
);
assert!(deserialize_x25519_private(&[0u8; 32]).is_err());
assert_eq!(deserialize_x25519_public(&serialize_x25519_public(&bytes)).unwrap(), bytes);
assert_eq!(
deserialize_x25519_public(&serialize_x25519_public(&bytes)).unwrap(),
bytes
);
assert!(deserialize_x25519_public(&[0u8; 32]).is_err());
}
@ -297,9 +345,15 @@ fn x25519_serialization() {
fn ed25519_serialization() {
let bytes = [0u8; 64];
let bytes2 = [0u8; 32];
assert_eq!(deserialize_ed25519_private(&serialize_ed25519_private(&bytes)).unwrap(), bytes);
assert_eq!(
deserialize_ed25519_private(&serialize_ed25519_private(&bytes)).unwrap(),
bytes
);
assert!(deserialize_ed25519_private(&[0u8; 32]).is_err());
assert_eq!(deserialize_ed25519_public(&serialize_ed25519_public(&bytes2)).unwrap(), bytes2);
assert_eq!(
deserialize_ed25519_public(&serialize_ed25519_public(&bytes2)).unwrap(),
bytes2
);
assert!(deserialize_ed25519_public(&[0u8; 64]).is_err());
let mut bytes = vec![];
@ -315,29 +369,87 @@ fn ed25519_serialization() {
#[test]
fn cert_verify() {
let (ca_cert, ca_key, _ca_pub) = test_ca_cert(round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(), vec![], vec![], vec!["groupa".to_string()]);
let (ca_cert, ca_key, _ca_pub) = test_ca_cert(
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(),
vec![],
vec![],
vec!["groupa".to_string()],
);
let (cert, _, _) = test_cert(&ca_cert, &ca_key, SystemTime::now(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca_cert,
&ca_key,
SystemTime::now(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![],
vec![],
);
let mut ca_pool = NebulaCAPool::new();
ca_pool.add_ca_certificate(&ca_cert.serialize_to_pem().unwrap()).unwrap();
ca_pool
.add_ca_certificate(&ca_cert.serialize_to_pem().unwrap())
.unwrap();
let fingerprint = cert.sha256sum().unwrap();
ca_pool.blocklist_fingerprint(&fingerprint);
assert!(matches!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Blocklisted));
assert!(matches!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Blocklisted
));
ca_pool.reset_blocklist();
assert!(matches!(cert.verify(SystemTime::now() + Duration::from_secs(60 * 60 * 60), &ca_pool).unwrap(), CertificateValidity::RootCertExpired));
assert!(matches!(cert.verify(SystemTime::now() + Duration::from_secs(60 * 60 * 6), &ca_pool).unwrap(), CertificateValidity::CertExpired));
assert!(matches!(
cert.verify(
SystemTime::now() + Duration::from_secs(60 * 60 * 60),
&ca_pool
)
.unwrap(),
CertificateValidity::RootCertExpired
));
assert!(matches!(
cert.verify(
SystemTime::now() + Duration::from_secs(60 * 60 * 6),
&ca_pool
)
.unwrap(),
CertificateValidity::CertExpired
));
let (cert_with_bad_group, _, _) = test_cert(&ca_cert, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), in_a_minute(), vec![], vec![], vec!["group-not-present on parent".to_string()]);
assert_eq!(cert_with_bad_group.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::GroupNotPresentOnSigner);
let (cert_with_good_group, _, _) = test_cert(&ca_cert, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), in_a_minute(), vec![], vec![], vec!["groupa".to_string()]);
assert_eq!(cert_with_good_group.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert_with_bad_group, _, _) = test_cert(
&ca_cert,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
in_a_minute(),
vec![],
vec![],
vec!["group-not-present on parent".to_string()],
);
assert_eq!(
cert_with_bad_group
.verify(SystemTime::now(), &ca_pool)
.unwrap(),
CertificateValidity::GroupNotPresentOnSigner
);
let (cert_with_good_group, _, _) = test_cert(
&ca_cert,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
in_a_minute(),
vec![],
vec![],
vec!["groupa".to_string()],
);
assert_eq!(
cert_with_good_group
.verify(SystemTime::now(), &ca_pool)
.unwrap(),
CertificateValidity::Ok
);
}
#[test]
@ -345,7 +457,13 @@ fn cert_verify_ip() {
let ca_ip_1 = Ipv4Net::from_str("10.0.0.0/16").unwrap();
let ca_ip_2 = Ipv4Net::from_str("192.168.0.0/24").unwrap();
let (ca, ca_key, _ca_pub) = test_ca_cert(round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(), vec![ca_ip_1, ca_ip_2], vec![], vec![]);
let (ca, ca_key, _ca_pub) = test_ca_cert(
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(),
vec![ca_ip_1, ca_ip_2],
vec![],
vec![],
);
let ca_pem = ca.serialize_to_pem().unwrap();
@ -355,49 +473,137 @@ fn cert_verify_ip() {
// ip is outside the network
let cip1 = netmask!("10.1.0.0", "255.255.255.0");
let cip2 = netmask!("192.198.0.1", "255.255.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip is outside the network - reversed order from above
let cip1 = netmask!("192.198.0.1", "255.255.255.0");
let cip2 = netmask!("10.1.0.0", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip is within the network but mask is outside
let cip1 = netmask!("10.0.1.0", "255.254.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip is within the network but mask is outside - reversed order from above
let cip1 = netmask!("192.168.0.1", "255.255.255.0");
let cip2 = netmask!("10.0.1.0", "255.254.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::IPNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::IPNotPresentOnSigner
);
// ip and mask are within the network
let cip1 = netmask!("10.0.1.0", "255.255.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.128");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![ca_ip_1, ca_ip_2], vec![], vec![]);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![ca_ip_1, ca_ip_2],
vec![],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![ca_ip_2, ca_ip_1], vec![], vec![]);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![ca_ip_2, ca_ip_1],
vec![],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed with just one
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![ca_ip_2], vec![], vec![]);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![ca_ip_2],
vec![],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
}
#[test]
@ -405,7 +611,13 @@ fn cert_verify_subnet() {
let ca_ip_1 = Ipv4Net::from_str("10.0.0.0/16").unwrap();
let ca_ip_2 = Ipv4Net::from_str("192.168.0.0/24").unwrap();
let (ca, ca_key, _ca_pub) = test_ca_cert(round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(), vec![],vec![ca_ip_1, ca_ip_2], vec![]);
let (ca, ca_key, _ca_pub) = test_ca_cert(
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 10)).unwrap(),
vec![],
vec![ca_ip_1, ca_ip_2],
vec![],
);
let ca_pem = ca.serialize_to_pem().unwrap();
@ -415,63 +627,170 @@ fn cert_verify_subnet() {
// ip is outside the network
let cip1 = netmask!("10.1.0.0", "255.255.255.0");
let cip2 = netmask!("192.198.0.1", "255.255.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![cip1, cip2], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip is outside the network - reversed order from above
let cip1 = netmask!("192.198.0.1", "255.255.255.0");
let cip2 = netmask!("10.1.0.0", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![cip1, cip2], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip is within the network but mask is outside
let cip1 = netmask!("10.0.1.0", "255.254.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![cip1, cip2], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip is within the network but mask is outside - reversed order from above
let cip1 = netmask!("192.168.0.1", "255.255.255.0");
let cip2 = netmask!("10.0.1.0", "255.254.0.0");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![cip1, cip2], vec![], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![cip1, cip2],
vec![],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::SubnetNotPresentOnSigner);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::SubnetNotPresentOnSigner
);
// ip and mask are within the network
let cip1 = netmask!("10.0.1.0", "255.255.0.0");
let cip2 = netmask!("192.168.0.1", "255.255.255.128");
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![cip1, cip2], vec![]);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![cip1, cip2],
vec![],
);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![],vec![ca_ip_1, ca_ip_2], vec![]);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![ca_ip_1, ca_ip_2],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![ca_ip_2, ca_ip_1], vec![]);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![ca_ip_2, ca_ip_1],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
// Exact matches reversed with just one
let (cert, _, _) = test_cert(&ca, &ca_key, round_systime_to_secs(SystemTime::now()).unwrap(), round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(), vec![], vec![ca_ip_2], vec![]);
assert_eq!(cert.verify(SystemTime::now(), &ca_pool).unwrap(), CertificateValidity::Ok);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
round_systime_to_secs(SystemTime::now()).unwrap(),
round_systime_to_secs(SystemTime::now() + Duration::from_secs(60 * 60 * 5)).unwrap(),
vec![],
vec![ca_ip_2],
vec![],
);
assert_eq!(
cert.verify(SystemTime::now(), &ca_pool).unwrap(),
CertificateValidity::Ok
);
}
#[test]
fn cert_private_key() {
let (ca, ca_key, _) = test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
let (ca, ca_key, _) =
test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
ca.verify_private_key(&ca_key.to_keypair_bytes()).unwrap();
let (_, ca_key2, _) = test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
ca.verify_private_key(&ca_key2.to_keypair_bytes()).unwrap_err();
let (_, ca_key2, _) =
test_ca_cert(SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
ca.verify_private_key(&ca_key2.to_keypair_bytes())
.unwrap_err();
let (cert, priv_key, _) = test_cert(&ca, &ca_key, SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
let (cert, priv_key, _) = test_cert(
&ca,
&ca_key,
SystemTime::now(),
SystemTime::now(),
vec![],
vec![],
vec![],
);
cert.verify_private_key(&priv_key.to_bytes()).unwrap();
let (cert2, _, _) = test_cert(&ca, &ca_key, SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
let (cert2, _, _) = test_cert(
&ca,
&ca_key,
SystemTime::now(),
SystemTime::now(),
vec![],
vec![],
vec![],
);
cert2.verify_private_key(&priv_key.to_bytes()).unwrap_err();
}
@ -489,7 +808,8 @@ CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
-----END NEBULA CERTIFICATE-----";
let with_newlines = b"# Current provisional, Remove once everything moves over to the real root.
let with_newlines =
b"# Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
@ -511,24 +831,64 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
-----END NEBULA CERTIFICATE-----";
let pool_a = NebulaCAPool::new_from_pem(no_newlines).unwrap();
assert_eq!(pool_a.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"].details.name, "nebula root ca".to_string());
assert_eq!(pool_a.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"].details.name, "nebula root ca 01".to_string());
assert_eq!(
pool_a.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"]
.details
.name,
"nebula root ca".to_string()
);
assert_eq!(
pool_a.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"]
.details
.name,
"nebula root ca 01".to_string()
);
assert!(!pool_a.expired);
let pool_b = NebulaCAPool::new_from_pem(with_newlines).unwrap();
assert_eq!(pool_b.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"].details.name, "nebula root ca".to_string());
assert_eq!(pool_b.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"].details.name, "nebula root ca 01".to_string());
assert_eq!(
pool_b.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"]
.details
.name,
"nebula root ca".to_string()
);
assert_eq!(
pool_b.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"]
.details
.name,
"nebula root ca 01".to_string()
);
assert!(!pool_b.expired);
let pool_c = NebulaCAPool::new_from_pem(expired).unwrap();
assert!(pool_c.expired);
assert_eq!(pool_c.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"].details.name, "expired");
assert_eq!(
pool_c.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"]
.details
.name,
"expired"
);
let mut pool_d = NebulaCAPool::new_from_pem(with_newlines).unwrap();
pool_d.add_ca_certificate(expired).unwrap();
assert_eq!(pool_d.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"].details.name, "nebula root ca".to_string());
assert_eq!(pool_d.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"].details.name, "nebula root ca 01".to_string());
assert_eq!(pool_d.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"].details.name, "expired");
assert_eq!(
pool_d.cas["c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522"]
.details
.name,
"nebula root ca".to_string()
);
assert_eq!(
pool_d.cas["5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd"]
.details
.name,
"nebula root ca 01".to_string()
);
assert_eq!(
pool_d.cas["152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0"]
.details
.name,
"expired"
);
assert!(pool_d.expired);
assert_eq!(pool_d.get_fingerprints().len(), 3);
}
@ -536,7 +896,11 @@ WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
#[macro_export]
macro_rules! netmask {
($ip:expr,$mask:expr) => {
Ipv4Net::with_netmask(Ipv4Addr::from_str($ip).unwrap(), Ipv4Addr::from_str($mask).unwrap()).unwrap()
Ipv4Net::with_netmask(
Ipv4Addr::from_str($ip).unwrap(),
Ipv4Addr::from_str($mask).unwrap(),
)
.unwrap()
};
}
@ -545,7 +909,13 @@ fn round_systime_to_secs(time: SystemTime) -> Result<SystemTime, SystemTimeError
Ok(SystemTime::UNIX_EPOCH.add(Duration::from_secs(secs)))
}
fn test_ca_cert(before: SystemTime, after: SystemTime, ips: Vec<Ipv4Net>, subnets: Vec<Ipv4Net>, groups: Vec<String>) -> (NebulaCertificate, SigningKey, VerifyingKey) {
fn test_ca_cert(
before: SystemTime,
after: SystemTime,
ips: Vec<Ipv4Net>,
subnets: Vec<Ipv4Net>,
groups: Vec<String>,
) -> (NebulaCertificate, SigningKey, VerifyingKey) {
let mut csprng = OsRng;
let key = SigningKey::generate(&mut csprng);
let pub_key = key.verifying_key();
@ -710,17 +1080,45 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
fn ca_pool_add_non_ca() {
let mut ca_pool = NebulaCAPool::new();
let (ca, ca_key, _) = test_ca_cert(SystemTime::now(), SystemTime::now() + Duration::from_secs(3600), vec![], vec![], vec![]);
let (cert, _, _) = test_cert(&ca, &ca_key, SystemTime::now(), SystemTime::now(), vec![], vec![], vec![]);
let (ca, ca_key, _) = test_ca_cert(
SystemTime::now(),
SystemTime::now() + Duration::from_secs(3600),
vec![],
vec![],
vec![],
);
let (cert, _, _) = test_cert(
&ca,
&ca_key,
SystemTime::now(),
SystemTime::now(),
vec![],
vec![],
vec![],
);
ca_pool.add_ca_certificate(&cert.serialize_to_pem().unwrap()).unwrap_err();
ca_pool
.add_ca_certificate(&cert.serialize_to_pem().unwrap())
.unwrap_err();
}
fn test_cert(ca: &NebulaCertificate, key: &SigningKey, before: SystemTime, after: SystemTime, ips: Vec<Ipv4Net>, subnets: Vec<Ipv4Net>, groups: Vec<String>) -> (NebulaCertificate, SigningKey, VerifyingKey) {
fn test_cert(
ca: &NebulaCertificate,
key: &SigningKey,
before: SystemTime,
after: SystemTime,
ips: Vec<Ipv4Net>,
subnets: Vec<Ipv4Net>,
groups: Vec<String>,
) -> (NebulaCertificate, SigningKey, VerifyingKey) {
let issuer = ca.sha256sum().unwrap();
let real_groups = if groups.is_empty() {
vec!["test-group1".to_string(), "test-group2".to_string(), "test-group3".to_string()]
vec![
"test-group1".to_string(),
"test-group2".to_string(),
"test-group3".to_string(),
]
} else {
groups
};
@ -729,7 +1127,7 @@ fn test_cert(ca: &NebulaCertificate, key: &SigningKey, before: SystemTime, after
vec![
netmask!("10.1.1.1", "255.255.255.0"),
netmask!("10.1.1.2", "255.255.0.0"),
netmask!("10.1.1.3", "255.0.0.0")
netmask!("10.1.1.3", "255.0.0.0"),
]
} else {
ips
@ -739,7 +1137,7 @@ fn test_cert(ca: &NebulaCertificate, key: &SigningKey, before: SystemTime, after
vec![
netmask!("9.1.1.1", "255.255.255.128"),
netmask!("9.1.1.2", "255.255.255.0"),
netmask!("9.1.1.3", "255.255.0.0")
netmask!("9.1.1.3", "255.255.0.0"),
]
} else {
subnets
@ -774,4 +1172,4 @@ fn in_a_minute() -> SystemTime {
#[allow(dead_code)]
fn a_minute_ago() -> SystemTime {
round_systime_to_secs(SystemTime::now().sub(Duration::from_secs(60))).unwrap()
}
}