0.3 error fixes, refmt
/ build (push) Successful in 49s Details
/ build_x64 (push) Successful in 2m6s Details
/ build_arm64 (push) Successful in 2m40s Details
/ build_win64 (push) Successful in 2m36s Details

This commit is contained in:
core 2023-11-23 15:23:52 -05:00
parent 2a5a2bb910
commit 646340b637
Signed by: core
GPG Key ID: FDBF740DADDCEECF
41 changed files with 1441 additions and 878 deletions

View File

@ -254,7 +254,10 @@ impl Client {
ed_privkey.verify(b64_msg_bytes, &Signature::from_slice(&signature)?)?; ed_privkey.verify(b64_msg_bytes, &Signature::from_slice(&signature)?)?;
debug!("signature valid via clientside check"); debug!("signature valid via clientside check");
debug!("signed with key: {:x?}", ed_privkey.verifying_key().as_bytes()); debug!(
"signed with key: {:x?}",
ed_privkey.verifying_key().as_bytes()
);
let body = RequestV1 { let body = RequestV1 {
version: 1, version: 1,

View File

@ -0,0 +1,18 @@
# API Support
This document is valid for: `trifid-api 0.3.0`.
This document is only useful for developers, and it lists what endpoint versions are currently supported by trifid-api. This is subject to change at any time according to SemVer constraints.
Endpoint types:
- **Documented** is an endpoint available in the official documentation
- **Reverse-engineered** is an endpoint that was reverse-engineered
| Endpoint Name | Version | Endpoint | Type | Added In |
|---------------------------|---------|------------------------------------|--------------------|----------------|
| Signup | v1 | POST /v1/signup | Reverse-engineered | 0.3.0/79b1765e |
| Get Magic Link | v1 | POST /v1/auth/magic-link | Reverse-engineered | 0.3.0/52049947 |
| Verify Magic Link | v1 | POST /v1/auth/verify-magic-link | Reverse-engineered | 0.3.0/51b6d3a8 |
| Create TOTP Authenticator | v1 | POST /v1/totp-authenticators | Reverse-engineered | 0.3.0/4180bdd1 |
| Verify TOTP Authenticator | v1 | POST /v1/verify-totp-authenticator | Reverse-engineered | 0.3.0/19332e51 |
| Authenticate with TOTP | v1 | POST /v1/auth/totp | Reverse-engineered | 0.3.0/19332e51 |

View File

@ -1,7 +1,7 @@
use std::{env, process};
use std::path::PathBuf;
use bindgen::CargoCallbacks; use bindgen::CargoCallbacks;
use std::path::Path; use std::path::Path;
use std::path::PathBuf;
use std::{env, process};
fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> { fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?); let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?);
@ -20,7 +20,6 @@ fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Erro
} }
fn main() { fn main() {
// Find compiler: // Find compiler:
// 1. GOC // 1. GOC
// 2. /usr/local/go/bin/go // 2. /usr/local/go/bin/go
@ -49,7 +48,14 @@ fn main() {
let out = out_path.join(out_file); let out = out_path.join(out_file);
let mut command = process::Command::new(compiler); let mut command = process::Command::new(compiler);
command.args(["build", "-buildmode", link_type().as_str(), "-o", out.display().to_string().as_str(), "main.go"]); command.args([
"build",
"-buildmode",
link_type().as_str(),
"-o",
out.display().to_string().as_str(),
"main.go",
]);
command.env("CGO_ENABLED", "1"); command.env("CGO_ENABLED", "1");
command.env("CC", c_compiler.path()); command.env("CC", c_compiler.path());
command.env("GOARCH", goarch()); command.env("GOARCH", goarch());
@ -68,7 +74,10 @@ fn main() {
copy_if_windows(); copy_if_windows();
print_link(); print_link();
println!("cargo:rustc-link-search=native={}", env::var("OUT_DIR").unwrap()); println!(
"cargo:rustc-link-search=native={}",
env::var("OUT_DIR").unwrap()
);
//let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); //let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
@ -85,7 +94,6 @@ fn main() {
.generate() .generate()
.expect("Error generating CFFI bindings"); .expect("Error generating CFFI bindings");
bindings bindings
.write_to_file(out_path.join("bindings.rs")) .write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!"); .expect("Couldn't write bindings!");
@ -125,8 +133,9 @@ fn goarch() -> String {
"powerpc64" => "ppc64", "powerpc64" => "ppc64",
"arm" => "arm", "arm" => "arm",
"aarch64" => "arm64", "aarch64" => "arm64",
arch => panic!("unsupported architecture {arch}") arch => panic!("unsupported architecture {arch}"),
}.to_string() }
.to_string()
} }
fn goos() -> String { fn goos() -> String {
match env::var("CARGO_CFG_TARGET_OS").unwrap().as_str() { match env::var("CARGO_CFG_TARGET_OS").unwrap().as_str() {
@ -139,8 +148,9 @@ fn goos() -> String {
"dragonfly" => "dragonfly", "dragonfly" => "dragonfly",
"openbsd" => "openbsd", "openbsd" => "openbsd",
"netbsd" => "netbsd", "netbsd" => "netbsd",
os => panic!("unsupported operating system {os}") os => panic!("unsupported operating system {os}"),
}.to_string() }
.to_string()
} }
fn print_link() { fn print_link() {

View File

@ -25,7 +25,6 @@
#![deny(clippy::missing_panics_doc)] #![deny(clippy::missing_panics_doc)]
#![deny(clippy::missing_safety_doc)] #![deny(clippy::missing_safety_doc)]
#[allow(non_upper_case_globals)] #[allow(non_upper_case_globals)]
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -36,12 +35,11 @@ pub mod generated {
include!(concat!(env!("OUT_DIR"), "/bindings.rs")); include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
} }
use generated::GoString;
use std::error::Error; use std::error::Error;
use std::ffi::{c_char, CString}; use std::ffi::{c_char, CString};
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::path::{Path}; use std::path::Path;
use generated::GoString;
impl From<&str> for GoString { impl From<&str> for GoString {
#[allow(clippy::cast_possible_wrap)] #[allow(clippy::cast_possible_wrap)]
@ -51,7 +49,7 @@ impl From<&str> for GoString {
let ptr = c_str.as_ptr(); let ptr = c_str.as_ptr();
let go_string = GoString { let go_string = GoString {
p: ptr, p: ptr,
n: c_str.as_bytes().len() as isize n: c_str.as_bytes().len() as isize,
}; };
go_string go_string
} }
@ -73,14 +71,18 @@ impl NebulaInstance {
/// # Panics /// # Panics
/// This function will panic if memory is corrupted while communicating with Go. /// This function will panic if memory is corrupted while communicating with Go.
pub fn new(config_path: &Path, config_test: bool) -> Result<Self, Box<dyn Error>> { pub fn new(config_path: &Path, config_test: bool) -> Result<Self, Box<dyn Error>> {
let mut config_path_bytes = unsafe { config_path.display().to_string().as_bytes_mut().to_vec() }; let mut config_path_bytes =
unsafe { config_path.display().to_string().as_bytes_mut().to_vec() };
config_path_bytes.push(0u8); config_path_bytes.push(0u8);
let config_test_u8 = u8::from(config_test); let config_test_u8 = u8::from(config_test);
let res; let res;
unsafe { unsafe {
res = generated::NebulaSetup(config_path_bytes.as_mut_ptr().cast::<c_char>(), config_test_u8); res = generated::NebulaSetup(
config_path_bytes.as_mut_ptr().cast::<c_char>(),
config_test_u8,
);
} }
let res = cstring_to_string(res); let res = cstring_to_string(res);
@ -194,18 +196,18 @@ pub enum NebulaError {
/// Returned by nebula when the TUN/TAP device already exists /// Returned by nebula when the TUN/TAP device already exists
DeviceOrResourceBusy { DeviceOrResourceBusy {
/// The complete error string returned by the Nebula wrapper /// The complete error string returned by the Nebula wrapper
error_str: String error_str: String,
}, },
/// An unknown error that the error parser couldn't figure out how to parse. /// An unknown error that the error parser couldn't figure out how to parse.
Unknown { Unknown {
/// The complete error string returned by the Nebula wrapper /// The complete error string returned by the Nebula wrapper
error_str: String error_str: String,
}, },
/// Occurs if you call a function before NebulaSetup has been called /// Occurs if you call a function before NebulaSetup has been called
NebulaNotSetup { NebulaNotSetup {
/// The complete error string returned by the Nebula wrapper /// The complete error string returned by the Nebula wrapper
error_str: String error_str: String,
} },
} }
impl Display for NebulaError { impl Display for NebulaError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
@ -223,11 +225,17 @@ impl NebulaError {
#[must_use] #[must_use]
pub fn from_string(string: &str) -> Self { pub fn from_string(string: &str) -> Self {
if string.starts_with("device or resource busy") { if string.starts_with("device or resource busy") {
Self::DeviceOrResourceBusy { error_str: string.to_string() } Self::DeviceOrResourceBusy {
error_str: string.to_string(),
}
} else if string.starts_with("NebulaSetup has not yet been called") { } else if string.starts_with("NebulaSetup has not yet been called") {
Self::NebulaNotSetup { error_str: string.to_string() } Self::NebulaNotSetup {
error_str: string.to_string(),
}
} else { } else {
Self::Unknown { error_str: string.to_string() } Self::Unknown {
error_str: string.to_string(),
}
} }
} }
} }

View File

@ -1,38 +1,48 @@
use crate::api::APIErrorResponse;
use crate::AccountCommands;
use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
use crate::AccountCommands;
use crate::api::APIErrorResponse;
pub async fn account_main(command: AccountCommands, server: Url) -> Result<(), Box<dyn Error>> { pub async fn account_main(command: AccountCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command { match command {
AccountCommands::Create { email } => create_account(email, server).await, AccountCommands::Create { email } => create_account(email, server).await,
AccountCommands::MagicLink { magic_link_token } => auth_magic_link(magic_link_token, server).await, AccountCommands::MagicLink { magic_link_token } => {
auth_magic_link(magic_link_token, server).await
}
AccountCommands::MfaSetup {} => create_mfa_authenticator(server).await, AccountCommands::MfaSetup {} => create_mfa_authenticator(server).await,
AccountCommands::MfaSetupFinish {code, token} => finish_mfa_authenticator(token, code, server).await, AccountCommands::MfaSetupFinish { code, token } => {
AccountCommands::Mfa {code} => mfa_auth(code, server).await, finish_mfa_authenticator(token, code, server).await
AccountCommands::Login { email } => login_account(email, server).await }
AccountCommands::Mfa { code } => mfa_auth(code, server).await,
AccountCommands::Login { email } => login_account(email, server).await,
} }
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct CreateAccountBody { pub struct CreateAccountBody {
pub email: String pub email: String,
} }
pub async fn create_account(email: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn create_account(email: String, server: Url) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let res = client.post(server.join("/v1/signup")?).json(&CreateAccountBody { email }).send().await?; let res = client
.post(server.join("/v1/signup")?)
.json(&CreateAccountBody { email })
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
println!("Account created successfully, check your email."); println!("Account created successfully, check your email.");
println!("Finish creating your account with 'tfcli account magic-link --magic-link-token [magic-link-token]'."); println!("Finish creating your account with 'tfcli account magic-link --magic-link-token [magic-link-token]'.");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating account: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error creating account: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -42,21 +52,27 @@ pub async fn create_account(email: String, server: Url) -> Result<(), Box<dyn Er
#[derive(Serialize)] #[derive(Serialize)]
pub struct LoginAccountBody { pub struct LoginAccountBody {
pub email: String pub email: String,
} }
pub async fn login_account(email: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn login_account(email: String, server: Url) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let res = client.post(server.join("/v1/auth/magic-link")?).json(&LoginAccountBody { email }).send().await?; let res = client
.post(server.join("/v1/auth/magic-link")?)
.json(&LoginAccountBody { email })
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
println!("Magic link sent, check your email."); println!("Magic link sent, check your email.");
println!("Finish creating your account with 'tfcli account magic-link --magic-link-token [magic-link-token]'."); println!("Finish creating your account with 'tfcli account magic-link --magic-link-token [magic-link-token]'.");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error logging in: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error logging in: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -64,26 +80,31 @@ pub async fn login_account(email: String, server: Url) -> Result<(), Box<dyn Err
Ok(()) Ok(())
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct MagicLinkBody { pub struct MagicLinkBody {
#[serde(rename = "magicLinkToken")] #[serde(rename = "magicLinkToken")]
pub magic_link_token: String pub magic_link_token: String,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct MagicLinkSuccess { pub struct MagicLinkSuccess {
pub data: MagicLinkSuccessBody pub data: MagicLinkSuccessBody,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct MagicLinkSuccessBody { pub struct MagicLinkSuccessBody {
#[serde(rename = "sessionToken")] #[serde(rename = "sessionToken")]
pub session_token: String pub session_token: String,
} }
pub async fn auth_magic_link(magic_token: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn auth_magic_link(magic_token: String, server: Url) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let res = client.post(server.join("/v1/auth/verify-magic-link")?).json(&MagicLinkBody { magic_link_token: magic_token }).send().await?; let res = client
.post(server.join("/v1/auth/verify-magic-link")?)
.json(&MagicLinkBody {
magic_link_token: magic_token,
})
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: MagicLinkSuccess = res.json().await?; let resp: MagicLinkSuccess = res.json().await?;
@ -97,7 +118,10 @@ pub async fn auth_magic_link(magic_token: String, server: Url) -> Result<(), Box
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error getting session token: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error getting session token: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -135,14 +159,14 @@ pub struct WhoamiResponseMetadata {}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct CreateMfaResponse { pub struct CreateMfaResponse {
pub data: CreateMfaResponseData pub data: CreateMfaResponseData,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct CreateMfaResponseData { pub struct CreateMfaResponseData {
#[serde(rename = "totpToken")] #[serde(rename = "totpToken")]
pub totp_token: String, pub totp_token: String,
pub secret: String, pub secret: String,
pub url: String pub url: String,
} }
pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>> { pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>> {
@ -153,14 +177,25 @@ pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>>
let session_token = fs::read_to_string(&token_store)?; let session_token = fs::read_to_string(&token_store)?;
// do we have mfa already? // do we have mfa already?
let whoami: WhoamiResponse = client.get(server.join("/v2/whoami")?).bearer_auth(&session_token).send().await?.json().await?; let whoami: WhoamiResponse = client
.get(server.join("/v2/whoami")?)
.bearer_auth(&session_token)
.send()
.await?
.json()
.await?;
if whoami.data.actor.has_totp_authenticator { if whoami.data.actor.has_totp_authenticator {
eprintln!("[error] user already has a totp authenticator, cannot add another one"); eprintln!("[error] user already has a totp authenticator, cannot add another one");
std::process::exit(1); std::process::exit(1);
} }
let res = client.post(server.join("/v1/totp-authenticators")?).bearer_auth(&session_token).body("{}").send().await?; let res = client
.post(server.join("/v1/totp-authenticators")?)
.bearer_auth(&session_token)
.body("{}")
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: CreateMfaResponse = res.json().await?; let resp: CreateMfaResponse = res.json().await?;
@ -169,14 +204,23 @@ pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>>
println!("To complete setup, you'll need a TOTP-compatible app, such as Google Authenticator or Authy."); println!("To complete setup, you'll need a TOTP-compatible app, such as Google Authenticator or Authy.");
println!("Scan the following code with your authenticator app:"); println!("Scan the following code with your authenticator app:");
qr2term::print_qr(resp.data.url)?; qr2term::print_qr(resp.data.url)?;
println!("Alternatively, enter the following secret into your authenticator app: '{}'", resp.data.secret); println!(
"Alternatively, enter the following secret into your authenticator app: '{}'",
resp.data.secret
);
println!("Once done, enable TOTP by running the following command with the code shown on your authenticator app:"); println!("Once done, enable TOTP by running the following command with the code shown on your authenticator app:");
println!("tfcli account mfa-setup-finish --token {} --code [CODE IN AUTHENTICATOR]", resp.data.totp_token); println!(
"tfcli account mfa-setup-finish --token {} --code [CODE IN AUTHENTICATOR]",
resp.data.totp_token
);
println!("This code will expire in 10 minutes."); println!("This code will expire in 10 minutes.");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error adding MFA to account: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error adding MFA to account: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -188,27 +232,39 @@ pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>>
pub struct MfaVerifyBody { pub struct MfaVerifyBody {
#[serde(rename = "totpToken")] #[serde(rename = "totpToken")]
pub totp_token: String, pub totp_token: String,
pub code: String pub code: String,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct MFASuccess { pub struct MFASuccess {
pub data: MFASuccessBody pub data: MFASuccessBody,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct MFASuccessBody { pub struct MFASuccessBody {
#[serde(rename = "authToken")] #[serde(rename = "authToken")]
pub auth_token: String pub auth_token: String,
} }
pub async fn finish_mfa_authenticator(token: String, code: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn finish_mfa_authenticator(
token: String,
code: String,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// load session token // load session token
let token_store = dirs::config_dir().unwrap().join("tfcli-session.token"); let token_store = dirs::config_dir().unwrap().join("tfcli-session.token");
let session_token = fs::read_to_string(&token_store)?; let session_token = fs::read_to_string(&token_store)?;
let res = client.post(server.join("/v1/verify-totp-authenticators")?).json(&MfaVerifyBody {totp_token: token, code }).bearer_auth(session_token).send().await?; let res = client
.post(server.join("/v1/verify-totp-authenticators")?)
.json(&MfaVerifyBody {
totp_token: token,
code,
})
.bearer_auth(session_token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: MFASuccess = res.json().await?; let resp: MFASuccess = res.json().await?;
@ -222,7 +278,10 @@ pub async fn finish_mfa_authenticator(token: String, code: String, server: Url)
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error verifying MFA code: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error verifying MFA code: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -232,7 +291,7 @@ pub async fn finish_mfa_authenticator(token: String, code: String, server: Url)
#[derive(Serialize)] #[derive(Serialize)]
pub struct MfaAuthBody { pub struct MfaAuthBody {
pub code: String pub code: String,
} }
pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -242,7 +301,12 @@ pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> {
let token_store = dirs::config_dir().unwrap().join("tfcli-session.token"); let token_store = dirs::config_dir().unwrap().join("tfcli-session.token");
let session_token = fs::read_to_string(&token_store)?; let session_token = fs::read_to_string(&token_store)?;
let res = client.post(server.join("/v1/auth/totp")?).json(&MfaAuthBody { code }).bearer_auth(session_token).send().await?; let res = client
.post(server.join("/v1/auth/totp")?)
.json(&MfaAuthBody { code })
.bearer_auth(session_token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: MFASuccess = res.json().await?; let resp: MFASuccess = res.json().await?;
@ -256,7 +320,10 @@ pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> {
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error verifying MFA code: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error verifying MFA code: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }

View File

@ -2,11 +2,11 @@ use serde::Deserialize;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct APIErrorResponse { pub struct APIErrorResponse {
pub errors: Vec<APIError> pub errors: Vec<APIError>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct APIError { pub struct APIError {
pub code: String, pub code: String,
pub message: String, pub message: String,
pub path: Option<String> pub path: Option<String>,
} }

View File

@ -1,34 +1,68 @@
use crate::api::APIErrorResponse;
use crate::{HostCommands, HostOverrideCommands};
use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddrV4};
use serde::{Deserialize, Serialize}; use url::Url;
use url::{Url};
use crate::api::APIErrorResponse;
use crate::{HostCommands, HostOverrideCommands};
pub async fn host_main(command: HostCommands, server: Url) -> Result<(), Box<dyn Error>> { pub async fn host_main(command: HostCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command { match command {
HostCommands::List {} => list_hosts(server).await, HostCommands::List {} => list_hosts(server).await,
HostCommands::Create { name, network_id, role_id, ip_address, listen_port, lighthouse, relay, static_address } => create_host(name, network_id, role_id, ip_address, listen_port, lighthouse, relay, static_address, server).await, HostCommands::Create {
name,
network_id,
role_id,
ip_address,
listen_port,
lighthouse,
relay,
static_address,
} => {
create_host(
name,
network_id,
role_id,
ip_address,
listen_port,
lighthouse,
relay,
static_address,
server,
)
.await
}
HostCommands::Lookup { id } => get_host(id, server).await, HostCommands::Lookup { id } => get_host(id, server).await,
HostCommands::Delete { id } => delete_host(id, server).await, HostCommands::Delete { id } => delete_host(id, server).await,
HostCommands::Update { id, listen_port, static_address, name, ip, role } => update_host(id, listen_port, static_address, name, ip, role, server).await, HostCommands::Update {
id,
listen_port,
static_address,
name,
ip,
role,
} => update_host(id, listen_port, static_address, name, ip, role, server).await,
HostCommands::Block { id } => block_host(id, server).await, HostCommands::Block { id } => block_host(id, server).await,
HostCommands::Enroll { id } => enroll_host(id, server).await, HostCommands::Enroll { id } => enroll_host(id, server).await,
HostCommands::Overrides { command } => match command { HostCommands::Overrides { command } => match command {
HostOverrideCommands::List { id } => list_overrides(id, server).await, HostOverrideCommands::List { id } => list_overrides(id, server).await,
HostOverrideCommands::Set { id, key, boolean, string, numeric } => set_override(id, key, boolean, numeric, string, server).await, HostOverrideCommands::Set {
HostOverrideCommands::Unset { id, key } => unset_override(id, key, server).await id,
} key,
boolean,
string,
numeric,
} => set_override(id, key, boolean, numeric, string, server).await,
HostOverrideCommands::Unset { id, key } => unset_override(id, key, server).await,
},
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct HostListResp { pub struct HostListResp {
pub data: Vec<Host> pub data: Vec<Host>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostMetadata { pub struct HostMetadata {
#[serde(rename = "lastSeenAt")] #[serde(rename = "lastSeenAt")]
@ -77,7 +111,11 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join("/v1/hosts?pageSize=5000")?).bearer_auth(token).send().await?; let res = client
.get(server.join("/v1/hosts?pageSize=5000")?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: HostListResp = res.json().await?; let resp: HostListResp = res.json().await?;
@ -88,14 +126,33 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
println!(" Network: {}", host.network_id); println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id); println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address); println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", ")); println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port); println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } ); println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked); println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at); println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version); println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform); println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available); println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at); println!(" Created: {}", host.created_at);
println!(); println!();
} }
@ -106,7 +163,10 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing hosts: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error listing hosts: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -114,7 +174,6 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
Ok(()) Ok(())
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostCreateBody { pub struct HostCreateBody {
pub name: String, pub name: String,
@ -134,11 +193,9 @@ pub struct HostCreateBody {
pub static_addresses: Vec<SocketAddrV4>, pub static_addresses: Vec<SocketAddrV4>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostGetMetadata {} pub struct HostGetMetadata {}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostGetResponse { pub struct HostGetResponse {
pub data: Host, pub data: Host,
@ -146,7 +203,17 @@ pub struct HostGetResponse {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn create_host(name: String, network_id: String, role_id: String, ip_address: Ipv4Addr, listen_port: Option<u16>, lighthouse: bool, relay: bool, static_address: Option<SocketAddrV4>, server: Url) -> Result<(), Box<dyn Error>> { pub async fn create_host(
name: String,
network_id: String,
role_id: String,
ip_address: Ipv4Addr,
listen_port: Option<u16>,
lighthouse: bool,
relay: bool,
static_address: Option<SocketAddrV4>,
server: Url,
) -> Result<(), Box<dyn Error>> {
if lighthouse && relay { if lighthouse && relay {
eprintln!("[error] Error creating host: a host cannot be both a lighthouse and a relay at the same time"); eprintln!("[error] Error creating host: a host cannot be both a lighthouse and a relay at the same time");
std::process::exit(1); std::process::exit(1);
@ -172,7 +239,9 @@ pub async fn create_host(name: String, network_id: String, role_id: String, ip_a
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join("/v1/hosts")?).json(&HostCreateBody { let res = client
.post(server.join("/v1/hosts")?)
.json(&HostCreateBody {
name, name,
network_id, network_id,
role_id, role_id,
@ -181,7 +250,10 @@ pub async fn create_host(name: String, network_id: String, role_id: String, ip_a
is_lighthouse: lighthouse, is_lighthouse: lighthouse,
is_relay: relay, is_relay: relay,
static_addresses: static_address.map_or(vec![], |u| vec![u]), static_addresses: static_address.map_or(vec![], |u| vec![u]),
}).bearer_auth(token).send().await?; })
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let host: Host = res.json::<HostGetResponse>().await?.data; let host: Host = res.json::<HostGetResponse>().await?.data;
@ -191,21 +263,42 @@ pub async fn create_host(name: String, network_id: String, role_id: String, ip_a
println!(" Network: {}", host.network_id); println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id); println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address); println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", ")); println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port); println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } ); println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked); println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at); println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version); println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform); println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available); println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at); println!(" Created: {}", host.created_at);
println!(); println!();
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating host: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error creating host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -224,7 +317,11 @@ pub async fn get_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}", id))?).bearer_auth(token).send().await?; let res = client
.get(server.join(&format!("/v1/hosts/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let host: Host = res.json::<HostGetResponse>().await?.data; let host: Host = res.json::<HostGetResponse>().await?.data;
@ -234,21 +331,42 @@ pub async fn get_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
println!(" Network: {}", host.network_id); println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id); println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address); println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", ")); println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port); println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } ); println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked); println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at); println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version); println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform); println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available); println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at); println!(" Created: {}", host.created_at);
println!(); println!();
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing hosts: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error listing hosts: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -267,14 +385,21 @@ pub async fn delete_host(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.delete(server.join(&format!("/v1/hosts/{}", id))?).bearer_auth(token).send().await?; let res = client
.delete(server.join(&format!("/v1/hosts/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
println!("Host removed"); println!("Host removed");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error removing host: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error removing host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -290,10 +415,18 @@ pub struct HostUpdateBody {
pub static_addresses: Vec<SocketAddrV4>, pub static_addresses: Vec<SocketAddrV4>,
pub name: Option<String>, pub name: Option<String>,
pub ip: Option<Ipv4Addr>, pub ip: Option<Ipv4Addr>,
pub role: Option<String> pub role: Option<String>,
} }
pub async fn update_host(id: String, listen_port: Option<u16>, static_address: Option<SocketAddrV4>, name: Option<String>, ip: Option<Ipv4Addr>, role: Option<String>, server: Url) -> Result<(), Box<dyn Error>> { pub async fn update_host(
id: String,
listen_port: Option<u16>,
static_address: Option<SocketAddrV4>,
name: Option<String>,
ip: Option<Ipv4Addr>,
role: Option<String>,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// load session token // load session token
@ -304,13 +437,18 @@ pub async fn update_host(id: String, listen_port: Option<u16>, static_address: O
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.put(server.join(&format!("/v1/hosts/{}?extension=extended_hosts", id))?).json(&HostUpdateBody { let res = client
.put(server.join(&format!("/v1/hosts/{}?extension=extended_hosts", id))?)
.json(&HostUpdateBody {
listen_port: listen_port.unwrap_or(0), listen_port: listen_port.unwrap_or(0),
static_addresses: static_address.map_or_else(Vec::new, |u| vec![u]), static_addresses: static_address.map_or_else(Vec::new, |u| vec![u]),
name, name,
ip, ip,
role role,
}).bearer_auth(token).send().await?; })
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let host: Host = res.json::<HostGetResponse>().await?.data; let host: Host = res.json::<HostGetResponse>().await?.data;
@ -320,21 +458,42 @@ pub async fn update_host(id: String, listen_port: Option<u16>, static_address: O
println!(" Network: {}", host.network_id); println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id); println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address); println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", ")); println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port); println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } ); println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked); println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at); println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version); println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform); println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available); println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at); println!(" Created: {}", host.created_at);
println!(); println!();
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error updating host: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error updating host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -342,7 +501,6 @@ pub async fn update_host(id: String, listen_port: Option<u16>, static_address: O
Ok(()) Ok(())
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct EnrollmentCodeResponseMetadata {} pub struct EnrollmentCodeResponseMetadata {}
@ -376,18 +534,29 @@ pub async fn enroll_host(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join(&format!("/v1/hosts/{}/enrollment-code", id))?).header("content-length", 0).bearer_auth(token).send().await?; let res = client
.post(server.join(&format!("/v1/hosts/{}/enrollment-code", id))?)
.header("content-length", 0)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: EnrollmentResponse = res.json().await?; let resp: EnrollmentResponse = res.json().await?;
println!("Enrollment code generated. Enroll the host with the following code: {}", resp.data.enrollment_code.code); println!(
"Enrollment code generated. Enroll the host with the following code: {}",
resp.data.enrollment_code.code
);
println!("This code will be valid for {} seconds, at which point you will need to generate a new code", resp.data.enrollment_code.lifetime_seconds); println!("This code will be valid for {} seconds, at which point you will need to generate a new code", resp.data.enrollment_code.lifetime_seconds);
println!("If this host is blocked, a successful re-enrollment will unblock it."); println!("If this host is blocked, a successful re-enrollment will unblock it.");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error blocking host: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error blocking host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -406,14 +575,22 @@ pub async fn block_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join(&format!("/v1/hosts/{}/block", id))?).header("Content-Length", "0").bearer_auth(token).send().await?; let res = client
.post(server.join(&format!("/v1/hosts/{}/block", id))?)
.header("Content-Length", "0")
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
println!("Host blocked. To unblock it, re-enroll the host."); println!("Host blocked. To unblock it, re-enroll the host.");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error blocking host: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error blocking host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -423,18 +600,18 @@ pub async fn block_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostConfigOverrideResponse { pub struct HostConfigOverrideResponse {
pub data: HostConfigOverrideData pub data: HostConfigOverrideData,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostConfigOverrideData { pub struct HostConfigOverrideData {
pub overrides: Vec<HostConfigOverrideDataOverride> pub overrides: Vec<HostConfigOverrideDataOverride>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct HostConfigOverrideDataOverride { pub struct HostConfigOverrideDataOverride {
pub key: String, pub key: String,
pub value: HostConfigOverrideDataOverrideValue pub value: HostConfigOverrideDataOverrideValue,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -442,7 +619,7 @@ pub struct HostConfigOverrideDataOverride {
pub enum HostConfigOverrideDataOverrideValue { pub enum HostConfigOverrideDataOverrideValue {
Boolean(bool), Boolean(bool),
Numeric(i64), Numeric(i64),
Other(String) Other(String),
} }
pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -456,18 +633,25 @@ pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token).send().await?; let res = client
.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?; let resp: HostConfigOverrideResponse = res.json().await?;
for c_override in &resp.data.overrides { for c_override in &resp.data.overrides {
println!(" Key: {}", c_override.key); println!(" Key: {}", c_override.key);
println!("Value: {}", match &c_override.value { println!(
"Value: {}",
match &c_override.value {
HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v), HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v),
HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v), HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v),
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v) HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v),
}); }
);
} }
if resp.data.overrides.is_empty() { if resp.data.overrides.is_empty() {
@ -476,7 +660,10 @@ pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error looking up config overrides: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error looking up config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -486,14 +673,24 @@ pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct SetOverrideRequest { pub struct SetOverrideRequest {
pub overrides: Vec<HostConfigOverrideDataOverride> pub overrides: Vec<HostConfigOverrideDataOverride>,
} }
pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeric: Option<i64>, other: Option<String>, server: Url) -> Result<(), Box<dyn Error>> { pub async fn set_override(
id: String,
key: String,
boolean: Option<bool>,
numeric: Option<i64>,
other: Option<String>,
server: Url,
) -> Result<(), Box<dyn Error>> {
if boolean.is_none() && numeric.is_none() && other.is_none() { if boolean.is_none() && numeric.is_none() && other.is_none() {
eprintln!("[error] no value provided: you must provide at least --boolean, --numeric, or --string"); eprintln!("[error] no value provided: you must provide at least --boolean, --numeric, or --string");
std::process::exit(1); std::process::exit(1);
} else if boolean.is_some() && numeric.is_some() || boolean.is_some() && other.is_some() || numeric.is_some() && other.is_some() { } else if boolean.is_some() && numeric.is_some()
|| boolean.is_some() && other.is_some()
|| numeric.is_some() && other.is_some()
{
eprintln!("[error] multiple values provided: you must provide only one of --boolean, --numeric, or --string"); eprintln!("[error] multiple values provided: you must provide only one of --boolean, --numeric, or --string");
std::process::exit(1); std::process::exit(1);
} }
@ -520,7 +717,11 @@ pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeri
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).send().await?; let res = client
.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token.clone())
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?; let resp: HostConfigOverrideResponse = res.json().await?;
@ -533,25 +734,28 @@ pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeri
} }
} }
others.push(HostConfigOverrideDataOverride { others.push(HostConfigOverrideDataOverride { key, value: val });
key,
value: val,
});
let res = client.put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).json(&SetOverrideRequest { let res = client
overrides: others, .put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
}).send().await?; .bearer_auth(token.clone())
.json(&SetOverrideRequest { overrides: others })
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?; let resp: HostConfigOverrideResponse = res.json().await?;
for c_override in &resp.data.overrides { for c_override in &resp.data.overrides {
println!(" Key: {}", c_override.key); println!(" Key: {}", c_override.key);
println!("Value: {}", match &c_override.value { println!(
"Value: {}",
match &c_override.value {
HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v), HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v),
HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v), HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v),
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v) HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v),
}); }
);
} }
if resp.data.overrides.is_empty() { if resp.data.overrides.is_empty() {
@ -562,14 +766,20 @@ pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeri
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error setting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error setting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error setting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error setting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -588,7 +798,11 @@ pub async fn unset_override(id: String, key: String, server: Url) -> Result<(),
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).send().await?; let res = client
.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token.clone())
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?; let resp: HostConfigOverrideResponse = res.json().await?;
@ -601,20 +815,26 @@ pub async fn unset_override(id: String, key: String, server: Url) -> Result<(),
} }
} }
let res = client.put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).json(&SetOverrideRequest { let res = client
overrides: others, .put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
}).send().await?; .bearer_auth(token.clone())
.json(&SetOverrideRequest { overrides: others })
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?; let resp: HostConfigOverrideResponse = res.json().await?;
for c_override in &resp.data.overrides { for c_override in &resp.data.overrides {
println!(" Key: {}", c_override.key); println!(" Key: {}", c_override.key);
println!("Value: {}", match &c_override.value { println!(
"Value: {}",
match &c_override.value {
HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v), HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v),
HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v), HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v),
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v) HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v),
}); }
);
} }
if resp.data.overrides.is_empty() { if resp.data.overrides.is_empty() {
@ -625,14 +845,20 @@ pub async fn unset_override(id: String, key: String, server: Url) -> Result<(),
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error unsetting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error unsetting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error unsetting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error unsetting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }

View File

@ -1,21 +1,21 @@
use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use clap::{Parser, Subcommand};
use ipnet::Ipv4Net;
use url::Url;
use crate::account::account_main; use crate::account::account_main;
use crate::host::host_main; use crate::host::host_main;
use crate::network::network_main; use crate::network::network_main;
use crate::org::org_main; use crate::org::org_main;
use crate::role::role_main; use crate::role::role_main;
use clap::{Parser, Subcommand};
use ipnet::Ipv4Net;
use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use url::Url;
pub mod account; pub mod account;
pub mod api; pub mod api;
pub mod host;
pub mod network; pub mod network;
pub mod org; pub mod org;
pub mod role; pub mod role;
pub mod host;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
@ -25,7 +25,7 @@ pub struct Args {
command: Commands, command: Commands,
#[clap(short, long, env = "TFCLI_SERVER")] #[clap(short, long, env = "TFCLI_SERVER")]
/// The base URL of your trifid-api-old instance. Defaults to the value in $XDG_CONFIG_HOME/tfcli-server-url.conf or the TFCLI_SERVER environment variable. /// The base URL of your trifid-api-old instance. Defaults to the value in $XDG_CONFIG_HOME/tfcli-server-url.conf or the TFCLI_SERVER environment variable.
server: Option<Url> server: Option<Url>,
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -33,28 +33,28 @@ pub enum Commands {
/// Manage your trifid account /// Manage your trifid account
Account { Account {
#[command(subcommand)] #[command(subcommand)]
command: AccountCommands command: AccountCommands,
}, },
/// Manage the networks associated with your trifid account /// Manage the networks associated with your trifid account
Network { Network {
#[command(subcommand)] #[command(subcommand)]
command: NetworkCommands command: NetworkCommands,
}, },
/// Manage the organization associated with your trifid account /// Manage the organization associated with your trifid account
Org { Org {
#[command(subcommand)] #[command(subcommand)]
command: OrgCommands command: OrgCommands,
}, },
/// Manage the roles associated with your trifid organization /// Manage the roles associated with your trifid organization
Role { Role {
#[command(subcommand)] #[command(subcommand)]
command: RoleCommands command: RoleCommands,
}, },
/// Manage the hosts associated with your trifid network /// Manage the hosts associated with your trifid network
Host { Host {
#[command(subcommand)] #[command(subcommand)]
command: HostCommands command: HostCommands,
} },
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -62,17 +62,17 @@ pub enum AccountCommands {
/// Create a new trifid account on the designated server /// Create a new trifid account on the designated server
Create { Create {
#[clap(short, long)] #[clap(short, long)]
email: String email: String,
}, },
/// Log into an existing account on the designated server /// Log into an existing account on the designated server
Login { Login {
#[clap(short, long)] #[clap(short, long)]
email: String email: String,
}, },
/// Log in to your account with a magic-link token acquired via email or the trifid-api-old logs. /// Log in to your account with a magic-link token acquired via email or the trifid-api-old logs.
MagicLink { MagicLink {
#[clap(short, long)] #[clap(short, long)]
magic_link_token: String magic_link_token: String,
}, },
/// Create a new TOTP authenticator on this account to enable authorizing with 2fa and performing all management tasks. /// Create a new TOTP authenticator on this account to enable authorizing with 2fa and performing all management tasks.
MfaSetup {}, MfaSetup {},
@ -81,13 +81,13 @@ pub enum AccountCommands {
#[clap(short, long)] #[clap(short, long)]
code: String, code: String,
#[clap(short, long)] #[clap(short, long)]
token: String token: String,
}, },
/// Create a new short-lived authentication token by inputting the code shown on your authenticator app. /// Create a new short-lived authentication token by inputting the code shown on your authenticator app.
Mfa { Mfa {
#[clap(short, long)] #[clap(short, long)]
code: String code: String,
} },
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -97,8 +97,8 @@ pub enum NetworkCommands {
/// Lookup a specific network by ID. /// Lookup a specific network by ID.
Lookup { Lookup {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
} },
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -106,8 +106,8 @@ pub enum OrgCommands {
/// Create an organization on your trifid-api-old server. NOTE: This command ONLY works on trifid-api-old servers. It will NOT work on original DN servers. /// Create an organization on your trifid-api-old server. NOTE: This command ONLY works on trifid-api-old servers. It will NOT work on original DN servers.
Create { Create {
#[clap(short, long)] #[clap(short, long)]
cidr: Ipv4Net cidr: Ipv4Net,
} },
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -120,19 +120,19 @@ pub enum RoleCommands {
description: String, description: String,
/// A JSON string containing the firewall rules to add to this host /// A JSON string containing the firewall rules to add to this host
#[clap(short, long)] #[clap(short, long)]
rules_json: String rules_json: String,
}, },
/// List all roles attached to your organization /// List all roles attached to your organization
List {}, List {},
/// Lookup a specific role by it's ID /// Lookup a specific role by it's ID
Lookup { Lookup {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Delete a specific role by it's ID /// Delete a specific role by it's ID
Delete { Delete {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Update a specific role by it's ID. Warning: any data not provided in this update will be removed - include all data you wish to remain /// Update a specific role by it's ID. Warning: any data not provided in this update will be removed - include all data you wish to remain
Update { Update {
@ -142,8 +142,8 @@ pub enum RoleCommands {
description: String, description: String,
/// A JSON string containing the firewall rules to add to this host /// A JSON string containing the firewall rules to add to this host
#[clap(short, long)] #[clap(short, long)]
rules_json: String rules_json: String,
} },
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -165,19 +165,19 @@ pub enum HostCommands {
#[clap(short = 'R', long)] #[clap(short = 'R', long)]
relay: bool, relay: bool,
#[clap(short, long)] #[clap(short, long)]
static_address: Option<SocketAddrV4> static_address: Option<SocketAddrV4>,
}, },
/// List all hosts on your network /// List all hosts on your network
List {}, List {},
/// Lookup a specific host by it's ID /// Lookup a specific host by it's ID
Lookup { Lookup {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Delete a specific host by it's ID /// Delete a specific host by it's ID
Delete { Delete {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Update a specific host by it's ID, changing the listen port and static addresses, as well as the name, ip and role. The name, ip and role updates will only work on trifid-api-old compatible servers. /// Update a specific host by it's ID, changing the listen port and static addresses, as well as the name, ip and role. The name, ip and role updates will only work on trifid-api-old compatible servers.
Update { Update {
@ -192,23 +192,23 @@ pub enum HostCommands {
#[clap(short, long)] #[clap(short, long)]
role: Option<String>, role: Option<String>,
#[clap(short = 'I', long)] #[clap(short = 'I', long)]
ip: Option<Ipv4Addr> ip: Option<Ipv4Addr>,
}, },
/// Blocks the specified host from the network /// Blocks the specified host from the network
Block { Block {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Enroll or re-enroll the host by generating an enrollment code /// Enroll or re-enroll the host by generating an enrollment code
Enroll { Enroll {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Manage config overrides set on the host /// Manage config overrides set on the host
Overrides { Overrides {
#[command(subcommand)] #[command(subcommand)]
command: HostOverrideCommands command: HostOverrideCommands,
} },
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -216,7 +216,7 @@ pub enum HostOverrideCommands {
/// List the config overrides set on the host /// List the config overrides set on the host
List { List {
#[clap(short, long)] #[clap(short, long)]
id: String id: String,
}, },
/// Set a config override on the host /// Set a config override on the host
Set { Set {
@ -229,15 +229,15 @@ pub enum HostOverrideCommands {
#[clap(short, long)] #[clap(short, long)]
numeric: Option<i64>, numeric: Option<i64>,
#[clap(short, long)] #[clap(short, long)]
string: Option<String> string: Option<String>,
}, },
/// Unset a config override on the host /// Unset a config override on the host
Unset { Unset {
#[clap(short, long)] #[clap(short, long)]
id: String, id: String,
#[clap(short, long)] #[clap(short, long)]
key: String key: String,
} },
} }
#[tokio::main] #[tokio::main]
@ -270,7 +270,10 @@ async fn main2() -> Result<(), Box<dyn Error>> {
let url = match Url::parse(&url_s) { let url = match Url::parse(&url_s) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
eprintln!("[error] unable to parse the URL in {}", srv_url_file.display()); eprintln!(
"[error] unable to parse the URL in {}",
srv_url_file.display()
);
eprintln!("[error] urlparse returned error '{}'", e); eprintln!("[error] urlparse returned error '{}'", e);
eprintln!("[error] please correct the error and try again"); eprintln!("[error] please correct the error and try again");
std::process::exit(1); std::process::exit(1);
@ -284,6 +287,6 @@ async fn main2() -> Result<(), Box<dyn Error>> {
Commands::Network { command } => network_main(command, server).await, Commands::Network { command } => network_main(command, server).await,
Commands::Org { command } => org_main(command, server).await, Commands::Org { command } => org_main(command, server).await,
Commands::Role { command } => role_main(command, server).await, Commands::Role { command } => role_main(command, server).await,
Commands::Host { command } => host_main(command, server).await Commands::Host { command } => host_main(command, server).await,
} }
} }

View File

@ -1,20 +1,20 @@
use std::error::Error;
use std::fs;
use serde::Deserialize;
use url::Url;
use crate::api::APIErrorResponse; use crate::api::APIErrorResponse;
use crate::NetworkCommands; use crate::NetworkCommands;
use serde::Deserialize;
use std::error::Error;
use std::fs;
use url::Url;
pub async fn network_main(command: NetworkCommands, server: Url) -> Result<(), Box<dyn Error>> { pub async fn network_main(command: NetworkCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command { match command {
NetworkCommands::List {} => list_networks(server).await, NetworkCommands::List {} => list_networks(server).await,
NetworkCommands::Lookup {id} => get_network(id, server).await NetworkCommands::Lookup { id } => get_network(id, server).await,
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct NetworkListResp { pub struct NetworkListResp {
pub data: Vec<Network> pub data: Vec<Network>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -29,7 +29,7 @@ pub struct Network {
pub created_at: String, pub created_at: String,
#[serde(rename = "lighthousesAsRelays")] #[serde(rename = "lighthousesAsRelays")]
pub lighthouses_as_relays: bool, pub lighthouses_as_relays: bool,
pub name: String pub name: String,
} }
pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> { pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
@ -43,7 +43,11 @@ pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join("/v1/networks")?).bearer_auth(token).send().await?; let res = client
.get(server.join("/v1/networks")?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: NetworkListResp = res.json().await?; let resp: NetworkListResp = res.json().await?;
@ -65,7 +69,10 @@ pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing networks: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error listing networks: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -75,7 +82,7 @@ pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct NetworkGetResponse { pub struct NetworkGetResponse {
pub data: Network pub data: Network,
} }
pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -89,7 +96,11 @@ pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/networks/{}", id))?).bearer_auth(token).send().await?; let res = client
.get(server.join(&format!("/v1/networks/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let network: Network = res.json::<NetworkGetResponse>().await?.data; let network: Network = res.json::<NetworkGetResponse>().await?.data;
@ -101,11 +112,13 @@ pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>>
println!("Dedicated Relays: {}", !network.lighthouses_as_relays); println!("Dedicated Relays: {}", !network.lighthouses_as_relays);
println!(" Name: {}", network.name); println!(" Name: {}", network.name);
println!(" Created At: {}", network.created_at); println!(" Created At: {}", network.created_at);
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing networks: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error listing networks: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }

View File

@ -1,10 +1,10 @@
use std::error::Error; use crate::api::APIErrorResponse;
use std::fs; use crate::OrgCommands;
use ipnet::Ipv4Net; use ipnet::Ipv4Net;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use url::Url; use url::Url;
use crate::OrgCommands;
use crate::api::APIErrorResponse;
pub async fn org_main(command: OrgCommands, server: Url) -> Result<(), Box<dyn Error>> { pub async fn org_main(command: OrgCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command { match command {
@ -14,14 +14,14 @@ pub async fn org_main(command: OrgCommands, server: Url) -> Result<(), Box<dyn E
#[derive(Serialize)] #[derive(Serialize)]
pub struct CreateOrgBody { pub struct CreateOrgBody {
pub cidr: String pub cidr: String,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct OrgCreateResponse { pub struct OrgCreateResponse {
pub organization: String, pub organization: String,
pub ca: String, pub ca: String,
pub network: String pub network: String,
} }
pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>> { pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>> {
@ -34,7 +34,14 @@ pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join("/v1/organization")?).json(&CreateOrgBody { cidr: cidr.to_string() }).bearer_auth(token).send().await?; let res = client
.post(server.join("/v1/organization")?)
.json(&CreateOrgBody {
cidr: cidr.to_string(),
})
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: OrgCreateResponse = res.json().await?; let resp: OrgCreateResponse = res.json().await?;
@ -45,7 +52,10 @@ pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating org: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error creating org: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }

View File

@ -1,23 +1,31 @@
use crate::api::APIErrorResponse;
use crate::RoleCommands;
use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
use crate::api::APIErrorResponse;
use crate::{RoleCommands};
pub async fn role_main(command: RoleCommands, server: Url) -> Result<(), Box<dyn Error>> { pub async fn role_main(command: RoleCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command { match command {
RoleCommands::List {} => list_roles(server).await, RoleCommands::List {} => list_roles(server).await,
RoleCommands::Lookup {id} => get_role(id, server).await, RoleCommands::Lookup { id } => get_role(id, server).await,
RoleCommands::Create { name, description, rules_json } => create_role(name, description, rules_json, server).await, RoleCommands::Create {
name,
description,
rules_json,
} => create_role(name, description, rules_json, server).await,
RoleCommands::Delete { id } => delete_role(id, server).await, RoleCommands::Delete { id } => delete_role(id, server).await,
RoleCommands::Update { id, description, rules_json } => update_role(id, description, rules_json, server).await RoleCommands::Update {
id,
description,
rules_json,
} => update_role(id, description, rules_json, server).await,
} }
} }
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct RoleListResp { pub struct RoleListResp {
pub data: Vec<Role> pub data: Vec<Role>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -30,7 +38,7 @@ pub struct Role {
#[serde(rename = "createdAt")] #[serde(rename = "createdAt")]
pub created_at: String, pub created_at: String,
#[serde(rename = "modifiedAt")] #[serde(rename = "modifiedAt")]
pub modified_at: String pub modified_at: String,
} }
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct RoleFirewallRule { pub struct RoleFirewallRule {
@ -39,12 +47,12 @@ pub struct RoleFirewallRule {
#[serde(rename = "allowedRoleID")] #[serde(rename = "allowedRoleID")]
pub allowed_role_id: Option<String>, pub allowed_role_id: Option<String>,
#[serde(rename = "portRange")] #[serde(rename = "portRange")]
pub port_range: Option<RoleFirewallRulePortRange> pub port_range: Option<RoleFirewallRulePortRange>,
} }
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct RoleFirewallRulePortRange { pub struct RoleFirewallRulePortRange {
pub from: u16, pub from: u16,
pub to: u16 pub to: u16,
} }
pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> { pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
@ -58,7 +66,11 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join("/v1/roles")?).bearer_auth(token).send().await?; let res = client
.get(server.join("/v1/roles")?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let resp: RoleListResp = res.json().await?; let resp: RoleListResp = res.json().await?;
@ -68,9 +80,21 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
println!(" Description: {}", role.description); println!(" Description: {}", role.description);
for rule in &role.firewall_rules { for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description); println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string())); println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol); println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() }); println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
} }
println!(" Created: {}", role.created_at); println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at); println!(" Updated: {}", role.modified_at);
@ -82,7 +106,10 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing roles: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error listing roles: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -92,7 +119,7 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct RoleGetResponse { pub struct RoleGetResponse {
pub data: Role pub data: Role,
} }
pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -106,7 +133,11 @@ pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/roles/{}", id))?).bearer_auth(token).send().await?; let res = client
.get(server.join(&format!("/v1/roles/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let role: Role = res.json::<RoleGetResponse>().await?.data; let role: Role = res.json::<RoleGetResponse>().await?.data;
@ -115,17 +146,31 @@ pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> {
println!(" Description: {}", role.description); println!(" Description: {}", role.description);
for rule in &role.firewall_rules { for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description); println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string())); println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol); println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() }); println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
} }
println!(" Created: {}", role.created_at); println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at); println!(" Updated: {}", role.modified_at);
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing roles: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error listing roles: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -138,10 +183,15 @@ pub struct RoleCreateBody {
pub name: String, pub name: String,
pub description: String, pub description: String,
#[serde(rename = "firewallRules")] #[serde(rename = "firewallRules")]
pub firewall_rules: Vec<RoleFirewallRule> pub firewall_rules: Vec<RoleFirewallRule>,
} }
pub async fn create_role(name: String, description: String, rules_json: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn create_role(
name: String,
description: String,
rules_json: String,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let rules: Vec<RoleFirewallRule> = match serde_json::from_str(&rules_json) { let rules: Vec<RoleFirewallRule> = match serde_json::from_str(&rules_json) {
@ -160,11 +210,16 @@ pub async fn create_role(name: String, description: String, rules_json: String,
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join("/v1/roles")?).json(&RoleCreateBody { let res = client
.post(server.join("/v1/roles")?)
.json(&RoleCreateBody {
name, name,
description, description,
firewall_rules: rules, firewall_rules: rules,
}).bearer_auth(token).send().await?; })
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let role: Role = res.json::<RoleGetResponse>().await?.data; let role: Role = res.json::<RoleGetResponse>().await?.data;
@ -173,17 +228,31 @@ pub async fn create_role(name: String, description: String, rules_json: String,
println!(" Description: {}", role.description); println!(" Description: {}", role.description);
for rule in &role.firewall_rules { for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description); println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string())); println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol); println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() }); println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
} }
println!(" Created: {}", role.created_at); println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at); println!(" Updated: {}", role.modified_at);
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating role: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error creating role: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -195,10 +264,15 @@ pub async fn create_role(name: String, description: String, rules_json: String,
pub struct RoleUpdateBody { pub struct RoleUpdateBody {
pub description: String, pub description: String,
#[serde(rename = "firewallRules")] #[serde(rename = "firewallRules")]
pub firewall_rules: Vec<RoleFirewallRule> pub firewall_rules: Vec<RoleFirewallRule>,
} }
pub async fn update_role(id: String, description: String, rules_json: String, server: Url) -> Result<(), Box<dyn Error>> { pub async fn update_role(
id: String,
description: String,
rules_json: String,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let rules: Vec<RoleFirewallRule> = match serde_json::from_str(&rules_json) { let rules: Vec<RoleFirewallRule> = match serde_json::from_str(&rules_json) {
@ -217,10 +291,15 @@ pub async fn update_role(id: String, description: String, rules_json: String, se
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.put(server.join(&format!("/v1/roles/{}", id))?).json(&RoleUpdateBody { let res = client
.put(server.join(&format!("/v1/roles/{}", id))?)
.json(&RoleUpdateBody {
description, description,
firewall_rules: rules, firewall_rules: rules,
}).bearer_auth(token).send().await?; })
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
let role: Role = res.json::<RoleGetResponse>().await?.data; let role: Role = res.json::<RoleGetResponse>().await?.data;
@ -229,17 +308,31 @@ pub async fn update_role(id: String, description: String, rules_json: String, se
println!(" Description: {}", role.description); println!(" Description: {}", role.description);
for rule in &role.firewall_rules { for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description); println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string())); println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol); println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() }); println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
} }
println!(" Created: {}", role.created_at); println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at); println!(" Updated: {}", role.modified_at);
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error updating role: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error updating role: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }
@ -258,14 +351,21 @@ pub async fn delete_role(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token); let token = format!("{} {}", session_token, auth_token);
let res = client.delete(server.join(&format!("/v1/roles/{}", id))?).bearer_auth(token).send().await?; let res = client
.delete(server.join(&format!("/v1/roles/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() { if res.status().is_success() {
println!("Role removed"); println!("Role removed");
} else { } else {
let resp: APIErrorResponse = res.json().await?; let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error removing role: {} {}", resp.errors[0].code, resp.errors[0].message); eprintln!(
"[error] Error removing role: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1); std::process::exit(1);
} }

View File

@ -112,10 +112,7 @@ pub fn apiworker_main(
cdata.creds = Some(creds); cdata.creds = Some(creds);
cdata.dh_privkey = Some(dh_privkey); cdata.dh_privkey = Some(dh_privkey);
match fs::write( match fs::write(nebula_yml(&instance), config) {
nebula_yml(&instance),
config,
) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("unable to save nebula config: {}", e); error!("unable to save nebula config: {}", e);
@ -175,10 +172,7 @@ pub fn apiworker_main(
} }
}; };
match fs::write( match fs::write(nebula_yml(&instance), config) {
nebula_yml(&instance),
config,
) {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("unable to save nebula config: {}", e); error!("unable to save nebula config: {}", e);

View File

@ -4,11 +4,11 @@ use std::error::Error;
use std::fs; use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddrV4};
use crate::dirs::{config_dir, tfclient_toml, tfdata_toml};
use dnapi_rs::client_blocking::EnrollMeta; use dnapi_rs::client_blocking::EnrollMeta;
use dnapi_rs::credentials::Credentials; use dnapi_rs::credentials::Credentials;
use log::{debug, info}; use log::{debug, info};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::dirs::{config_dir, tfclient_toml, tfdata_toml};
pub const DEFAULT_PORT: u16 = 8157; pub const DEFAULT_PORT: u16 = 8157;
fn default_port() -> u16 { fn default_port() -> u16 {
@ -39,10 +39,7 @@ pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> {
disable_automatic_config_updates: false, disable_automatic_config_updates: false,
}; };
let config_str = toml::to_string(&config)?; let config_str = toml::to_string(&config)?;
fs::write( fs::write(tfclient_toml(instance), config_str)?;
tfclient_toml(instance),
config_str,
)?;
Ok(()) Ok(())
} }
@ -72,10 +69,7 @@ pub fn create_cdata(instance: &str) -> Result<(), Box<dyn Error>> {
meta: None, meta: None,
}; };
let config_str = toml::to_string(&config)?; let config_str = toml::to_string(&config)?;
fs::write( fs::write(tfdata_toml(instance), config_str)?;
tfdata_toml(instance),
config_str,
)?;
Ok(()) Ok(())
} }

View File

@ -6,7 +6,7 @@ use std::thread;
use crate::apiworker::{apiworker_main, APIWorkerMessage}; use crate::apiworker::{apiworker_main, APIWorkerMessage};
use crate::config::load_config; use crate::config::load_config;
use crate::nebulaworker::{NebulaWorkerMessage, nebulaworker_main}; use crate::nebulaworker::{nebulaworker_main, NebulaWorkerMessage};
use crate::socketworker::{socketworker_main, SocketWorkerMessage}; use crate::socketworker::{socketworker_main, SocketWorkerMessage};
use crate::timerworker::{timer_main, TimerWorkerMessage}; use crate::timerworker::{timer_main, TimerWorkerMessage};
use crate::util::{check_server_url, shutdown}; use crate::util::{check_server_url, shutdown};
@ -83,7 +83,11 @@ pub fn daemon_main(name: String, server: String) {
let timer_transmitter = transmitter; let timer_transmitter = transmitter;
let timer_thread = thread::spawn(move || { let timer_thread = thread::spawn(move || {
timer_main(timer_transmitter, rx_timer, config.disable_automatic_config_updates); timer_main(
timer_transmitter,
rx_timer,
config.disable_automatic_config_updates,
);
}); });
info!("Waiting for timer thread to exit..."); info!("Waiting for timer thread to exit...");
match timer_thread.join() { match timer_thread.join() {
@ -95,7 +99,6 @@ pub fn daemon_main(name: String, server: String) {
} }
info!("Timer thread exited"); info!("Timer thread exited");
info!("Waiting for socket thread to exit..."); info!("Waiting for socket thread to exit...");
match socket_thread.join() { match socket_thread.join() {
Ok(_) => (), Ok(_) => (),

View File

@ -31,13 +31,19 @@ pub fn config_dir(instance: &str) -> PathBuf {
} }
pub fn tfclient_toml(instance: &str) -> PathBuf { pub fn tfclient_toml(instance: &str) -> PathBuf {
config_base().join(format!("{}/", instance)).join("tfclient.toml") config_base()
.join(format!("{}/", instance))
.join("tfclient.toml")
} }
pub fn tfdata_toml(instance: &str) -> PathBuf { pub fn tfdata_toml(instance: &str) -> PathBuf {
config_base().join(format!("{}/", instance)).join("tfdata.toml") config_base()
.join(format!("{}/", instance))
.join("tfdata.toml")
} }
pub fn nebula_yml(instance: &str) -> PathBuf { pub fn nebula_yml(instance: &str) -> PathBuf {
config_base().join(format!("{}/", instance)).join("nebula.yml") config_base()
.join(format!("{}/", instance))
.join("nebula.yml")
} }

View File

@ -44,7 +44,6 @@ struct Cli {
#[derive(Subcommand)] #[derive(Subcommand)]
enum Commands { enum Commands {
/// Run the tfclient daemon in the foreground /// Run the tfclient daemon in the foreground
Run { Run {
#[clap(short, long, default_value = "tfclient")] #[clap(short, long, default_value = "tfclient")]

View File

@ -2,13 +2,13 @@
use crate::config::{load_cdata, NebulaConfig, TFClientConfig}; use crate::config::{load_cdata, NebulaConfig, TFClientConfig};
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
use crate::dirs::{nebula_yml}; use crate::dirs::nebula_yml;
use crate::util::shutdown;
use log::{debug, error, info}; use log::{debug, error, info};
use nebula_ffi::NebulaInstance;
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
use std::sync::mpsc::Receiver; use std::sync::mpsc::Receiver;
use nebula_ffi::NebulaInstance;
use crate::util::shutdown;
pub enum NebulaWorkerMessage { pub enum NebulaWorkerMessage {
Shutdown, Shutdown,
@ -23,9 +23,7 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
let cdata = load_cdata(instance)?; let cdata = load_cdata(instance)?;
let key = cdata.dh_privkey.ok_or("Missing private key")?; let key = cdata.dh_privkey.ok_or("Missing private key")?;
let config_str = fs::read_to_string( let config_str = fs::read_to_string(nebula_yml(instance))?;
nebula_yml(instance),
)?;
let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?; let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?;
config.pki.key = Some(String::from_utf8(key)?); config.pki.key = Some(String::from_utf8(key)?);
@ -33,10 +31,7 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
debug!("inserted private key into config: {:?}", config); debug!("inserted private key into config: {:?}", config);
let config_str = serde_yaml::to_string(&config)?; let config_str = serde_yaml::to_string(&config)?;
fs::write( fs::write(nebula_yml(instance), config_str)?;
nebula_yml(instance),
config_str,
)?;
Ok(()) Ok(())
} }
@ -74,7 +69,8 @@ pub fn nebulaworker_main(
error!("not enrolled, cannot start nebula"); error!("not enrolled, cannot start nebula");
} else { } else {
info!("setting up nebula..."); info!("setting up nebula...");
nebula = Some(match NebulaInstance::new(nebula_yml(&instance).as_path(), false) { nebula = Some(
match NebulaInstance::new(nebula_yml(&instance).as_path(), false) {
Ok(i) => { Ok(i) => {
info!("nebula setup"); info!("nebula setup");
info!("starting nebula..."); info!("starting nebula...");
@ -89,15 +85,15 @@ pub fn nebulaworker_main(
} }
i i
}, }
Err(e) => { Err(e) => {
error!("error setting up Nebula: {}", e); error!("error setting up Nebula: {}", e);
error!("nebula thread exiting with error"); error!("nebula thread exiting with error");
shutdown(&transmitter); shutdown(&transmitter);
return; return;
} }
},
}); );
info!("nebula process started"); info!("nebula process started");
} }
@ -157,7 +153,8 @@ pub fn nebulaworker_main(
} else { } else {
debug!("detected enrollment, starting nebula for the first time"); debug!("detected enrollment, starting nebula for the first time");
info!("setting up nebula..."); info!("setting up nebula...");
nebula = Some(match NebulaInstance::new(nebula_yml(&instance).as_path(), false) { nebula = Some(
match NebulaInstance::new(nebula_yml(&instance).as_path(), false) {
Ok(i) => { Ok(i) => {
info!("nebula setup"); info!("nebula setup");
info!("starting nebula..."); info!("starting nebula...");
@ -172,15 +169,15 @@ pub fn nebulaworker_main(
} }
i i
}, }
Err(e) => { Err(e) => {
error!("error setting up Nebula: {}", e); error!("error setting up Nebula: {}", e);
error!("nebula thread exiting with error"); error!("nebula thread exiting with error");
shutdown(&transmitter); shutdown(&transmitter);
return; return;
} }
},
}); );
info!("nebula process started"); info!("nebula process started");
} }

View File

@ -4,7 +4,7 @@
use crate::config::{load_cdata, NebulaConfig, TFClientConfig}; use crate::config::{load_cdata, NebulaConfig, TFClientConfig};
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
use crate::dirs::{nebula_yml}; use crate::dirs::nebula_yml;
use log::{debug, error, info}; use log::{debug, error, info};
use std::error::Error; use std::error::Error;
use std::fs; use std::fs;
@ -23,9 +23,7 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
let cdata = load_cdata(instance)?; let cdata = load_cdata(instance)?;
let key = cdata.dh_privkey.ok_or("Missing private key")?; let key = cdata.dh_privkey.ok_or("Missing private key")?;
let config_str = fs::read_to_string( let config_str = fs::read_to_string(nebula_yml(instance))?;
nebula_yml(instance),
)?;
let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?; let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?;
config.pki.key = Some(String::from_utf8(key)?); config.pki.key = Some(String::from_utf8(key)?);
@ -33,25 +31,26 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
debug!("inserted private key into config: {:?}", config); debug!("inserted private key into config: {:?}", config);
let config_str = serde_yaml::to_string(&config)?; let config_str = serde_yaml::to_string(&config)?;
fs::write( fs::write(nebula_yml(instance), config_str)?;
nebula_yml(instance),
config_str,
)?;
Ok(()) Ok(())
} }
pub fn nebulaworker_main(
pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter: ThreadMessageSender, rx: Receiver<NebulaWorkerMessage>) { _config: TFClientConfig,
instance: String,
_transmitter: ThreadMessageSender,
rx: Receiver<NebulaWorkerMessage>,
) {
loop { loop {
match rx.recv() { match rx.recv() {
Ok(msg) => match msg { Ok(msg) => match msg {
NebulaWorkerMessage::WakeUp => { NebulaWorkerMessage::WakeUp => {
continue; continue;
}, }
NebulaWorkerMessage::Shutdown => { NebulaWorkerMessage::Shutdown => {
break; break;
}, }
NebulaWorkerMessage::ConfigUpdated => { NebulaWorkerMessage::ConfigUpdated => {
info!("our configuration has been updated - reloading"); info!("our configuration has been updated - reloading");

View File

@ -12,7 +12,11 @@ pub enum TimerWorkerMessage {
Shutdown, Shutdown,
} }
pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>, disable_config_updates: bool) { pub fn timer_main(
tx: ThreadMessageSender,
rx: Receiver<TimerWorkerMessage>,
disable_config_updates: bool,
) {
let mut api_reload_timer = SystemTime::now().add(Duration::from_secs(60)); let mut api_reload_timer = SystemTime::now().add(Duration::from_secs(60));
loop { loop {

View File

@ -1,12 +1,12 @@
use log::{error, warn};
use sha2::Digest;
use sha2::Sha256;
use url::Url;
use crate::apiworker::APIWorkerMessage; use crate::apiworker::APIWorkerMessage;
use crate::daemon::ThreadMessageSender; use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage; use crate::nebulaworker::NebulaWorkerMessage;
use crate::socketworker::SocketWorkerMessage; use crate::socketworker::SocketWorkerMessage;
use crate::timerworker::TimerWorkerMessage; use crate::timerworker::TimerWorkerMessage;
use log::{error, warn};
use sha2::Digest;
use sha2::Sha256;
use url::Url;
pub fn sha256(bytes: &[u8]) -> String { pub fn sha256(bytes: &[u8]) -> String {
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
@ -51,10 +51,7 @@ pub fn shutdown(transmitter: &ThreadMessageSender) {
); );
} }
} }
match transmitter match transmitter.api_thread.send(APIWorkerMessage::Shutdown) {
.api_thread
.send(APIWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("Error sending shutdown message to api worker thread: {}", e); error!("Error sending shutdown message to api worker thread: {}", e);
@ -72,10 +69,7 @@ pub fn shutdown(transmitter: &ThreadMessageSender) {
); );
} }
} }
match transmitter match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) {
.timer_thread
.send(TimerWorkerMessage::Shutdown)
{
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!( error!(

View File

@ -15,10 +15,13 @@ use crate::crypto::{decrypt_with_nonce, get_cipher_from_config};
use crate::AppState; use crate::AppState;
use ed25519_dalek::SigningKey; use ed25519_dalek::SigningKey;
use ipnet::Ipv4Net; use ipnet::Ipv4Net;
use log::{error}; use log::error;
use sea_orm::{ColumnTrait, Condition, EntityTrait, QueryFilter}; use sea_orm::{ColumnTrait, Condition, EntityTrait, QueryFilter};
use serde_yaml::{Mapping, Value}; use serde_yaml::{Mapping, Value};
use trifid_api_entities::entity::{firewall_rule, host, host_config_override, host_static_address, keystore_entry, network, organization, signing_ca}; use trifid_api_entities::entity::{
firewall_rule, host, host_config_override, host_static_address, keystore_entry, network,
organization, signing_ca,
};
use trifid_pki::cert::{ use trifid_pki::cert::{
deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate, deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate,
NebulaCertificateDetails, NebulaCertificateDetails,
@ -36,7 +39,7 @@ pub struct CodegenRequiredInfo {
pub lighthouse_ips: Vec<Ipv4Addr>, pub lighthouse_ips: Vec<Ipv4Addr>,
pub blocked_hosts: Vec<String>, pub blocked_hosts: Vec<String>,
pub firewall_rules: Vec<NebulaConfigFirewallRule>, pub firewall_rules: Vec<NebulaConfigFirewallRule>,
pub config_overrides: Vec<(String, String)> pub config_overrides: Vec<(String, String)>,
} }
pub async fn generate_config( pub async fn generate_config(
@ -90,14 +93,20 @@ pub async fn generate_config(
let mut blocked_hosts_fingerprints = vec![]; let mut blocked_hosts_fingerprints = vec![];
for host in &info.blocked_hosts { for host in &info.blocked_hosts {
// check if the host exists // check if the host exists
if host::Entity::find().filter(host::Column::Id.eq(host)).one(&db.conn).await?.is_some() { if host::Entity::find()
.filter(host::Column::Id.eq(host))
.one(&db.conn)
.await?
.is_some()
{
// pull all of their certs ever and block them // pull all of their certs ever and block them
let host_entries = keystore_entry::Entity::find().filter(keystore_entry::Column::Host.eq(host)).all(&db.conn).await?; let host_entries = keystore_entry::Entity::find()
.filter(keystore_entry::Column::Host.eq(host))
.all(&db.conn)
.await?;
for entry in &host_entries { for entry in &host_entries {
// decode the cert // decode the cert
let cert = deserialize_nebula_certificate_from_pem(&entry.certificate)?; let cert = deserialize_nebula_certificate_from_pem(&entry.certificate)?;
@ -209,11 +218,18 @@ pub async fn generate_config(
let mut current_val = &mut value; let mut current_val = &mut value;
for key_iter in &key_split[..key_split.len()-1] { for key_iter in &key_split[..key_split.len() - 1] {
current_val = current_val.as_mapping_mut().unwrap().entry(Value::String(key_iter.to_string())).or_insert(Value::Mapping(Mapping::new())); current_val = current_val
.as_mapping_mut()
.unwrap()
.entry(Value::String(key_iter.to_string()))
.or_insert(Value::Mapping(Mapping::new()));
} }
current_val.as_mapping_mut().unwrap().insert(Value::String(key_split[key_split.len()-1].to_string()), serde_yaml::from_str(kv_value)?); current_val.as_mapping_mut().unwrap().insert(
Value::String(key_split[key_split.len() - 1].to_string()),
serde_yaml::from_str(kv_value)?,
);
} }
let config_str_merged = serde_yaml::to_string(&value)?; let config_str_merged = serde_yaml::to_string(&value)?;
@ -338,10 +354,10 @@ pub async fn collect_info<'a>(
let best_ca = best_ca.unwrap(); let best_ca = best_ca.unwrap();
// pull our host's config overrides // pull our host's config overrides
let config_overrides = host_config_overrides.iter().map(|u| { let config_overrides = host_config_overrides
(u.key.clone(), u.value.clone()) .iter()
}).collect(); .map(|u| (u.key.clone(), u.value.clone()))
.collect();
// pull our role's firewall rules // pull our role's firewall rules
let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find() let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find()
@ -386,6 +402,6 @@ pub async fn collect_info<'a>(
lighthouse_ips: lighthouses, lighthouse_ips: lighthouses,
blocked_hosts, blocked_hosts,
firewall_rules, firewall_rules,
config_overrides config_overrides,
}) })
} }

View File

@ -77,7 +77,7 @@ pub struct TrifidConfigServer {
#[serde(default = "socketaddr_8080")] #[serde(default = "socketaddr_8080")]
pub bind: SocketAddr, pub bind: SocketAddr,
#[serde(default = "default_workers")] #[serde(default = "default_workers")]
pub workers: usize pub workers: usize,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -733,4 +733,6 @@ fn is_none<T>(o: &Option<T>) -> bool {
o.is_none() o.is_none()
} }
fn default_workers() -> usize { 32 } fn default_workers() -> usize {
32
}

View File

@ -14,6 +14,10 @@
// You should have received a copy of the GNU General Public License // You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
use crate::config::CONFIG;
use crate::error::{APIError, APIErrorsResponse};
use crate::tokens::random_id_no_id;
use actix_cors::Cors;
use actix_request_identifier::RequestIdentifier; use actix_request_identifier::RequestIdentifier;
use actix_web::{ use actix_web::{
web::{Data, JsonConfig}, web::{Data, JsonConfig},
@ -23,10 +27,6 @@ use log::{info, Level};
use sea_orm::{ConnectOptions, Database, DatabaseConnection}; use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use std::error::Error; use std::error::Error;
use std::time::Duration; use std::time::Duration;
use actix_cors::Cors;
use crate::config::CONFIG;
use crate::error::{APIError, APIErrorsResponse};
use crate::tokens::random_id_no_id;
use trifid_api_migration::{Migrator, MigratorTrait}; use trifid_api_migration::{Migrator, MigratorTrait};
pub mod auth_tokens; pub mod auth_tokens;
@ -37,10 +37,10 @@ pub mod cursor;
pub mod error; pub mod error;
//pub mod legacy_keystore; // TODO- Remove //pub mod legacy_keystore; // TODO- Remove
pub mod magic_link; pub mod magic_link;
pub mod response;
pub mod routes; pub mod routes;
pub mod timers; pub mod timers;
pub mod tokens; pub mod tokens;
pub mod response;
pub struct AppState { pub struct AppState {
pub conn: DatabaseConnection, pub conn: DatabaseConnection,

View File

@ -1,9 +1,9 @@
use std::fmt::{Display, Formatter};
use actix_web::{HttpRequest, HttpResponse, Responder, ResponseError};
use actix_web::body::EitherBody; use actix_web::body::EitherBody;
use actix_web::web::Json; use actix_web::web::Json;
use actix_web::{HttpRequest, HttpResponse, Responder, ResponseError};
use log::error; use log::error;
use sea_orm::DbErr; use sea_orm::DbErr;
use std::fmt::{Display, Formatter};
use crate::error::{APIError, APIErrorsResponse}; use crate::error::{APIError, APIErrorsResponse};
@ -30,13 +30,15 @@ impl Responder for ErrResponse {
impl From<DbErr> for ErrResponse { impl From<DbErr> for ErrResponse {
fn from(value: DbErr) -> Self { fn from(value: DbErr) -> Self {
error!("database error: {}", value); error!("database error: {}", value);
Self(APIErrorsResponse { errors: vec![ Self(APIErrorsResponse {
APIError { errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error performing the database query. Please try again later.".to_string(), message:
"There was an error performing the database query. Please try again later."
.to_string(),
path: None, path: None,
} }],
] }) })
} }
} }

View File

@ -10,18 +10,20 @@ use dnapi_rs::message::{
}; };
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use log::{error, warn}; use log::{error, warn};
use sea_orm::{ActiveModelTrait, EntityTrait};
use std::clone::Clone; use std::clone::Clone;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use sea_orm::{ActiveModelTrait, EntityTrait}; use trifid_pki::cert::{
use trifid_pki::cert::{deserialize_ed25519_public, deserialize_nebula_certificate_from_pem, deserialize_x25519_public}; deserialize_ed25519_public, deserialize_nebula_certificate_from_pem, deserialize_x25519_public,
};
use trifid_api_entities::entity::{host, keystore_entry, keystore_host};
use crate::error::APIErrorsResponse;
use sea_orm::{ColumnTrait, QueryFilter, IntoActiveModel};
use sea_orm::ActiveValue::Set;
use crate::AppState;
use crate::config::NebulaConfig; use crate::config::NebulaConfig;
use crate::error::APIErrorsResponse;
use crate::tokens::random_id; use crate::tokens::random_id;
use crate::AppState;
use sea_orm::ActiveValue::Set;
use sea_orm::{ColumnTrait, IntoActiveModel, QueryFilter};
use trifid_api_entities::entity::{host, keystore_entry, keystore_host};
#[post("/v1/dnclient")] #[post("/v1/dnclient")]
pub async fn dnclient( pub async fn dnclient(
@ -109,7 +111,8 @@ pub async fn dnclient(
log::debug!("{:x?}", keystore_data.client_signing_key); log::debug!("{:x?}", keystore_data.client_signing_key);
let key = VerifyingKey::from_bytes(&keystore_data.client_signing_key.try_into().unwrap()).unwrap(); let key =
VerifyingKey::from_bytes(&keystore_data.client_signing_key.try_into().unwrap()).unwrap();
if key.verify(req.message.as_bytes(), &signature).is_err() { if key.verify(req.message.as_bytes(), &signature).is_err() {
// Be intentionally vague as the message is invalid. // Be intentionally vague as the message is invalid.
@ -176,9 +179,7 @@ pub async fn dnclient(
return HttpResponse::NotFound().json(APIErrorsResponse { return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![crate::error::APIError { errors: vec![crate::error::APIError {
code: "ERR_NOT_FOUND".to_string(), code: "ERR_NOT_FOUND".to_string(),
message: message: "resource not found".to_string(),
"resource not found"
.to_string(),
path: None, path: None,
}], }],
}); });
@ -241,11 +242,15 @@ pub async fn dnclient(
c1.pki.cert = c0.pki.cert.clone(); c1.pki.cert = c0.pki.cert.clone();
if c0 == c1 { if c0 == c1 {
// its just the cert. deserialize both and check if any details have changed // its just the cert. deserialize both and check if any details have changed
let cert0 = deserialize_nebula_certificate_from_pem(c0.pki.cert.as_bytes()).expect("generated an invalid certificate"); let cert0 = deserialize_nebula_certificate_from_pem(c0.pki.cert.as_bytes())
let mut cert1 = deserialize_nebula_certificate_from_pem(c1.pki.cert.as_bytes()).expect("generated an invalid certificate"); .expect("generated an invalid certificate");
let mut cert1 = deserialize_nebula_certificate_from_pem(c1.pki.cert.as_bytes())
.expect("generated an invalid certificate");
cert1.details.not_before = cert0.details.not_before; cert1.details.not_before = cert0.details.not_before;
cert1.details.not_after = cert0.details.not_after; cert1.details.not_after = cert0.details.not_after;
if cert0.serialize_to_pem().expect("generated invalid cert") == cert1.serialize_to_pem().expect("generated invalid cert") { if cert0.serialize_to_pem().expect("generated invalid cert")
== cert1.serialize_to_pem().expect("generated invalid cert")
{
// fake news! its fine actually // fake news! its fine actually
config_is_different = false; config_is_different = false;
} }
@ -256,7 +261,10 @@ pub async fn dnclient(
let config_update_avail = config_is_different || req.counter < keystore_header.counter as u32; let config_update_avail = config_is_different || req.counter < keystore_header.counter as u32;
host_am.last_out_of_date = Set(config_update_avail); host_am.last_out_of_date = Set(config_update_avail);
host_am.last_seen_at = Set(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64); host_am.last_seen_at = Set(SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64);
match host_am.update(&db.conn).await { match host_am.update(&db.conn).await {
Ok(_) => (), Ok(_) => (),
@ -391,7 +399,7 @@ pub async fn dnclient(
client_dh_key: dh_pubkey, client_dh_key: dh_pubkey,
client_signing_key: ed_pubkey, client_signing_key: ed_pubkey,
config: cfg_str.clone(), config: cfg_str.clone(),
signing_key: keystore_data.signing_key.clone() signing_key: keystore_data.signing_key.clone(),
}; };
match ks_entry_model.into_active_model().insert(&db.conn).await { match ks_entry_model.into_active_model().insert(&db.conn).await {
@ -425,7 +433,8 @@ pub async fn dnclient(
} }
} }
let signing_key = SigningKey::from_bytes(&keystore_data.signing_key.try_into().unwrap()); let signing_key =
SigningKey::from_bytes(&keystore_data.signing_key.try_into().unwrap());
// get the signing key that the client last trusted based on its current config version // get the signing key that the client last trusted based on its current config version
let msg = DoUpdateResponse { let msg = DoUpdateResponse {

View File

@ -76,7 +76,9 @@ use serde::{Deserialize, Serialize};
use std::net::{Ipv4Addr, SocketAddrV4}; use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr; use std::str::FromStr;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use trifid_api_entities::entity::{host, host_config_override, host_static_address, network, organization, role}; use trifid_api_entities::entity::{
host, host_config_override, host_static_address, network, organization, role,
};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ListHostsRequestOpts { pub struct ListHostsRequestOpts {
@ -577,13 +579,11 @@ pub async fn create_hosts_request(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()), path: Some("networkID".to_string()),
} }],
],
}); });
} }
}; };
@ -592,25 +592,21 @@ pub async fn create_hosts_request(
net_id = net.id; net_id = net.id;
} else { } else {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()), path: Some("networkID".to_string()),
} }],
],
}); });
} }
if net_id != req.network_id { if net_id != req.network_id {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()), path: Some("networkID".to_string()),
} }],
],
}); });
} }
@ -640,7 +636,7 @@ pub async fn create_hosts_request(
code: "ERR_INVALID_VALUE".to_string(), code: "ERR_INVALID_VALUE".to_string(),
message: "lighthouse hosts must specify a static listen port".to_string(), message: "lighthouse hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()), path: Some("listenPort".to_string()),
}] }],
}); });
} else if req.listen_port == 0 && req.is_relay { } else if req.listen_port == 0 && req.is_relay {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
@ -648,7 +644,7 @@ pub async fn create_hosts_request(
code: "ERR_INVALID_VALUE".to_string(), code: "ERR_INVALID_VALUE".to_string(),
message: "relay hosts must specify a static listen port".to_string(), message: "relay hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()), path: Some("listenPort".to_string()),
}] }],
}); });
} }
@ -662,26 +658,24 @@ pub async fn create_hosts_request(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(), message:
"There was an error validating the request. Please try again later."
.to_string(),
path: Some("role".to_string()), path: Some("role".to_string()),
} }],
],
}); });
} }
}; };
if roles.is_empty() { if roles.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("role".to_string()), path: Some("role".to_string()),
} }],
],
}); });
} }
} }
@ -695,26 +689,23 @@ pub async fn create_hosts_request(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(), message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("name".to_string()), path: Some("name".to_string()),
} }],
],
}); });
} }
}; };
if !matching_hostname.is_empty() { if !matching_hostname.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DUPLICATE_VALUE".to_string(), code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(), message: "value already exists".to_string(),
path: Some("name".to_string()), path: Some("name".to_string()),
} }],
],
}); });
} }
@ -727,26 +718,23 @@ pub async fn create_hosts_request(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(), message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("ipAddress".to_string()), path: Some("ipAddress".to_string()),
} }],
],
}); });
} }
}; };
if !matching_ip.is_empty() { if !matching_ip.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DUPLICATE_VALUE".to_string(), code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(), message: "value already exists".to_string(),
path: Some("ipAddress".to_string()), path: Some("ipAddress".to_string()),
} }],
],
}); });
} }
@ -1006,9 +994,7 @@ pub async fn get_host(id: Path<String>, req_info: HttpRequest, db: Data<AppState
return HttpResponse::NotFound().json(APIErrorsResponse { return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(), code: "ERR_NOT_FOUND".to_string(),
message: message: "resource not found".to_string(),
"resource not found"
.to_string(),
path: None, path: None,
}], }],
}); });
@ -1243,9 +1229,7 @@ pub async fn delete_host(
return HttpResponse::Unauthorized().json(APIErrorsResponse { return HttpResponse::Unauthorized().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(), code: "ERR_NOT_FOUND".to_string(),
message: message: "resource not found".to_string(),
"resource not found"
.to_string(),
path: None, path: None,
}], }],
}); });
@ -1500,9 +1484,7 @@ pub async fn edit_host(
return HttpResponse::NotFound().json(APIErrorsResponse { return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(), code: "ERR_NOT_FOUND".to_string(),
message: message: "resource not found".to_string(),
"resource not found"
.to_string(),
path: None, path: None,
}], }],
}); });
@ -1827,9 +1809,7 @@ pub async fn block_host(
return HttpResponse::NotFound().json(APIErrorsResponse { return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(), code: "ERR_NOT_FOUND".to_string(),
message: message: "resource not found".to_string(),
"resource not found"
.to_string(),
path: None, path: None,
}], }],
}); });
@ -2094,9 +2074,7 @@ pub async fn enroll_host(
return HttpResponse::Unauthorized().json(APIErrorsResponse { return HttpResponse::Unauthorized().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(), code: "ERR_NOT_FOUND".to_string(),
message: message: "resource not found".to_string(),
"resource not found"
.to_string(),
path: None, path: None,
}], }],
}); });
@ -2263,13 +2241,11 @@ pub async fn create_host_and_enrollment_code(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()), path: Some("networkID".to_string()),
} }],
],
}); });
} }
}; };
@ -2288,13 +2264,11 @@ pub async fn create_host_and_enrollment_code(
if net_id != req.network_id { if net_id != req.network_id {
return HttpResponse::Unauthorized().json(APIErrorsResponse { return HttpResponse::Unauthorized().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()), path: Some("networkID".to_string()),
} }],
],
}); });
} }
@ -2324,7 +2298,7 @@ pub async fn create_host_and_enrollment_code(
code: "ERR_INVALID_VALUE".to_string(), code: "ERR_INVALID_VALUE".to_string(),
message: "lighthouse hosts must specify a static listen port".to_string(), message: "lighthouse hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()), path: Some("listenPort".to_string()),
}] }],
}); });
} else if req.listen_port == 0 && req.is_relay { } else if req.listen_port == 0 && req.is_relay {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
@ -2332,7 +2306,7 @@ pub async fn create_host_and_enrollment_code(
code: "ERR_INVALID_VALUE".to_string(), code: "ERR_INVALID_VALUE".to_string(),
message: "relay hosts must specify a static listen port".to_string(), message: "relay hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()), path: Some("listenPort".to_string()),
}] }],
}); });
} }
@ -2346,26 +2320,24 @@ pub async fn create_host_and_enrollment_code(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(), message:
"There was an error validating the request. Please try again later."
.to_string(),
path: Some("role".to_string()), path: Some("role".to_string()),
} }],
],
}); });
} }
}; };
if roles.is_empty() { if roles.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_REFERENCE".to_string(), code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(), message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("role".to_string()), path: Some("role".to_string()),
} }],
],
}); });
} }
} }
@ -2379,26 +2351,23 @@ pub async fn create_host_and_enrollment_code(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(), message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("name".to_string()), path: Some("name".to_string()),
} }],
],
}); });
} }
}; };
if !matching_hostname.is_empty() { if !matching_hostname.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DUPLICATE_VALUE".to_string(), code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(), message: "value already exists".to_string(),
path: Some("name".to_string()), path: Some("name".to_string()),
} }],
],
}); });
} }
@ -2411,26 +2380,23 @@ pub async fn create_host_and_enrollment_code(
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(), message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("ipAddress".to_string()), path: Some("ipAddress".to_string()),
} }],
],
}); });
} }
}; };
if !matching_ip.is_empty() { if !matching_ip.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DUPLICATE_VALUE".to_string(), code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(), message: "value already exists".to_string(),
path: Some("ipAddress".to_string()), path: Some("ipAddress".to_string()),
} }],
],
}); });
} }
@ -2588,7 +2554,11 @@ pub enum HostConfigOverrideDataOverrideValue {
} }
#[get("/v1/hosts/{host_id}/config-overrides")] #[get("/v1/hosts/{host_id}/config-overrides")]
pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Data<AppState>) -> HttpResponse { pub async fn get_host_overrides(
id: Path<String>,
req_info: HttpRequest,
db: Data<AppState>,
) -> HttpResponse {
// For this endpoint, you either need to be a fully authenticated user OR a token with hosts:read // For this endpoint, you either need to be a fully authenticated user OR a token with hosts:read
let session_info = enforce_2fa(&req_info, &db.conn) let session_info = enforce_2fa(&req_info, &db.conn)
.await .await
@ -2752,7 +2722,11 @@ pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Dat
}); });
} }
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find().filter(host_config_override::Column::Host.eq(host.id)).all(&db.conn).await { let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(host.id))
.all(&db.conn)
.await
{
Ok(h) => h, Ok(h) => h,
Err(e) => { Err(e) => {
error!("Database error: {}", e); error!("Database error: {}", e);
@ -2767,7 +2741,9 @@ pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Dat
} }
}; };
let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides.iter().map(|u| { let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides
.iter()
.map(|u| {
let val; let val;
if u.value == "true" || u.value == "false" { if u.value == "true" || u.value == "false" {
val = HostConfigOverrideDataOverrideValue::Boolean(u.value == "true"); val = HostConfigOverrideDataOverrideValue::Boolean(u.value == "true");
@ -2780,12 +2756,11 @@ pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Dat
key: u.key.clone(), key: u.key.clone(),
value: val, value: val,
} }
}).collect(); })
.collect();
HttpResponse::Ok().json(HostConfigOverrideResponse { HttpResponse::Ok().json(HostConfigOverrideResponse {
data: HostConfigOverrideData { data: HostConfigOverrideData { overrides },
overrides,
},
}) })
} }
@ -2795,7 +2770,12 @@ pub struct UpdateOverridesRequest {
} }
#[put("/v1/hosts/{host_id}/config-overrides")] #[put("/v1/hosts/{host_id}/config-overrides")]
pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRequest>, req_info: HttpRequest, db: Data<AppState>) -> HttpResponse { pub async fn update_host_overrides(
id: Path<String>,
req: Json<UpdateOverridesRequest>,
req_info: HttpRequest,
db: Data<AppState>,
) -> HttpResponse {
// For this endpoint, you either need to be a fully authenticated user OR a token with hosts:read // For this endpoint, you either need to be a fully authenticated user OR a token with hosts:read
let session_info = enforce_2fa(&req_info, &db.conn) let session_info = enforce_2fa(&req_info, &db.conn)
.await .await
@ -2959,7 +2939,11 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
}); });
} }
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find().filter(host_config_override::Column::Host.eq(&host.id)).all(&db.conn).await { let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(&host.id))
.all(&db.conn)
.await
{
Ok(h) => h, Ok(h) => h,
Err(e) => { Err(e) => {
error!("Database error: {}", e); error!("Database error: {}", e);
@ -2982,7 +2966,8 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error with the database query. Please try again later." message:
"There was an error with the database query. Please try again later."
.to_string(), .to_string(),
path: None, path: None,
}], }],
@ -3009,7 +2994,8 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: "There was an error with the database query. Please try again later." message:
"There was an error with the database query. Please try again later."
.to_string(), .to_string(),
path: None, path: None,
}], }],
@ -3018,7 +3004,11 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
} }
} }
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find().filter(host_config_override::Column::Host.eq(&host.id)).all(&db.conn).await { let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(&host.id))
.all(&db.conn)
.await
{
Ok(h) => h, Ok(h) => h,
Err(e) => { Err(e) => {
error!("Database error: {}", e); error!("Database error: {}", e);
@ -3033,11 +3023,18 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
} }
}; };
let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides.iter().map(|u| { let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides
.iter()
.map(|u| {
let val; let val;
if u.value == "true" || u.value == "false" { if u.value == "true" || u.value == "false" {
val = HostConfigOverrideDataOverrideValue::Boolean(u.value == "true"); val = HostConfigOverrideDataOverrideValue::Boolean(u.value == "true");
} else if u.value.chars().all(|c| c.is_numeric()) || u.value.starts_with('-') && u.value.chars().collect::<Vec<_>>()[1..].iter().all(|c| c.is_numeric()) { } else if u.value.chars().all(|c| c.is_numeric())
|| u.value.starts_with('-')
&& u.value.chars().collect::<Vec<_>>()[1..]
.iter()
.all(|c| c.is_numeric())
{
val = HostConfigOverrideDataOverrideValue::Numeric(u.value.parse().unwrap()); val = HostConfigOverrideDataOverrideValue::Numeric(u.value.parse().unwrap());
} else { } else {
val = HostConfigOverrideDataOverrideValue::Other(u.value.clone()); val = HostConfigOverrideDataOverrideValue::Other(u.value.clone());
@ -3046,11 +3043,10 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
key: u.key.clone(), key: u.key.clone(),
value: val, value: val,
} }
}).collect(); })
.collect();
HttpResponse::Ok().json(HostConfigOverrideResponse { HttpResponse::Ok().json(HostConfigOverrideResponse {
data: HostConfigOverrideData { data: HostConfigOverrideData { overrides },
overrides,
},
}) })
} }

View File

@ -238,27 +238,23 @@ pub async fn create_role_request(
if role.is_some() { if role.is_some() {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_DUPLICATE_VALUE".to_string(), code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(), message: "value already exists".to_string(),
path: Some("name".to_string()) path: Some("name".to_string()),
} }],
] });
})
} }
for (id, rule) in req.firewall_rules.iter().enumerate() { for (id, rule) in req.firewall_rules.iter().enumerate() {
if let Some(pr) = &rule.port_range { if let Some(pr) = &rule.port_range {
if pr.from < pr.to { if pr.from < pr.to {
return HttpResponse::BadRequest().json(APIErrorsResponse { return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_INVALID_VALUE".to_string(), code: "ERR_INVALID_VALUE".to_string(),
message: "from must be less than or equal to to".to_string(), message: "from must be less than or equal to to".to_string(),
path: Some(format!("firewallRules[{}].portRange", id)) path: Some(format!("firewallRules[{}].portRange", id)),
} }],
]
}); });
} }
} }

View File

@ -4,18 +4,20 @@ use actix_web::{post, HttpRequest, HttpResponse, Responder};
use dnapi_rs::message::{ use dnapi_rs::message::{
APIError, EnrollRequest, EnrollResponse, EnrollResponseData, EnrollResponseDataOrg, APIError, EnrollRequest, EnrollResponse, EnrollResponseData, EnrollResponseDataOrg,
}; };
use ed25519_dalek::{SigningKey}; use ed25519_dalek::SigningKey;
use log::{debug, error}; use log::{debug, error};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter}; use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter,
};
use crate::codegen::{collect_info, generate_config}; use crate::codegen::{collect_info, generate_config};
use crate::response::ErrResponse;
use crate::AppState; use crate::AppState;
use trifid_api_entities::entity::{host_enrollment_code, keystore_entry, keystore_host}; use trifid_api_entities::entity::{host_enrollment_code, keystore_entry, keystore_host};
use trifid_pki::cert::{ use trifid_pki::cert::{
deserialize_ed25519_public, deserialize_x25519_public, serialize_ed25519_public, deserialize_ed25519_public, deserialize_x25519_public, serialize_ed25519_public,
}; };
use crate::response::ErrResponse;
use crate::timers::expired; use crate::timers::expired;
use crate::tokens::random_id; use crate::tokens::random_id;
@ -111,7 +113,8 @@ pub async fn enroll(
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return Ok(HttpResponse::InternalServerError().json(EnrollResponse::Error { return Ok(
HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(), code: "ERR_DB_ERROR".to_string(),
message: message:
@ -119,20 +122,23 @@ pub async fn enroll(
.to_string(), .to_string(),
path: None, path: None,
}], }],
})); }),
);
} }
} }
let info = match collect_info(&db, &enroll_info.host, &dh_pubkey).await { let info = match collect_info(&db, &enroll_info.host, &dh_pubkey).await {
Ok(i) => i, Ok(i) => i,
Err(e) => { Err(e) => {
return Ok(HttpResponse::InternalServerError().json(EnrollResponse::Error { return Ok(
HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(), code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: e.to_string(), message: e.to_string(),
path: None, path: None,
}], }],
})); }),
);
} }
}; };
@ -141,25 +147,34 @@ pub async fn enroll(
Ok(cfg) => cfg, Ok(cfg) => cfg,
Err(e) => { Err(e) => {
error!("error generating configuration: {}", e); error!("error generating configuration: {}", e);
return Ok(HttpResponse::InternalServerError().json(EnrollResponse::Error { return Ok(
HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError { errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(), code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: "There was an error generating the host configuration.".to_string(), message: "There was an error generating the host configuration."
.to_string(),
path: None, path: None,
}], }],
})); }),
);
} }
}; };
// delete all entries in the keystore for this host // delete all entries in the keystore for this host
let entries = keystore_entry::Entity::find().filter(keystore_entry::Column::Host.eq(&enroll_info.host)).all(&db.conn).await?; let entries = keystore_entry::Entity::find()
.filter(keystore_entry::Column::Host.eq(&enroll_info.host))
.all(&db.conn)
.await?;
for entry in entries { for entry in entries {
entry.delete(&db.conn).await?; entry.delete(&db.conn).await?;
} }
let host_info = keystore_host::Entity::find().filter(keystore_host::Column::Id.eq(&enroll_info.host)).one(&db.conn).await?; let host_info = keystore_host::Entity::find()
.filter(keystore_host::Column::Id.eq(&enroll_info.host))
.one(&db.conn)
.await?;
if let Some(old_host) = host_info { if let Some(old_host) = host_info {
old_host.delete(&db.conn).await?; old_host.delete(&db.conn).await?;
@ -183,7 +198,7 @@ pub async fn enroll(
let host_header = keystore_host::Model { let host_header = keystore_host::Model {
id: enroll_info.host.clone(), id: enroll_info.host.clone(),
counter: 1 counter: 1,
}; };
let entry = keystore_entry::Model { let entry = keystore_entry::Model {
id: random_id("ksentry"), id: random_id("ksentry"),
@ -193,7 +208,7 @@ pub async fn enroll(
client_dh_key: dh_pubkey, client_dh_key: dh_pubkey,
client_signing_key: ed_pubkey, client_signing_key: ed_pubkey,
config: cfg.clone(), config: cfg.clone(),
signing_key: key.to_bytes().to_vec() signing_key: key.to_bytes().to_vec(),
}; };
host_header.into_active_model().insert(&db.conn).await?; host_header.into_active_model().insert(&db.conn).await?;

View File

@ -19,17 +19,17 @@
// This endpoint is considered done. No major features should be added or removed, unless it fixes bugs. // This endpoint is considered done. No major features should be added or removed, unless it fixes bugs.
// This endpoint requires the `definednetworking` extension to be enabled to be used. // This endpoint requires the `definednetworking` extension to be enabled to be used.
use actix_web::{get, HttpRequest, HttpResponse}; use crate::auth_tokens::{enforce_2fa, enforce_session, TokenInfo};
use crate::error::{APIError, APIErrorsResponse};
use crate::timers::TIME_FORMAT;
use crate::AppState;
use actix_web::web::Data; use actix_web::web::Data;
use actix_web::{get, HttpRequest, HttpResponse};
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use log::error; use log::error;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use trifid_api_entities::entity::{organization, totp_authenticator}; use trifid_api_entities::entity::{organization, totp_authenticator};
use crate::AppState;
use crate::auth_tokens::{enforce_2fa, enforce_session, TokenInfo};
use crate::error::{APIError, APIErrorsResponse};
use crate::timers::TIME_FORMAT;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct WhoamiResponse { pub struct WhoamiResponse {
@ -72,13 +72,11 @@ pub async fn whoami(req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
Err(e) => { Err(e) => {
error!("database error: {}", e); error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse { return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![ errors: vec![APIError {
APIError {
code: "ERR_UNAUTHORIZED".to_string(), code: "ERR_UNAUTHORIZED".to_string(),
message: "Your authentication token is invalid.".to_string(), message: "Your authentication token is invalid.".to_string(),
path: None, path: None,
} }],
],
}); });
} }
} }
@ -155,7 +153,9 @@ pub async fn whoami(req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
}; };
HttpResponse::Ok().json(WhoamiResponse { HttpResponse::Ok().json(WhoamiResponse {
data: WhoamiResponseData { actor_type: "user".to_string(), actor: WhoamiResponseDataActor { data: WhoamiResponseData {
actor_type: "user".to_string(),
actor: WhoamiResponseDataActor {
id: user.id, id: user.id,
organization_id: org, organization_id: org,
email: user.email, email: user.email,
@ -165,7 +165,8 @@ pub async fn whoami(req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
.format(TIME_FORMAT) .format(TIME_FORMAT)
.to_string(), .to_string(),
has_totp_authenticator: has_totp, has_totp_authenticator: has_totp,
} }, },
},
metadata: WhoamiResponseMetadata {}, metadata: WhoamiResponseMetadata {},
}) })
} }

View File

@ -6,22 +6,20 @@ pub struct Migration;
#[async_trait::async_trait] #[async_trait::async_trait]
impl MigrationTrait for Migration { impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.create_table( manager
.create_table(
Table::create() Table::create()
.table(KeystoreHost::Table) .table(KeystoreHost::Table)
.col( .col(
ColumnDef::new(KeystoreHost::Id) ColumnDef::new(KeystoreHost::Id)
.string() .string()
.not_null() .not_null()
.primary_key() .primary_key(),
) )
.col( .col(ColumnDef::new(KeystoreHost::Counter).integer().not_null())
ColumnDef::new(KeystoreHost::Counter) .to_owned(),
.integer()
.not_null()
) )
.to_owned() .await
).await
} }
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {

View File

@ -1,5 +1,5 @@
use sea_orm_migration::prelude::*;
use crate::m20230427_170037_create_table_hosts::Host; use crate::m20230427_170037_create_table_hosts::Host;
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)] #[derive(DeriveMigrationName)]
pub struct Migration; pub struct Migration;
@ -7,54 +7,43 @@ pub struct Migration;
#[async_trait::async_trait] #[async_trait::async_trait]
impl MigrationTrait for Migration { impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.create_table( manager
.create_table(
Table::create() Table::create()
.table(KeystoreEntry::Table) .table(KeystoreEntry::Table)
.col( .col(
ColumnDef::new(KeystoreEntry::Id) ColumnDef::new(KeystoreEntry::Id)
.string() .string()
.not_null() .not_null()
.primary_key() .primary_key(),
)
.col(
ColumnDef::new(KeystoreEntry::Host)
.string()
.not_null()
)
.col(
ColumnDef::new(KeystoreEntry::Counter)
.integer()
.not_null()
) )
.col(ColumnDef::new(KeystoreEntry::Host).string().not_null())
.col(ColumnDef::new(KeystoreEntry::Counter).integer().not_null())
.col( .col(
ColumnDef::new(KeystoreEntry::SigningKey) ColumnDef::new(KeystoreEntry::SigningKey)
.binary() .binary()
.not_null() .not_null(),
) )
.col( .col(
ColumnDef::new(KeystoreEntry::ClientSigningKey) ColumnDef::new(KeystoreEntry::ClientSigningKey)
.binary() .binary()
.not_null() .not_null(),
) )
.col( .col(
ColumnDef::new(KeystoreEntry::ClientDHKey) ColumnDef::new(KeystoreEntry::ClientDHKey)
.binary() .binary()
.not_null() .not_null(),
)
.col(
ColumnDef::new(KeystoreEntry::Config)
.string()
.not_null()
) )
.col(ColumnDef::new(KeystoreEntry::Config).string().not_null())
.col( .col(
ColumnDef::new(KeystoreEntry::Certificate) ColumnDef::new(KeystoreEntry::Certificate)
.binary() .binary()
.not_null() .not_null(),
) )
.foreign_key( .foreign_key(
ForeignKey::create() ForeignKey::create()
.from(KeystoreEntry::Table, KeystoreEntry::Host) .from(KeystoreEntry::Table, KeystoreEntry::Host)
.to(Host::Table, Host::Id) .to(Host::Table, Host::Id),
) )
.index( .index(
Index::create() Index::create()
@ -62,10 +51,11 @@ impl MigrationTrait for Migration {
.table(KeystoreEntry::Table) .table(KeystoreEntry::Table)
.col(KeystoreEntry::Host) .col(KeystoreEntry::Host)
.col(KeystoreEntry::Counter) .col(KeystoreEntry::Counter)
.unique() .unique(),
) )
.to_owned() .to_owned(),
).await )
.await
} }
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
@ -86,5 +76,5 @@ enum KeystoreEntry {
ClientSigningKey, ClientSigningKey,
ClientDHKey, ClientDHKey,
Config, Config,
Certificate Certificate,
} }

View File

@ -42,5 +42,5 @@ pub struct ConfigEmail {
pub struct ConfigTokens { pub struct ConfigTokens {
pub magic_link_expiry_seconds: u64, pub magic_link_expiry_seconds: u64,
pub session_token_expiry_seconds: u64, pub session_token_expiry_seconds: u64,
pub auth_token_expiry_seconds: u64 pub auth_token_expiry_seconds: u64,
} }

View File

@ -2,6 +2,11 @@ use actix_web::error::{JsonPayloadError, PayloadError};
use serde::Serialize; use serde::Serialize;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
#[derive(Serialize, Debug)]
pub struct APIErrorsResponse {
pub errors: Vec<APIErrorResponse>,
}
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
pub struct APIErrorResponse { pub struct APIErrorResponse {
pub code: String, pub code: String,

View File

@ -46,11 +46,11 @@ pub struct TotpAuthenticator {
pub verified: bool, pub verified: bool,
pub name: String, pub name: String,
pub created_at: SystemTime, pub created_at: SystemTime,
pub last_seen_at: SystemTime pub last_seen_at: SystemTime,
} }
#[derive( #[derive(
Queryable, Selectable, Insertable, Identifiable, Associations, Debug, PartialEq, Clone, Queryable, Selectable, Insertable, Identifiable, Associations, Debug, PartialEq, Clone,
)] )]
#[diesel(belongs_to(User))] #[diesel(belongs_to(User))]
#[diesel(table_name = crate::schema::auth_tokens)] #[diesel(table_name = crate::schema::auth_tokens)]

View File

@ -1,4 +1,4 @@
use crate::error::APIErrorResponse; use crate::error::APIErrorsResponse;
use actix_web::body::BoxBody; use actix_web::body::BoxBody;
use actix_web::error::JsonPayloadError; use actix_web::error::JsonPayloadError;
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
@ -8,7 +8,7 @@ use std::fmt::{Debug, Display, Formatter};
#[derive(Debug)] #[derive(Debug)]
pub enum JsonAPIResponse<T: Serialize + Debug> { pub enum JsonAPIResponse<T: Serialize + Debug> {
Error(StatusCode, APIErrorResponse), Error(StatusCode, APIErrorsResponse),
Success(StatusCode, T), Success(StatusCode, T),
} }
@ -87,17 +87,21 @@ macro_rules! handle_error {
#[macro_export] #[macro_export]
macro_rules! make_err { macro_rules! make_err {
($c:expr,$m:expr,$p:expr) => { ($c:expr,$m:expr,$p:expr) => {
$crate::error::APIErrorResponse { $crate::error::APIErrorsResponse {
errors: vec![$crate::error::APIErrorResponse {
code: $c.to_string(), code: $c.to_string(),
message: $m.to_string(), message: $m.to_string(),
path: Some($p.to_string()), path: Some($p.to_string()),
}],
} }
}; };
($c:expr,$m:expr) => { ($c:expr,$m:expr) => {
$crate::error::APIErrorResponse { $crate::error::APIErrorsResponse {
errors: vec![$crate::error::APIErrorResponse {
code: $c.to_string(), code: $c.to_string(),
message: $m.to_string(), message: $m.to_string(),
path: None, path: None,
}],
} }
}; };
} }

View File

@ -1,3 +1,3 @@
pub mod magic_link; pub mod magic_link;
pub mod verify_magic_link;
pub mod totp; pub mod totp;
pub mod verify_magic_link;

View File

@ -1,19 +1,19 @@
use std::time::{Duration, SystemTime};
use actix_web::http::StatusCode;
use actix_web::{HttpRequest, post};
use actix_web::web::{Data, Json};
use serde::{Deserialize, Serialize};
use crate::{AppState, auth, enforce, randid};
use crate::response::JsonAPIResponse;
use diesel::{QueryDsl, ExpressionMethods, SelectableHelper, BelongingToDsl};
use diesel_async::RunQueryDsl;
use totp_rs::{Algorithm, Secret, TOTP};
use crate::schema::{auth_tokens, users, totp_authenticators};
use crate::models::{AuthToken, TotpAuthenticator, User}; use crate::models::{AuthToken, TotpAuthenticator, User};
use crate::response::JsonAPIResponse;
use crate::schema::{auth_tokens, totp_authenticators, users};
use crate::{auth, enforce, randid, AppState};
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest};
use diesel::{BelongingToDsl, ExpressionMethods, QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime};
use totp_rs::{Algorithm, Secret, TOTP};
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct TotpAuthReq { pub struct TotpAuthReq {
pub code: String pub code: String,
} }
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
@ -32,14 +32,27 @@ pub struct TotpAuthResp {
} }
#[post("/v1/auth/totp")] #[post("/v1/auth/totp")]
pub async fn totp_req(req: Json<TotpAuthReq>, state: Data<AppState>, req_info: HttpRequest) -> JsonAPIResponse<TotpAuthResp> { pub async fn totp_req(
req: Json<TotpAuthReq>,
state: Data<AppState>,
req_info: HttpRequest,
) -> JsonAPIResponse<TotpAuthResp> {
let mut conn = handle_error!(state.pool.get().await); let mut conn = handle_error!(state.pool.get().await);
let auth_info = auth!(req_info, conn); let auth_info = auth!(req_info, conn);
let session_token = enforce!(sess auth_info); let session_token = enforce!(sess auth_info);
let user = handle_error!(users::table.find(&session_token.user_id).first::<User>(&mut conn).await); let user = handle_error!(
users::table
.find(&session_token.user_id)
.first::<User>(&mut conn)
.await
);
let authenticators: Vec<TotpAuthenticator> = handle_error!(TotpAuthenticator::belonging_to(&user).load::<TotpAuthenticator>(&mut conn).await); let authenticators: Vec<TotpAuthenticator> = handle_error!(
TotpAuthenticator::belonging_to(&user)
.load::<TotpAuthenticator>(&mut conn)
.await
);
let mut found_valid_code = false; let mut found_valid_code = false;
let mut chosen_auther = None; let mut chosen_auther = None;
@ -47,16 +60,36 @@ pub async fn totp_req(req: Json<TotpAuthReq>, state: Data<AppState>, req_info: H
for totp_auther in authenticators { for totp_auther in authenticators {
if totp_auther.verified { if totp_auther.verified {
let secret = Secret::Encoded(totp_auther.secret.clone()); let secret = Secret::Encoded(totp_auther.secret.clone());
let totp_machine = handle_error!(TOTP::new(Algorithm::SHA1, 6, 1, 30, handle_error!(secret.to_bytes()), Some("Trifid".to_string()), user.email.clone())); let totp_machine = handle_error!(TOTP::new(
Algorithm::SHA1,
6,
1,
30,
handle_error!(secret.to_bytes()),
Some("Trifid".to_string()),
user.email.clone()
));
let is_valid = handle_error!(totp_machine.check_current(&req.code)); let is_valid = handle_error!(totp_machine.check_current(&req.code));
if is_valid { found_valid_code = true; chosen_auther = Some(totp_auther); break; } if is_valid {
found_valid_code = true;
chosen_auther = Some(totp_auther);
break;
}
} }
} }
if !found_valid_code { if !found_valid_code {
err!(StatusCode::UNAUTHORIZED, make_err!("ERR_UNAUTHORIZED", "unauthorized")); err!(
StatusCode::UNAUTHORIZED,
make_err!("ERR_UNAUTHORIZED", "unauthorized")
);
} }
handle_error!(diesel::update(&(chosen_auther.unwrap())).set(totp_authenticators::dsl::last_seen_at.eq(SystemTime::now())).execute(&mut conn).await); handle_error!(
diesel::update(&(chosen_auther.unwrap()))
.set(totp_authenticators::dsl::last_seen_at.eq(SystemTime::now()))
.execute(&mut conn)
.await
);
// issue auth token // issue auth token
@ -74,12 +107,10 @@ pub async fn totp_req(req: Json<TotpAuthReq>, state: Data<AppState>, req_info: H
.await .await
); );
ok!( ok!(TotpAuthResp {
TotpAuthResp {
data: TotpAuthRespData { data: TotpAuthRespData {
auth_token: new_token.id.clone() auth_token: new_token.id.clone()
}, },
metadata: TotpAuthRespMeta {} metadata: TotpAuthRespMeta {}
} })
)
} }

View File

@ -5,10 +5,10 @@ use crate::{randid, AppState};
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
use actix_web::post; use actix_web::post;
use actix_web::web::{Data, Json}; use actix_web::web::{Data, Json};
use diesel::result::OptionalExtension;
use diesel::QueryDsl; use diesel::QueryDsl;
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use diesel::result::OptionalExtension;
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
#[derive(Deserialize)] #[derive(Deserialize)]
@ -35,11 +35,14 @@ pub async fn verify_link_req(
req: Json<VerifyLinkReq>, req: Json<VerifyLinkReq>,
state: Data<AppState>, state: Data<AppState>,
) -> JsonAPIResponse<VerifyLinkResp> { ) -> JsonAPIResponse<VerifyLinkResp> {
let mut conn = handle_error!(state.pool.get().await); let mut conn = handle_error!(state.pool.get().await);
let token = match handle_error!(magic_links::table.find(&req.magic_link_token).first::<MagicLink>(&mut conn).await.optional()) { let token = match handle_error!(magic_links::table
.find(&req.magic_link_token)
.first::<MagicLink>(&mut conn)
.await
.optional())
{
Some(t) => t, Some(t) => t,
None => { None => {
err!( err!(

View File

@ -1,5 +1,8 @@
use crate::models::TotpAuthenticator; use crate::models::TotpAuthenticator;
use crate::models::User;
use crate::response::JsonAPIResponse; use crate::response::JsonAPIResponse;
use crate::schema::totp_authenticators;
use crate::schema::users;
use crate::{auth, enforce, randid, AppState}; use crate::{auth, enforce, randid, AppState};
use actix_web::web::{Data, Json}; use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest}; use actix_web::{post, HttpRequest};
@ -8,11 +11,8 @@ use diesel::QueryDsl;
use diesel::SelectableHelper; use diesel::SelectableHelper;
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use totp_rs::{Algorithm, Secret, TOTP};
use crate::schema::totp_authenticators;
use crate::schema::users;
use crate::models::User;
use std::time::SystemTime; use std::time::SystemTime;
use totp_rs::{Algorithm, Secret, TOTP};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct TotpAuthenticatorReq {} pub struct TotpAuthenticatorReq {}
@ -45,11 +45,24 @@ pub async fn create_totp_auth_req(
let auth_info = auth!(req_info, conn); let auth_info = auth!(req_info, conn);
let session_token = enforce!(sess auth_info); let session_token = enforce!(sess auth_info);
let user = handle_error!(users::table.find(&session_token.user_id).first::<User>(&mut conn).await); let user = handle_error!(
users::table
.find(&session_token.user_id)
.first::<User>(&mut conn)
.await
);
let secret = Secret::generate_secret(); let secret = Secret::generate_secret();
let totp = handle_error!(TOTP::new(Algorithm::SHA1, 6, 1, 30, handle_error!(secret.to_bytes()), Some("Trifid".to_string()), user.email.clone())); let totp = handle_error!(TOTP::new(
Algorithm::SHA1,
6,
1,
30,
handle_error!(secret.to_bytes()),
Some("Trifid".to_string()),
user.email.clone()
));
let new_totp_authenticator = TotpAuthenticator { let new_totp_authenticator = TotpAuthenticator {
id: randid!(id "totp"), id: randid!(id "totp"),
@ -58,7 +71,7 @@ pub async fn create_totp_auth_req(
verified: false, verified: false,
name: "".to_string(), name: "".to_string(),
created_at: SystemTime::now(), created_at: SystemTime::now(),
last_seen_at: SystemTime::now() last_seen_at: SystemTime::now(),
}; };
handle_error!( handle_error!(
@ -77,3 +90,6 @@ pub async fn create_totp_auth_req(
metadata: TotpAuthRespMeta {} metadata: TotpAuthRespMeta {}
}) })
} }
// TODO: Get, Edit, Delete
// All of these API endpoints are the same... they could probably be automated...

View File

@ -1,21 +1,21 @@
use std::time::{Duration, SystemTime};
use actix_web::http::StatusCode;
use actix_web::{HttpRequest, post};
use actix_web::web::{Data, Json};
use serde::{Deserialize, Serialize};
use crate::{AppState, auth, enforce, randid};
use crate::response::JsonAPIResponse;
use diesel::{QueryDsl, ExpressionMethods, SelectableHelper, OptionalExtension};
use diesel_async::RunQueryDsl;
use totp_rs::{Algorithm, Secret, TOTP};
use crate::schema::{auth_tokens, totp_authenticators, users};
use crate::models::{AuthToken, TotpAuthenticator, User}; use crate::models::{AuthToken, TotpAuthenticator, User};
use crate::response::JsonAPIResponse;
use crate::schema::{auth_tokens, totp_authenticators, users};
use crate::{auth, enforce, randid, AppState};
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest};
use diesel::{ExpressionMethods, OptionalExtension, QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime};
use totp_rs::{Algorithm, Secret, TOTP};
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct VerifyTotpAuthReq { pub struct VerifyTotpAuthReq {
#[serde(rename = "totpToken")] #[serde(rename = "totpToken")]
pub totp_token: String, pub totp_token: String,
pub code: String pub code: String,
} }
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
@ -34,14 +34,28 @@ pub struct TotpAuthResp {
} }
#[post("/v1/verify-totp-authenticator")] #[post("/v1/verify-totp-authenticator")]
pub async fn verify_totp_req(req: Json<VerifyTotpAuthReq>, state: Data<AppState>, req_info: HttpRequest) -> JsonAPIResponse<TotpAuthResp> { pub async fn verify_totp_req(
req: Json<VerifyTotpAuthReq>,
state: Data<AppState>,
req_info: HttpRequest,
) -> JsonAPIResponse<TotpAuthResp> {
let mut conn = handle_error!(state.pool.get().await); let mut conn = handle_error!(state.pool.get().await);
let auth_info = auth!(req_info, conn); let auth_info = auth!(req_info, conn);
let session_token = enforce!(sess auth_info); let session_token = enforce!(sess auth_info);
let user = handle_error!(users::table.find(&session_token.user_id).first::<User>(&mut conn).await); let user = handle_error!(
users::table
.find(&session_token.user_id)
.first::<User>(&mut conn)
.await
);
let authenticator = match handle_error!(totp_authenticators::table.find(&req.totp_token).first::<TotpAuthenticator>(&mut conn).await.optional()) { let authenticator = match handle_error!(totp_authenticators::table
.find(&req.totp_token)
.first::<TotpAuthenticator>(&mut conn)
.await
.optional())
{
Some(t) => t, Some(t) => t,
None => { None => {
err!( err!(
@ -67,14 +81,33 @@ pub async fn verify_totp_req(req: Json<VerifyTotpAuthReq>, state: Data<AppState>
} }
let secret = Secret::Encoded(authenticator.secret.clone()); let secret = Secret::Encoded(authenticator.secret.clone());
let totp_machine = handle_error!(TOTP::new(Algorithm::SHA1, 6, 1, 30, handle_error!(secret.to_bytes()), Some("Trifid".to_string()), user.email.clone())); let totp_machine = handle_error!(TOTP::new(
Algorithm::SHA1,
6,
1,
30,
handle_error!(secret.to_bytes()),
Some("Trifid".to_string()),
user.email.clone()
));
let is_valid = handle_error!(totp_machine.check_current(&req.code)); let is_valid = handle_error!(totp_machine.check_current(&req.code));
if !is_valid { if !is_valid {
err!(StatusCode::UNAUTHORIZED, make_err!("ERR_UNAUTHORIZED", "unauthorized")); err!(
StatusCode::UNAUTHORIZED,
make_err!("ERR_UNAUTHORIZED", "unauthorized")
);
} }
handle_error!(diesel::update(&authenticator).set((totp_authenticators::dsl::verified.eq(true), totp_authenticators::dsl::last_seen_at.eq(SystemTime::now()))).execute(&mut conn).await); handle_error!(
diesel::update(&authenticator)
.set((
totp_authenticators::dsl::verified.eq(true),
totp_authenticators::dsl::last_seen_at.eq(SystemTime::now())
))
.execute(&mut conn)
.await
);
// issue auth token // issue auth token
@ -92,12 +125,10 @@ pub async fn verify_totp_req(req: Json<VerifyTotpAuthReq>, state: Data<AppState>
.await .await
); );
ok!( ok!(TotpAuthResp {
TotpAuthResp {
data: TotpAuthRespData { data: TotpAuthRespData {
auth_token: new_token.id.clone() auth_token: new_token.id.clone()
}, },
metadata: TotpAuthRespMeta {} metadata: TotpAuthRespMeta {}
} })
)
} }