0.3 error fixes, refmt
/ build (push) Successful in 49s Details
/ build_x64 (push) Successful in 2m6s Details
/ build_arm64 (push) Successful in 2m40s Details
/ build_win64 (push) Successful in 2m36s Details

This commit is contained in:
core 2023-11-23 15:23:52 -05:00
parent 2a5a2bb910
commit 646340b637
Signed by: core
GPG Key ID: FDBF740DADDCEECF
41 changed files with 1441 additions and 878 deletions

View File

@ -254,7 +254,10 @@ impl Client {
ed_privkey.verify(b64_msg_bytes, &Signature::from_slice(&signature)?)?;
debug!("signature valid via clientside check");
debug!("signed with key: {:x?}", ed_privkey.verifying_key().as_bytes());
debug!(
"signed with key: {:x?}",
ed_privkey.verifying_key().as_bytes()
);
let body = RequestV1 {
version: 1,

View File

@ -0,0 +1,18 @@
# API Support
This document is valid for: `trifid-api 0.3.0`.
This document is only useful for developers, and it lists what endpoint versions are currently supported by trifid-api. This is subject to change at any time according to SemVer constraints.
Endpoint types:
- **Documented** is an endpoint available in the official documentation
- **Reverse-engineered** is an endpoint that was reverse-engineered
| Endpoint Name | Version | Endpoint | Type | Added In |
|---------------------------|---------|------------------------------------|--------------------|----------------|
| Signup | v1 | POST /v1/signup | Reverse-engineered | 0.3.0/79b1765e |
| Get Magic Link | v1 | POST /v1/auth/magic-link | Reverse-engineered | 0.3.0/52049947 |
| Verify Magic Link | v1 | POST /v1/auth/verify-magic-link | Reverse-engineered | 0.3.0/51b6d3a8 |
| Create TOTP Authenticator | v1 | POST /v1/totp-authenticators | Reverse-engineered | 0.3.0/4180bdd1 |
| Verify TOTP Authenticator | v1 | POST /v1/verify-totp-authenticator | Reverse-engineered | 0.3.0/19332e51 |
| Authenticate with TOTP | v1 | POST /v1/auth/totp | Reverse-engineered | 0.3.0/19332e51 |

View File

@ -1,7 +1,7 @@
use std::{env, process};
use std::path::PathBuf;
use bindgen::CargoCallbacks;
use std::path::Path;
use std::path::PathBuf;
use std::{env, process};
fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?);
@ -20,7 +20,6 @@ fn get_cargo_target_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Erro
}
fn main() {
// Find compiler:
// 1. GOC
// 2. /usr/local/go/bin/go
@ -49,7 +48,14 @@ fn main() {
let out = out_path.join(out_file);
let mut command = process::Command::new(compiler);
command.args(["build", "-buildmode", link_type().as_str(), "-o", out.display().to_string().as_str(), "main.go"]);
command.args([
"build",
"-buildmode",
link_type().as_str(),
"-o",
out.display().to_string().as_str(),
"main.go",
]);
command.env("CGO_ENABLED", "1");
command.env("CC", c_compiler.path());
command.env("GOARCH", goarch());
@ -68,7 +74,10 @@ fn main() {
copy_if_windows();
print_link();
println!("cargo:rustc-link-search=native={}", env::var("OUT_DIR").unwrap());
println!(
"cargo:rustc-link-search=native={}",
env::var("OUT_DIR").unwrap()
);
//let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
@ -85,7 +94,6 @@ fn main() {
.generate()
.expect("Error generating CFFI bindings");
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
@ -125,8 +133,9 @@ fn goarch() -> String {
"powerpc64" => "ppc64",
"arm" => "arm",
"aarch64" => "arm64",
arch => panic!("unsupported architecture {arch}")
}.to_string()
arch => panic!("unsupported architecture {arch}"),
}
.to_string()
}
fn goos() -> String {
match env::var("CARGO_CFG_TARGET_OS").unwrap().as_str() {
@ -139,8 +148,9 @@ fn goos() -> String {
"dragonfly" => "dragonfly",
"openbsd" => "openbsd",
"netbsd" => "netbsd",
os => panic!("unsupported operating system {os}")
}.to_string()
os => panic!("unsupported operating system {os}"),
}
.to_string()
}
fn print_link() {

View File

@ -25,7 +25,6 @@
#![deny(clippy::missing_panics_doc)]
#![deny(clippy::missing_safety_doc)]
#[allow(non_upper_case_globals)]
#[allow(non_camel_case_types)]
#[allow(non_snake_case)]
@ -36,12 +35,11 @@ pub mod generated {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
use generated::GoString;
use std::error::Error;
use std::ffi::{c_char, CString};
use std::fmt::{Display, Formatter};
use std::path::{Path};
use generated::GoString;
use std::path::Path;
impl From<&str> for GoString {
#[allow(clippy::cast_possible_wrap)]
@ -51,7 +49,7 @@ impl From<&str> for GoString {
let ptr = c_str.as_ptr();
let go_string = GoString {
p: ptr,
n: c_str.as_bytes().len() as isize
n: c_str.as_bytes().len() as isize,
};
go_string
}
@ -73,14 +71,18 @@ impl NebulaInstance {
/// # Panics
/// This function will panic if memory is corrupted while communicating with Go.
pub fn new(config_path: &Path, config_test: bool) -> Result<Self, Box<dyn Error>> {
let mut config_path_bytes = unsafe { config_path.display().to_string().as_bytes_mut().to_vec() };
let mut config_path_bytes =
unsafe { config_path.display().to_string().as_bytes_mut().to_vec() };
config_path_bytes.push(0u8);
let config_test_u8 = u8::from(config_test);
let res;
unsafe {
res = generated::NebulaSetup(config_path_bytes.as_mut_ptr().cast::<c_char>(), config_test_u8);
res = generated::NebulaSetup(
config_path_bytes.as_mut_ptr().cast::<c_char>(),
config_test_u8,
);
}
let res = cstring_to_string(res);
@ -194,18 +196,18 @@ pub enum NebulaError {
/// Returned by nebula when the TUN/TAP device already exists
DeviceOrResourceBusy {
/// The complete error string returned by the Nebula wrapper
error_str: String
error_str: String,
},
/// An unknown error that the error parser couldn't figure out how to parse.
Unknown {
/// The complete error string returned by the Nebula wrapper
error_str: String
error_str: String,
},
/// Occurs if you call a function before NebulaSetup has been called
NebulaNotSetup {
/// The complete error string returned by the Nebula wrapper
error_str: String
}
error_str: String,
},
}
impl Display for NebulaError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
@ -223,11 +225,17 @@ impl NebulaError {
#[must_use]
pub fn from_string(string: &str) -> Self {
if string.starts_with("device or resource busy") {
Self::DeviceOrResourceBusy { error_str: string.to_string() }
Self::DeviceOrResourceBusy {
error_str: string.to_string(),
}
} else if string.starts_with("NebulaSetup has not yet been called") {
Self::NebulaNotSetup { error_str: string.to_string() }
Self::NebulaNotSetup {
error_str: string.to_string(),
}
} else {
Self::Unknown { error_str: string.to_string() }
Self::Unknown {
error_str: string.to_string(),
}
}
}
}

View File

@ -1,38 +1,48 @@
use crate::api::APIErrorResponse;
use crate::AccountCommands;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::AccountCommands;
use crate::api::APIErrorResponse;
pub async fn account_main(command: AccountCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command {
AccountCommands::Create { email } => create_account(email, server).await,
AccountCommands::MagicLink { magic_link_token } => auth_magic_link(magic_link_token, server).await,
AccountCommands::MagicLink { magic_link_token } => {
auth_magic_link(magic_link_token, server).await
}
AccountCommands::MfaSetup {} => create_mfa_authenticator(server).await,
AccountCommands::MfaSetupFinish {code, token} => finish_mfa_authenticator(token, code, server).await,
AccountCommands::Mfa {code} => mfa_auth(code, server).await,
AccountCommands::Login { email } => login_account(email, server).await
AccountCommands::MfaSetupFinish { code, token } => {
finish_mfa_authenticator(token, code, server).await
}
AccountCommands::Mfa { code } => mfa_auth(code, server).await,
AccountCommands::Login { email } => login_account(email, server).await,
}
}
#[derive(Serialize)]
pub struct CreateAccountBody {
pub email: String
pub email: String,
}
pub async fn create_account(email: String, server: Url) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
let res = client.post(server.join("/v1/signup")?).json(&CreateAccountBody { email }).send().await?;
let res = client
.post(server.join("/v1/signup")?)
.json(&CreateAccountBody { email })
.send()
.await?;
if res.status().is_success() {
println!("Account created successfully, check your email.");
println!("Finish creating your account with 'tfcli account magic-link --magic-link-token [magic-link-token]'.");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating account: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error creating account: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -42,21 +52,27 @@ pub async fn create_account(email: String, server: Url) -> Result<(), Box<dyn Er
#[derive(Serialize)]
pub struct LoginAccountBody {
pub email: String
pub email: String,
}
pub async fn login_account(email: String, server: Url) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
let res = client.post(server.join("/v1/auth/magic-link")?).json(&LoginAccountBody { email }).send().await?;
let res = client
.post(server.join("/v1/auth/magic-link")?)
.json(&LoginAccountBody { email })
.send()
.await?;
if res.status().is_success() {
println!("Magic link sent, check your email.");
println!("Finish creating your account with 'tfcli account magic-link --magic-link-token [magic-link-token]'.");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error logging in: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error logging in: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -64,26 +80,31 @@ pub async fn login_account(email: String, server: Url) -> Result<(), Box<dyn Err
Ok(())
}
#[derive(Serialize)]
pub struct MagicLinkBody {
#[serde(rename = "magicLinkToken")]
pub magic_link_token: String
pub magic_link_token: String,
}
#[derive(Deserialize)]
pub struct MagicLinkSuccess {
pub data: MagicLinkSuccessBody
pub data: MagicLinkSuccessBody,
}
#[derive(Deserialize)]
pub struct MagicLinkSuccessBody {
#[serde(rename = "sessionToken")]
pub session_token: String
pub session_token: String,
}
pub async fn auth_magic_link(magic_token: String, server: Url) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
let res = client.post(server.join("/v1/auth/verify-magic-link")?).json(&MagicLinkBody { magic_link_token: magic_token }).send().await?;
let res = client
.post(server.join("/v1/auth/verify-magic-link")?)
.json(&MagicLinkBody {
magic_link_token: magic_token,
})
.send()
.await?;
if res.status().is_success() {
let resp: MagicLinkSuccess = res.json().await?;
@ -97,7 +118,10 @@ pub async fn auth_magic_link(magic_token: String, server: Url) -> Result<(), Box
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error getting session token: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error getting session token: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -135,14 +159,14 @@ pub struct WhoamiResponseMetadata {}
#[derive(Deserialize)]
pub struct CreateMfaResponse {
pub data: CreateMfaResponseData
pub data: CreateMfaResponseData,
}
#[derive(Deserialize)]
pub struct CreateMfaResponseData {
#[serde(rename = "totpToken")]
pub totp_token: String,
pub secret: String,
pub url: String
pub url: String,
}
pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>> {
@ -153,14 +177,25 @@ pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>>
let session_token = fs::read_to_string(&token_store)?;
// do we have mfa already?
let whoami: WhoamiResponse = client.get(server.join("/v2/whoami")?).bearer_auth(&session_token).send().await?.json().await?;
let whoami: WhoamiResponse = client
.get(server.join("/v2/whoami")?)
.bearer_auth(&session_token)
.send()
.await?
.json()
.await?;
if whoami.data.actor.has_totp_authenticator {
eprintln!("[error] user already has a totp authenticator, cannot add another one");
std::process::exit(1);
}
let res = client.post(server.join("/v1/totp-authenticators")?).bearer_auth(&session_token).body("{}").send().await?;
let res = client
.post(server.join("/v1/totp-authenticators")?)
.bearer_auth(&session_token)
.body("{}")
.send()
.await?;
if res.status().is_success() {
let resp: CreateMfaResponse = res.json().await?;
@ -169,14 +204,23 @@ pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>>
println!("To complete setup, you'll need a TOTP-compatible app, such as Google Authenticator or Authy.");
println!("Scan the following code with your authenticator app:");
qr2term::print_qr(resp.data.url)?;
println!("Alternatively, enter the following secret into your authenticator app: '{}'", resp.data.secret);
println!(
"Alternatively, enter the following secret into your authenticator app: '{}'",
resp.data.secret
);
println!("Once done, enable TOTP by running the following command with the code shown on your authenticator app:");
println!("tfcli account mfa-setup-finish --token {} --code [CODE IN AUTHENTICATOR]", resp.data.totp_token);
println!(
"tfcli account mfa-setup-finish --token {} --code [CODE IN AUTHENTICATOR]",
resp.data.totp_token
);
println!("This code will expire in 10 minutes.");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error adding MFA to account: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error adding MFA to account: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -188,27 +232,39 @@ pub async fn create_mfa_authenticator(server: Url) -> Result<(), Box<dyn Error>>
pub struct MfaVerifyBody {
#[serde(rename = "totpToken")]
pub totp_token: String,
pub code: String
pub code: String,
}
#[derive(Deserialize)]
pub struct MFASuccess {
pub data: MFASuccessBody
pub data: MFASuccessBody,
}
#[derive(Deserialize)]
pub struct MFASuccessBody {
#[serde(rename = "authToken")]
pub auth_token: String
pub auth_token: String,
}
pub async fn finish_mfa_authenticator(token: String, code: String, server: Url) -> Result<(), Box<dyn Error>> {
pub async fn finish_mfa_authenticator(
token: String,
code: String,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
// load session token
let token_store = dirs::config_dir().unwrap().join("tfcli-session.token");
let session_token = fs::read_to_string(&token_store)?;
let res = client.post(server.join("/v1/verify-totp-authenticators")?).json(&MfaVerifyBody {totp_token: token, code }).bearer_auth(session_token).send().await?;
let res = client
.post(server.join("/v1/verify-totp-authenticators")?)
.json(&MfaVerifyBody {
totp_token: token,
code,
})
.bearer_auth(session_token)
.send()
.await?;
if res.status().is_success() {
let resp: MFASuccess = res.json().await?;
@ -222,7 +278,10 @@ pub async fn finish_mfa_authenticator(token: String, code: String, server: Url)
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error verifying MFA code: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error verifying MFA code: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -232,7 +291,7 @@ pub async fn finish_mfa_authenticator(token: String, code: String, server: Url)
#[derive(Serialize)]
pub struct MfaAuthBody {
pub code: String
pub code: String,
}
pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -242,7 +301,12 @@ pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> {
let token_store = dirs::config_dir().unwrap().join("tfcli-session.token");
let session_token = fs::read_to_string(&token_store)?;
let res = client.post(server.join("/v1/auth/totp")?).json(&MfaAuthBody { code }).bearer_auth(session_token).send().await?;
let res = client
.post(server.join("/v1/auth/totp")?)
.json(&MfaAuthBody { code })
.bearer_auth(session_token)
.send()
.await?;
if res.status().is_success() {
let resp: MFASuccess = res.json().await?;
@ -256,7 +320,10 @@ pub async fn mfa_auth(code: String, server: Url) -> Result<(), Box<dyn Error>> {
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error verifying MFA code: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error verifying MFA code: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}

View File

@ -2,11 +2,11 @@ use serde::Deserialize;
#[derive(Deserialize)]
pub struct APIErrorResponse {
pub errors: Vec<APIError>
pub errors: Vec<APIError>,
}
#[derive(Deserialize)]
pub struct APIError {
pub code: String,
pub message: String,
pub path: Option<String>
pub path: Option<String>,
}

View File

@ -1,34 +1,68 @@
use crate::api::APIErrorResponse;
use crate::{HostCommands, HostOverrideCommands};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use serde::{Deserialize, Serialize};
use url::{Url};
use crate::api::APIErrorResponse;
use crate::{HostCommands, HostOverrideCommands};
use url::Url;
pub async fn host_main(command: HostCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command {
HostCommands::List {} => list_hosts(server).await,
HostCommands::Create { name, network_id, role_id, ip_address, listen_port, lighthouse, relay, static_address } => create_host(name, network_id, role_id, ip_address, listen_port, lighthouse, relay, static_address, server).await,
HostCommands::Create {
name,
network_id,
role_id,
ip_address,
listen_port,
lighthouse,
relay,
static_address,
} => {
create_host(
name,
network_id,
role_id,
ip_address,
listen_port,
lighthouse,
relay,
static_address,
server,
)
.await
}
HostCommands::Lookup { id } => get_host(id, server).await,
HostCommands::Delete { id } => delete_host(id, server).await,
HostCommands::Update { id, listen_port, static_address, name, ip, role } => update_host(id, listen_port, static_address, name, ip, role, server).await,
HostCommands::Update {
id,
listen_port,
static_address,
name,
ip,
role,
} => update_host(id, listen_port, static_address, name, ip, role, server).await,
HostCommands::Block { id } => block_host(id, server).await,
HostCommands::Enroll { id } => enroll_host(id, server).await,
HostCommands::Overrides { command } => match command {
HostOverrideCommands::List { id } => list_overrides(id, server).await,
HostOverrideCommands::Set { id, key, boolean, string, numeric } => set_override(id, key, boolean, numeric, string, server).await,
HostOverrideCommands::Unset { id, key } => unset_override(id, key, server).await
}
HostOverrideCommands::Set {
id,
key,
boolean,
string,
numeric,
} => set_override(id, key, boolean, numeric, string, server).await,
HostOverrideCommands::Unset { id, key } => unset_override(id, key, server).await,
},
}
}
#[derive(Deserialize)]
pub struct HostListResp {
pub data: Vec<Host>
pub data: Vec<Host>,
}
#[derive(Serialize, Deserialize)]
pub struct HostMetadata {
#[serde(rename = "lastSeenAt")]
@ -77,7 +111,11 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join("/v1/hosts?pageSize=5000")?).bearer_auth(token).send().await?;
let res = client
.get(server.join("/v1/hosts?pageSize=5000")?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let resp: HostListResp = res.json().await?;
@ -88,14 +126,33 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", "));
println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } );
println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available);
println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at);
println!();
}
@ -106,7 +163,10 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing hosts: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error listing hosts: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -114,7 +174,6 @@ pub async fn list_hosts(server: Url) -> Result<(), Box<dyn Error>> {
Ok(())
}
#[derive(Serialize, Deserialize)]
pub struct HostCreateBody {
pub name: String,
@ -134,11 +193,9 @@ pub struct HostCreateBody {
pub static_addresses: Vec<SocketAddrV4>,
}
#[derive(Serialize, Deserialize)]
pub struct HostGetMetadata {}
#[derive(Serialize, Deserialize)]
pub struct HostGetResponse {
pub data: Host,
@ -146,7 +203,17 @@ pub struct HostGetResponse {
}
#[allow(clippy::too_many_arguments)]
pub async fn create_host(name: String, network_id: String, role_id: String, ip_address: Ipv4Addr, listen_port: Option<u16>, lighthouse: bool, relay: bool, static_address: Option<SocketAddrV4>, server: Url) -> Result<(), Box<dyn Error>> {
pub async fn create_host(
name: String,
network_id: String,
role_id: String,
ip_address: Ipv4Addr,
listen_port: Option<u16>,
lighthouse: bool,
relay: bool,
static_address: Option<SocketAddrV4>,
server: Url,
) -> Result<(), Box<dyn Error>> {
if lighthouse && relay {
eprintln!("[error] Error creating host: a host cannot be both a lighthouse and a relay at the same time");
std::process::exit(1);
@ -172,7 +239,9 @@ pub async fn create_host(name: String, network_id: String, role_id: String, ip_a
let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join("/v1/hosts")?).json(&HostCreateBody {
let res = client
.post(server.join("/v1/hosts")?)
.json(&HostCreateBody {
name,
network_id,
role_id,
@ -181,7 +250,10 @@ pub async fn create_host(name: String, network_id: String, role_id: String, ip_a
is_lighthouse: lighthouse,
is_relay: relay,
static_addresses: static_address.map_or(vec![], |u| vec![u]),
}).bearer_auth(token).send().await?;
})
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let host: Host = res.json::<HostGetResponse>().await?.data;
@ -191,21 +263,42 @@ pub async fn create_host(name: String, network_id: String, role_id: String, ip_a
println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", "));
println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } );
println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available);
println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at);
println!();
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating host: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error creating host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -224,7 +317,11 @@ pub async fn get_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}", id))?).bearer_auth(token).send().await?;
let res = client
.get(server.join(&format!("/v1/hosts/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let host: Host = res.json::<HostGetResponse>().await?.data;
@ -234,21 +331,42 @@ pub async fn get_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", "));
println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } );
println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available);
println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at);
println!();
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing hosts: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error listing hosts: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -267,14 +385,21 @@ pub async fn delete_host(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token);
let res = client.delete(server.join(&format!("/v1/hosts/{}", id))?).bearer_auth(token).send().await?;
let res = client
.delete(server.join(&format!("/v1/hosts/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
println!("Host removed");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error removing host: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error removing host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -290,10 +415,18 @@ pub struct HostUpdateBody {
pub static_addresses: Vec<SocketAddrV4>,
pub name: Option<String>,
pub ip: Option<Ipv4Addr>,
pub role: Option<String>
pub role: Option<String>,
}
pub async fn update_host(id: String, listen_port: Option<u16>, static_address: Option<SocketAddrV4>, name: Option<String>, ip: Option<Ipv4Addr>, role: Option<String>, server: Url) -> Result<(), Box<dyn Error>> {
pub async fn update_host(
id: String,
listen_port: Option<u16>,
static_address: Option<SocketAddrV4>,
name: Option<String>,
ip: Option<Ipv4Addr>,
role: Option<String>,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
// load session token
@ -304,13 +437,18 @@ pub async fn update_host(id: String, listen_port: Option<u16>, static_address: O
let token = format!("{} {}", session_token, auth_token);
let res = client.put(server.join(&format!("/v1/hosts/{}?extension=extended_hosts", id))?).json(&HostUpdateBody {
let res = client
.put(server.join(&format!("/v1/hosts/{}?extension=extended_hosts", id))?)
.json(&HostUpdateBody {
listen_port: listen_port.unwrap_or(0),
static_addresses: static_address.map_or_else(Vec::new, |u| vec![u]),
name,
ip,
role
}).bearer_auth(token).send().await?;
role,
})
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let host: Host = res.json::<HostGetResponse>().await?.data;
@ -320,21 +458,42 @@ pub async fn update_host(id: String, listen_port: Option<u16>, static_address: O
println!(" Network: {}", host.network_id);
println!(" Role: {}", host.role_id);
println!(" IP Address: {}", host.ip_address);
println!(" Static Addresses: {}", host.static_addresses.iter().map(|u| u.to_string()).collect::<Vec<_>>().join(", "));
println!(
" Static Addresses: {}",
host.static_addresses
.iter()
.map(|u| u.to_string())
.collect::<Vec<_>>()
.join(", ")
);
println!(" Listen Port: {}", host.listen_port);
println!(" Type: {}", if host.is_lighthouse { "Lighthouse" } else if host.is_relay { "Relay" } else { "Host" } );
println!(
" Type: {}",
if host.is_lighthouse {
"Lighthouse"
} else if host.is_relay {
"Relay"
} else {
"Host"
}
);
println!(" Blocked: {}", host.is_blocked);
println!(" Last Seen: {}", host.metadata.last_seen_at);
println!(" Client Version: {}", host.metadata.version);
println!(" Platform: {}", host.metadata.platform);
println!("Client Update Available: {}", host.metadata.update_available);
println!(
"Client Update Available: {}",
host.metadata.update_available
);
println!(" Created: {}", host.created_at);
println!();
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error updating host: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error updating host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -342,7 +501,6 @@ pub async fn update_host(id: String, listen_port: Option<u16>, static_address: O
Ok(())
}
#[derive(Serialize, Deserialize)]
pub struct EnrollmentCodeResponseMetadata {}
@ -376,18 +534,29 @@ pub async fn enroll_host(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join(&format!("/v1/hosts/{}/enrollment-code", id))?).header("content-length", 0).bearer_auth(token).send().await?;
let res = client
.post(server.join(&format!("/v1/hosts/{}/enrollment-code", id))?)
.header("content-length", 0)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let resp: EnrollmentResponse = res.json().await?;
println!("Enrollment code generated. Enroll the host with the following code: {}", resp.data.enrollment_code.code);
println!(
"Enrollment code generated. Enroll the host with the following code: {}",
resp.data.enrollment_code.code
);
println!("This code will be valid for {} seconds, at which point you will need to generate a new code", resp.data.enrollment_code.lifetime_seconds);
println!("If this host is blocked, a successful re-enrollment will unblock it.");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error blocking host: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error blocking host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -406,14 +575,22 @@ pub async fn block_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join(&format!("/v1/hosts/{}/block", id))?).header("Content-Length", "0").bearer_auth(token).send().await?;
let res = client
.post(server.join(&format!("/v1/hosts/{}/block", id))?)
.header("Content-Length", "0")
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
println!("Host blocked. To unblock it, re-enroll the host.");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error blocking host: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error blocking host: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -423,18 +600,18 @@ pub async fn block_host(id: String, server: Url) -> Result<(), Box<dyn Error>> {
#[derive(Serialize, Deserialize)]
pub struct HostConfigOverrideResponse {
pub data: HostConfigOverrideData
pub data: HostConfigOverrideData,
}
#[derive(Serialize, Deserialize)]
pub struct HostConfigOverrideData {
pub overrides: Vec<HostConfigOverrideDataOverride>
pub overrides: Vec<HostConfigOverrideDataOverride>,
}
#[derive(Serialize, Deserialize)]
pub struct HostConfigOverrideDataOverride {
pub key: String,
pub value: HostConfigOverrideDataOverrideValue
pub value: HostConfigOverrideDataOverrideValue,
}
#[derive(Serialize, Deserialize)]
@ -442,7 +619,7 @@ pub struct HostConfigOverrideDataOverride {
pub enum HostConfigOverrideDataOverrideValue {
Boolean(bool),
Numeric(i64),
Other(String)
Other(String),
}
pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -456,18 +633,25 @@ pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token).send().await?;
let res = client
.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?;
for c_override in &resp.data.overrides {
println!(" Key: {}", c_override.key);
println!("Value: {}", match &c_override.value {
println!(
"Value: {}",
match &c_override.value {
HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v),
HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v),
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v)
});
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v),
}
);
}
if resp.data.overrides.is_empty() {
@ -476,7 +660,10 @@ pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error looking up config overrides: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error looking up config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -486,14 +673,24 @@ pub async fn list_overrides(id: String, server: Url) -> Result<(), Box<dyn Error
#[derive(Serialize, Deserialize)]
pub struct SetOverrideRequest {
pub overrides: Vec<HostConfigOverrideDataOverride>
pub overrides: Vec<HostConfigOverrideDataOverride>,
}
pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeric: Option<i64>, other: Option<String>, server: Url) -> Result<(), Box<dyn Error>> {
pub async fn set_override(
id: String,
key: String,
boolean: Option<bool>,
numeric: Option<i64>,
other: Option<String>,
server: Url,
) -> Result<(), Box<dyn Error>> {
if boolean.is_none() && numeric.is_none() && other.is_none() {
eprintln!("[error] no value provided: you must provide at least --boolean, --numeric, or --string");
std::process::exit(1);
} else if boolean.is_some() && numeric.is_some() || boolean.is_some() && other.is_some() || numeric.is_some() && other.is_some() {
} else if boolean.is_some() && numeric.is_some()
|| boolean.is_some() && other.is_some()
|| numeric.is_some() && other.is_some()
{
eprintln!("[error] multiple values provided: you must provide only one of --boolean, --numeric, or --string");
std::process::exit(1);
}
@ -520,7 +717,11 @@ pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeri
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).send().await?;
let res = client
.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token.clone())
.send()
.await?;
if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?;
@ -533,25 +734,28 @@ pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeri
}
}
others.push(HostConfigOverrideDataOverride {
key,
value: val,
});
others.push(HostConfigOverrideDataOverride { key, value: val });
let res = client.put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).json(&SetOverrideRequest {
overrides: others,
}).send().await?;
let res = client
.put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token.clone())
.json(&SetOverrideRequest { overrides: others })
.send()
.await?;
if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?;
for c_override in &resp.data.overrides {
println!(" Key: {}", c_override.key);
println!("Value: {}", match &c_override.value {
println!(
"Value: {}",
match &c_override.value {
HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v),
HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v),
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v)
});
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v),
}
);
}
if resp.data.overrides.is_empty() {
@ -562,14 +766,20 @@ pub async fn set_override(id: String, key: String, boolean: Option<bool>, numeri
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error setting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error setting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error setting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error setting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -588,7 +798,11 @@ pub async fn unset_override(id: String, key: String, server: Url) -> Result<(),
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).send().await?;
let res = client
.get(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token.clone())
.send()
.await?;
if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?;
@ -601,20 +815,26 @@ pub async fn unset_override(id: String, key: String, server: Url) -> Result<(),
}
}
let res = client.put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?).bearer_auth(token.clone()).json(&SetOverrideRequest {
overrides: others,
}).send().await?;
let res = client
.put(server.join(&format!("/v1/hosts/{}/config-overrides", id))?)
.bearer_auth(token.clone())
.json(&SetOverrideRequest { overrides: others })
.send()
.await?;
if res.status().is_success() {
let resp: HostConfigOverrideResponse = res.json().await?;
for c_override in &resp.data.overrides {
println!(" Key: {}", c_override.key);
println!("Value: {}", match &c_override.value {
println!(
"Value: {}",
match &c_override.value {
HostConfigOverrideDataOverrideValue::Boolean(v) => format!("bool:{}", v),
HostConfigOverrideDataOverrideValue::Numeric(v) => format!("numeric:{}", v),
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v)
});
HostConfigOverrideDataOverrideValue::Other(v) => format!("string:{}", v),
}
);
}
if resp.data.overrides.is_empty() {
@ -625,14 +845,20 @@ pub async fn unset_override(id: String, key: String, server: Url) -> Result<(),
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error unsetting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error unsetting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error unsetting config overrides: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error unsetting config overrides: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}

View File

@ -1,21 +1,21 @@
use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use clap::{Parser, Subcommand};
use ipnet::Ipv4Net;
use url::Url;
use crate::account::account_main;
use crate::host::host_main;
use crate::network::network_main;
use crate::org::org_main;
use crate::role::role_main;
use clap::{Parser, Subcommand};
use ipnet::Ipv4Net;
use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use url::Url;
pub mod account;
pub mod api;
pub mod host;
pub mod network;
pub mod org;
pub mod role;
pub mod host;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -25,7 +25,7 @@ pub struct Args {
command: Commands,
#[clap(short, long, env = "TFCLI_SERVER")]
/// The base URL of your trifid-api-old instance. Defaults to the value in $XDG_CONFIG_HOME/tfcli-server-url.conf or the TFCLI_SERVER environment variable.
server: Option<Url>
server: Option<Url>,
}
#[derive(Subcommand, Debug)]
@ -33,28 +33,28 @@ pub enum Commands {
/// Manage your trifid account
Account {
#[command(subcommand)]
command: AccountCommands
command: AccountCommands,
},
/// Manage the networks associated with your trifid account
Network {
#[command(subcommand)]
command: NetworkCommands
command: NetworkCommands,
},
/// Manage the organization associated with your trifid account
Org {
#[command(subcommand)]
command: OrgCommands
command: OrgCommands,
},
/// Manage the roles associated with your trifid organization
Role {
#[command(subcommand)]
command: RoleCommands
command: RoleCommands,
},
/// Manage the hosts associated with your trifid network
Host {
#[command(subcommand)]
command: HostCommands
}
command: HostCommands,
},
}
#[derive(Subcommand, Debug)]
@ -62,17 +62,17 @@ pub enum AccountCommands {
/// Create a new trifid account on the designated server
Create {
#[clap(short, long)]
email: String
email: String,
},
/// Log into an existing account on the designated server
Login {
#[clap(short, long)]
email: String
email: String,
},
/// Log in to your account with a magic-link token acquired via email or the trifid-api-old logs.
MagicLink {
#[clap(short, long)]
magic_link_token: String
magic_link_token: String,
},
/// Create a new TOTP authenticator on this account to enable authorizing with 2fa and performing all management tasks.
MfaSetup {},
@ -81,13 +81,13 @@ pub enum AccountCommands {
#[clap(short, long)]
code: String,
#[clap(short, long)]
token: String
token: String,
},
/// Create a new short-lived authentication token by inputting the code shown on your authenticator app.
Mfa {
#[clap(short, long)]
code: String
}
code: String,
},
}
#[derive(Subcommand, Debug)]
@ -97,8 +97,8 @@ pub enum NetworkCommands {
/// Lookup a specific network by ID.
Lookup {
#[clap(short, long)]
id: String
}
id: String,
},
}
#[derive(Subcommand, Debug)]
@ -106,8 +106,8 @@ pub enum OrgCommands {
/// Create an organization on your trifid-api-old server. NOTE: This command ONLY works on trifid-api-old servers. It will NOT work on original DN servers.
Create {
#[clap(short, long)]
cidr: Ipv4Net
}
cidr: Ipv4Net,
},
}
#[derive(Subcommand, Debug)]
@ -120,19 +120,19 @@ pub enum RoleCommands {
description: String,
/// A JSON string containing the firewall rules to add to this host
#[clap(short, long)]
rules_json: String
rules_json: String,
},
/// List all roles attached to your organization
List {},
/// Lookup a specific role by it's ID
Lookup {
#[clap(short, long)]
id: String
id: String,
},
/// Delete a specific role by it's ID
Delete {
#[clap(short, long)]
id: String
id: String,
},
/// Update a specific role by it's ID. Warning: any data not provided in this update will be removed - include all data you wish to remain
Update {
@ -142,8 +142,8 @@ pub enum RoleCommands {
description: String,
/// A JSON string containing the firewall rules to add to this host
#[clap(short, long)]
rules_json: String
}
rules_json: String,
},
}
#[derive(Subcommand, Debug)]
@ -165,19 +165,19 @@ pub enum HostCommands {
#[clap(short = 'R', long)]
relay: bool,
#[clap(short, long)]
static_address: Option<SocketAddrV4>
static_address: Option<SocketAddrV4>,
},
/// List all hosts on your network
List {},
/// Lookup a specific host by it's ID
Lookup {
#[clap(short, long)]
id: String
id: String,
},
/// Delete a specific host by it's ID
Delete {
#[clap(short, long)]
id: String
id: String,
},
/// Update a specific host by it's ID, changing the listen port and static addresses, as well as the name, ip and role. The name, ip and role updates will only work on trifid-api-old compatible servers.
Update {
@ -192,23 +192,23 @@ pub enum HostCommands {
#[clap(short, long)]
role: Option<String>,
#[clap(short = 'I', long)]
ip: Option<Ipv4Addr>
ip: Option<Ipv4Addr>,
},
/// Blocks the specified host from the network
Block {
#[clap(short, long)]
id: String
id: String,
},
/// Enroll or re-enroll the host by generating an enrollment code
Enroll {
#[clap(short, long)]
id: String
id: String,
},
/// Manage config overrides set on the host
Overrides {
#[command(subcommand)]
command: HostOverrideCommands
}
command: HostOverrideCommands,
},
}
#[derive(Subcommand, Debug)]
@ -216,7 +216,7 @@ pub enum HostOverrideCommands {
/// List the config overrides set on the host
List {
#[clap(short, long)]
id: String
id: String,
},
/// Set a config override on the host
Set {
@ -229,15 +229,15 @@ pub enum HostOverrideCommands {
#[clap(short, long)]
numeric: Option<i64>,
#[clap(short, long)]
string: Option<String>
string: Option<String>,
},
/// Unset a config override on the host
Unset {
#[clap(short, long)]
id: String,
#[clap(short, long)]
key: String
}
key: String,
},
}
#[tokio::main]
@ -270,7 +270,10 @@ async fn main2() -> Result<(), Box<dyn Error>> {
let url = match Url::parse(&url_s) {
Ok(u) => u,
Err(e) => {
eprintln!("[error] unable to parse the URL in {}", srv_url_file.display());
eprintln!(
"[error] unable to parse the URL in {}",
srv_url_file.display()
);
eprintln!("[error] urlparse returned error '{}'", e);
eprintln!("[error] please correct the error and try again");
std::process::exit(1);
@ -284,6 +287,6 @@ async fn main2() -> Result<(), Box<dyn Error>> {
Commands::Network { command } => network_main(command, server).await,
Commands::Org { command } => org_main(command, server).await,
Commands::Role { command } => role_main(command, server).await,
Commands::Host { command } => host_main(command, server).await
Commands::Host { command } => host_main(command, server).await,
}
}

View File

@ -1,20 +1,20 @@
use std::error::Error;
use std::fs;
use serde::Deserialize;
use url::Url;
use crate::api::APIErrorResponse;
use crate::NetworkCommands;
use serde::Deserialize;
use std::error::Error;
use std::fs;
use url::Url;
pub async fn network_main(command: NetworkCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command {
NetworkCommands::List {} => list_networks(server).await,
NetworkCommands::Lookup {id} => get_network(id, server).await
NetworkCommands::Lookup { id } => get_network(id, server).await,
}
}
#[derive(Deserialize)]
pub struct NetworkListResp {
pub data: Vec<Network>
pub data: Vec<Network>,
}
#[derive(Deserialize)]
@ -29,7 +29,7 @@ pub struct Network {
pub created_at: String,
#[serde(rename = "lighthousesAsRelays")]
pub lighthouses_as_relays: bool,
pub name: String
pub name: String,
}
pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
@ -43,7 +43,11 @@ pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join("/v1/networks")?).bearer_auth(token).send().await?;
let res = client
.get(server.join("/v1/networks")?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let resp: NetworkListResp = res.json().await?;
@ -65,7 +69,10 @@ pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing networks: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error listing networks: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -75,7 +82,7 @@ pub async fn list_networks(server: Url) -> Result<(), Box<dyn Error>> {
#[derive(Deserialize)]
pub struct NetworkGetResponse {
pub data: Network
pub data: Network,
}
pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -89,7 +96,11 @@ pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/networks/{}", id))?).bearer_auth(token).send().await?;
let res = client
.get(server.join(&format!("/v1/networks/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let network: Network = res.json::<NetworkGetResponse>().await?.data;
@ -101,11 +112,13 @@ pub async fn get_network(id: String, server: Url) -> Result<(), Box<dyn Error>>
println!("Dedicated Relays: {}", !network.lighthouses_as_relays);
println!(" Name: {}", network.name);
println!(" Created At: {}", network.created_at);
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing networks: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error listing networks: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}

View File

@ -1,10 +1,10 @@
use std::error::Error;
use std::fs;
use crate::api::APIErrorResponse;
use crate::OrgCommands;
use ipnet::Ipv4Net;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use url::Url;
use crate::OrgCommands;
use crate::api::APIErrorResponse;
pub async fn org_main(command: OrgCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command {
@ -14,14 +14,14 @@ pub async fn org_main(command: OrgCommands, server: Url) -> Result<(), Box<dyn E
#[derive(Serialize)]
pub struct CreateOrgBody {
pub cidr: String
pub cidr: String,
}
#[derive(Deserialize)]
pub struct OrgCreateResponse {
pub organization: String,
pub ca: String,
pub network: String
pub network: String,
}
pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>> {
@ -34,7 +34,14 @@ pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>
let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join("/v1/organization")?).json(&CreateOrgBody { cidr: cidr.to_string() }).bearer_auth(token).send().await?;
let res = client
.post(server.join("/v1/organization")?)
.json(&CreateOrgBody {
cidr: cidr.to_string(),
})
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let resp: OrgCreateResponse = res.json().await?;
@ -45,7 +52,10 @@ pub async fn create_org(cidr: Ipv4Net, server: Url) -> Result<(), Box<dyn Error>
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating org: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error creating org: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}

View File

@ -1,23 +1,31 @@
use crate::api::APIErrorResponse;
use crate::RoleCommands;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::api::APIErrorResponse;
use crate::{RoleCommands};
pub async fn role_main(command: RoleCommands, server: Url) -> Result<(), Box<dyn Error>> {
match command {
RoleCommands::List {} => list_roles(server).await,
RoleCommands::Lookup {id} => get_role(id, server).await,
RoleCommands::Create { name, description, rules_json } => create_role(name, description, rules_json, server).await,
RoleCommands::Lookup { id } => get_role(id, server).await,
RoleCommands::Create {
name,
description,
rules_json,
} => create_role(name, description, rules_json, server).await,
RoleCommands::Delete { id } => delete_role(id, server).await,
RoleCommands::Update { id, description, rules_json } => update_role(id, description, rules_json, server).await
RoleCommands::Update {
id,
description,
rules_json,
} => update_role(id, description, rules_json, server).await,
}
}
#[derive(Deserialize)]
pub struct RoleListResp {
pub data: Vec<Role>
pub data: Vec<Role>,
}
#[derive(Deserialize)]
@ -30,7 +38,7 @@ pub struct Role {
#[serde(rename = "createdAt")]
pub created_at: String,
#[serde(rename = "modifiedAt")]
pub modified_at: String
pub modified_at: String,
}
#[derive(Deserialize, Serialize)]
pub struct RoleFirewallRule {
@ -39,12 +47,12 @@ pub struct RoleFirewallRule {
#[serde(rename = "allowedRoleID")]
pub allowed_role_id: Option<String>,
#[serde(rename = "portRange")]
pub port_range: Option<RoleFirewallRulePortRange>
pub port_range: Option<RoleFirewallRulePortRange>,
}
#[derive(Deserialize, Serialize)]
pub struct RoleFirewallRulePortRange {
pub from: u16,
pub to: u16
pub to: u16,
}
pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
@ -58,7 +66,11 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join("/v1/roles")?).bearer_auth(token).send().await?;
let res = client
.get(server.join("/v1/roles")?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let resp: RoleListResp = res.json().await?;
@ -68,9 +80,21 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
println!(" Description: {}", role.description);
for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string()));
println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() });
println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
}
println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at);
@ -82,7 +106,10 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing roles: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error listing roles: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -92,7 +119,7 @@ pub async fn list_roles(server: Url) -> Result<(), Box<dyn Error>> {
#[derive(Deserialize)]
pub struct RoleGetResponse {
pub data: Role
pub data: Role,
}
pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> {
@ -106,7 +133,11 @@ pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> {
let token = format!("{} {}", session_token, auth_token);
let res = client.get(server.join(&format!("/v1/roles/{}", id))?).bearer_auth(token).send().await?;
let res = client
.get(server.join(&format!("/v1/roles/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let role: Role = res.json::<RoleGetResponse>().await?.data;
@ -115,17 +146,31 @@ pub async fn get_role(id: String, server: Url) -> Result<(), Box<dyn Error>> {
println!(" Description: {}", role.description);
for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string()));
println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() });
println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
}
println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at);
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error listing roles: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error listing roles: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -138,10 +183,15 @@ pub struct RoleCreateBody {
pub name: String,
pub description: String,
#[serde(rename = "firewallRules")]
pub firewall_rules: Vec<RoleFirewallRule>
pub firewall_rules: Vec<RoleFirewallRule>,
}
pub async fn create_role(name: String, description: String, rules_json: String, server: Url) -> Result<(), Box<dyn Error>> {
pub async fn create_role(
name: String,
description: String,
rules_json: String,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
let rules: Vec<RoleFirewallRule> = match serde_json::from_str(&rules_json) {
@ -160,11 +210,16 @@ pub async fn create_role(name: String, description: String, rules_json: String,
let token = format!("{} {}", session_token, auth_token);
let res = client.post(server.join("/v1/roles")?).json(&RoleCreateBody {
let res = client
.post(server.join("/v1/roles")?)
.json(&RoleCreateBody {
name,
description,
firewall_rules: rules,
}).bearer_auth(token).send().await?;
})
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let role: Role = res.json::<RoleGetResponse>().await?.data;
@ -173,17 +228,31 @@ pub async fn create_role(name: String, description: String, rules_json: String,
println!(" Description: {}", role.description);
for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string()));
println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() });
println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
}
println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at);
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error creating role: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error creating role: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -195,10 +264,15 @@ pub async fn create_role(name: String, description: String, rules_json: String,
pub struct RoleUpdateBody {
pub description: String,
#[serde(rename = "firewallRules")]
pub firewall_rules: Vec<RoleFirewallRule>
pub firewall_rules: Vec<RoleFirewallRule>,
}
pub async fn update_role(id: String, description: String, rules_json: String, server: Url) -> Result<(), Box<dyn Error>> {
pub async fn update_role(
id: String,
description: String,
rules_json: String,
server: Url,
) -> Result<(), Box<dyn Error>> {
let client = reqwest::Client::new();
let rules: Vec<RoleFirewallRule> = match serde_json::from_str(&rules_json) {
@ -217,10 +291,15 @@ pub async fn update_role(id: String, description: String, rules_json: String, se
let token = format!("{} {}", session_token, auth_token);
let res = client.put(server.join(&format!("/v1/roles/{}", id))?).json(&RoleUpdateBody {
let res = client
.put(server.join(&format!("/v1/roles/{}", id))?)
.json(&RoleUpdateBody {
description,
firewall_rules: rules,
}).bearer_auth(token).send().await?;
})
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
let role: Role = res.json::<RoleGetResponse>().await?.data;
@ -229,17 +308,31 @@ pub async fn update_role(id: String, description: String, rules_json: String, se
println!(" Description: {}", role.description);
for rule in &role.firewall_rules {
println!("Rule Description: {}", rule.description);
println!(" Allowed Role: {}", rule.allowed_role_id.as_ref().unwrap_or(&"All roles".to_string()));
println!(
" Allowed Role: {}",
rule.allowed_role_id
.as_ref()
.unwrap_or(&"All roles".to_string())
);
println!(" Protocol: {}", rule.protocol);
println!(" Port Range: {}", if let Some(pr) = rule.port_range.as_ref() { format!("{}-{}", pr.from, pr.to) } else { "Any".to_string() });
println!(
" Port Range: {}",
if let Some(pr) = rule.port_range.as_ref() {
format!("{}-{}", pr.from, pr.to)
} else {
"Any".to_string()
}
);
}
println!(" Created: {}", role.created_at);
println!(" Updated: {}", role.modified_at);
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error updating role: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error updating role: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}
@ -258,14 +351,21 @@ pub async fn delete_role(id: String, server: Url) -> Result<(), Box<dyn Error>>
let token = format!("{} {}", session_token, auth_token);
let res = client.delete(server.join(&format!("/v1/roles/{}", id))?).bearer_auth(token).send().await?;
let res = client
.delete(server.join(&format!("/v1/roles/{}", id))?)
.bearer_auth(token)
.send()
.await?;
if res.status().is_success() {
println!("Role removed");
} else {
let resp: APIErrorResponse = res.json().await?;
eprintln!("[error] Error removing role: {} {}", resp.errors[0].code, resp.errors[0].message);
eprintln!(
"[error] Error removing role: {} {}",
resp.errors[0].code, resp.errors[0].message
);
std::process::exit(1);
}

View File

@ -112,10 +112,7 @@ pub fn apiworker_main(
cdata.creds = Some(creds);
cdata.dh_privkey = Some(dh_privkey);
match fs::write(
nebula_yml(&instance),
config,
) {
match fs::write(nebula_yml(&instance), config) {
Ok(_) => (),
Err(e) => {
error!("unable to save nebula config: {}", e);
@ -175,10 +172,7 @@ pub fn apiworker_main(
}
};
match fs::write(
nebula_yml(&instance),
config,
) {
match fs::write(nebula_yml(&instance), config) {
Ok(_) => (),
Err(e) => {
error!("unable to save nebula config: {}", e);

View File

@ -4,11 +4,11 @@ use std::error::Error;
use std::fs;
use std::net::{Ipv4Addr, SocketAddrV4};
use crate::dirs::{config_dir, tfclient_toml, tfdata_toml};
use dnapi_rs::client_blocking::EnrollMeta;
use dnapi_rs::credentials::Credentials;
use log::{debug, info};
use serde::{Deserialize, Serialize};
use crate::dirs::{config_dir, tfclient_toml, tfdata_toml};
pub const DEFAULT_PORT: u16 = 8157;
fn default_port() -> u16 {
@ -39,10 +39,7 @@ pub fn create_config(instance: &str) -> Result<(), Box<dyn Error>> {
disable_automatic_config_updates: false,
};
let config_str = toml::to_string(&config)?;
fs::write(
tfclient_toml(instance),
config_str,
)?;
fs::write(tfclient_toml(instance), config_str)?;
Ok(())
}
@ -72,10 +69,7 @@ pub fn create_cdata(instance: &str) -> Result<(), Box<dyn Error>> {
meta: None,
};
let config_str = toml::to_string(&config)?;
fs::write(
tfdata_toml(instance),
config_str,
)?;
fs::write(tfdata_toml(instance), config_str)?;
Ok(())
}

View File

@ -6,7 +6,7 @@ use std::thread;
use crate::apiworker::{apiworker_main, APIWorkerMessage};
use crate::config::load_config;
use crate::nebulaworker::{NebulaWorkerMessage, nebulaworker_main};
use crate::nebulaworker::{nebulaworker_main, NebulaWorkerMessage};
use crate::socketworker::{socketworker_main, SocketWorkerMessage};
use crate::timerworker::{timer_main, TimerWorkerMessage};
use crate::util::{check_server_url, shutdown};
@ -83,7 +83,11 @@ pub fn daemon_main(name: String, server: String) {
let timer_transmitter = transmitter;
let timer_thread = thread::spawn(move || {
timer_main(timer_transmitter, rx_timer, config.disable_automatic_config_updates);
timer_main(
timer_transmitter,
rx_timer,
config.disable_automatic_config_updates,
);
});
info!("Waiting for timer thread to exit...");
match timer_thread.join() {
@ -95,7 +99,6 @@ pub fn daemon_main(name: String, server: String) {
}
info!("Timer thread exited");
info!("Waiting for socket thread to exit...");
match socket_thread.join() {
Ok(_) => (),

View File

@ -31,13 +31,19 @@ pub fn config_dir(instance: &str) -> PathBuf {
}
pub fn tfclient_toml(instance: &str) -> PathBuf {
config_base().join(format!("{}/", instance)).join("tfclient.toml")
config_base()
.join(format!("{}/", instance))
.join("tfclient.toml")
}
pub fn tfdata_toml(instance: &str) -> PathBuf {
config_base().join(format!("{}/", instance)).join("tfdata.toml")
config_base()
.join(format!("{}/", instance))
.join("tfdata.toml")
}
pub fn nebula_yml(instance: &str) -> PathBuf {
config_base().join(format!("{}/", instance)).join("nebula.yml")
config_base()
.join(format!("{}/", instance))
.join("nebula.yml")
}

View File

@ -44,7 +44,6 @@ struct Cli {
#[derive(Subcommand)]
enum Commands {
/// Run the tfclient daemon in the foreground
Run {
#[clap(short, long, default_value = "tfclient")]

View File

@ -2,13 +2,13 @@
use crate::config::{load_cdata, NebulaConfig, TFClientConfig};
use crate::daemon::ThreadMessageSender;
use crate::dirs::{nebula_yml};
use crate::dirs::nebula_yml;
use crate::util::shutdown;
use log::{debug, error, info};
use nebula_ffi::NebulaInstance;
use std::error::Error;
use std::fs;
use std::sync::mpsc::Receiver;
use nebula_ffi::NebulaInstance;
use crate::util::shutdown;
pub enum NebulaWorkerMessage {
Shutdown,
@ -23,9 +23,7 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
let cdata = load_cdata(instance)?;
let key = cdata.dh_privkey.ok_or("Missing private key")?;
let config_str = fs::read_to_string(
nebula_yml(instance),
)?;
let config_str = fs::read_to_string(nebula_yml(instance))?;
let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?;
config.pki.key = Some(String::from_utf8(key)?);
@ -33,10 +31,7 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
debug!("inserted private key into config: {:?}", config);
let config_str = serde_yaml::to_string(&config)?;
fs::write(
nebula_yml(instance),
config_str,
)?;
fs::write(nebula_yml(instance), config_str)?;
Ok(())
}
@ -74,7 +69,8 @@ pub fn nebulaworker_main(
error!("not enrolled, cannot start nebula");
} else {
info!("setting up nebula...");
nebula = Some(match NebulaInstance::new(nebula_yml(&instance).as_path(), false) {
nebula = Some(
match NebulaInstance::new(nebula_yml(&instance).as_path(), false) {
Ok(i) => {
info!("nebula setup");
info!("starting nebula...");
@ -89,15 +85,15 @@ pub fn nebulaworker_main(
}
i
},
}
Err(e) => {
error!("error setting up Nebula: {}", e);
error!("nebula thread exiting with error");
shutdown(&transmitter);
return;
}
});
},
);
info!("nebula process started");
}
@ -157,7 +153,8 @@ pub fn nebulaworker_main(
} else {
debug!("detected enrollment, starting nebula for the first time");
info!("setting up nebula...");
nebula = Some(match NebulaInstance::new(nebula_yml(&instance).as_path(), false) {
nebula = Some(
match NebulaInstance::new(nebula_yml(&instance).as_path(), false) {
Ok(i) => {
info!("nebula setup");
info!("starting nebula...");
@ -172,15 +169,15 @@ pub fn nebulaworker_main(
}
i
},
}
Err(e) => {
error!("error setting up Nebula: {}", e);
error!("nebula thread exiting with error");
shutdown(&transmitter);
return;
}
});
},
);
info!("nebula process started");
}

View File

@ -4,7 +4,7 @@
use crate::config::{load_cdata, NebulaConfig, TFClientConfig};
use crate::daemon::ThreadMessageSender;
use crate::dirs::{nebula_yml};
use crate::dirs::nebula_yml;
use log::{debug, error, info};
use std::error::Error;
use std::fs;
@ -23,9 +23,7 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
let cdata = load_cdata(instance)?;
let key = cdata.dh_privkey.ok_or("Missing private key")?;
let config_str = fs::read_to_string(
nebula_yml(instance),
)?;
let config_str = fs::read_to_string(nebula_yml(instance))?;
let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?;
config.pki.key = Some(String::from_utf8(key)?);
@ -33,25 +31,26 @@ fn insert_private_key(instance: &str) -> Result<(), Box<dyn Error>> {
debug!("inserted private key into config: {:?}", config);
let config_str = serde_yaml::to_string(&config)?;
fs::write(
nebula_yml(instance),
config_str,
)?;
fs::write(nebula_yml(instance), config_str)?;
Ok(())
}
pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter: ThreadMessageSender, rx: Receiver<NebulaWorkerMessage>) {
pub fn nebulaworker_main(
_config: TFClientConfig,
instance: String,
_transmitter: ThreadMessageSender,
rx: Receiver<NebulaWorkerMessage>,
) {
loop {
match rx.recv() {
Ok(msg) => match msg {
NebulaWorkerMessage::WakeUp => {
continue;
},
}
NebulaWorkerMessage::Shutdown => {
break;
},
}
NebulaWorkerMessage::ConfigUpdated => {
info!("our configuration has been updated - reloading");

View File

@ -12,7 +12,11 @@ pub enum TimerWorkerMessage {
Shutdown,
}
pub fn timer_main(tx: ThreadMessageSender, rx: Receiver<TimerWorkerMessage>, disable_config_updates: bool) {
pub fn timer_main(
tx: ThreadMessageSender,
rx: Receiver<TimerWorkerMessage>,
disable_config_updates: bool,
) {
let mut api_reload_timer = SystemTime::now().add(Duration::from_secs(60));
loop {

View File

@ -1,12 +1,12 @@
use log::{error, warn};
use sha2::Digest;
use sha2::Sha256;
use url::Url;
use crate::apiworker::APIWorkerMessage;
use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage;
use crate::socketworker::SocketWorkerMessage;
use crate::timerworker::TimerWorkerMessage;
use log::{error, warn};
use sha2::Digest;
use sha2::Sha256;
use url::Url;
pub fn sha256(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
@ -51,10 +51,7 @@ pub fn shutdown(transmitter: &ThreadMessageSender) {
);
}
}
match transmitter
.api_thread
.send(APIWorkerMessage::Shutdown)
{
match transmitter.api_thread.send(APIWorkerMessage::Shutdown) {
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to api worker thread: {}", e);
@ -72,10 +69,7 @@ pub fn shutdown(transmitter: &ThreadMessageSender) {
);
}
}
match transmitter
.timer_thread
.send(TimerWorkerMessage::Shutdown)
{
match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) {
Ok(_) => (),
Err(e) => {
error!(

View File

@ -15,10 +15,13 @@ use crate::crypto::{decrypt_with_nonce, get_cipher_from_config};
use crate::AppState;
use ed25519_dalek::SigningKey;
use ipnet::Ipv4Net;
use log::{error};
use log::error;
use sea_orm::{ColumnTrait, Condition, EntityTrait, QueryFilter};
use serde_yaml::{Mapping, Value};
use trifid_api_entities::entity::{firewall_rule, host, host_config_override, host_static_address, keystore_entry, network, organization, signing_ca};
use trifid_api_entities::entity::{
firewall_rule, host, host_config_override, host_static_address, keystore_entry, network,
organization, signing_ca,
};
use trifid_pki::cert::{
deserialize_ed25519_private, deserialize_nebula_certificate_from_pem, NebulaCertificate,
NebulaCertificateDetails,
@ -36,7 +39,7 @@ pub struct CodegenRequiredInfo {
pub lighthouse_ips: Vec<Ipv4Addr>,
pub blocked_hosts: Vec<String>,
pub firewall_rules: Vec<NebulaConfigFirewallRule>,
pub config_overrides: Vec<(String, String)>
pub config_overrides: Vec<(String, String)>,
}
pub async fn generate_config(
@ -90,14 +93,20 @@ pub async fn generate_config(
let mut blocked_hosts_fingerprints = vec![];
for host in &info.blocked_hosts {
// check if the host exists
if host::Entity::find().filter(host::Column::Id.eq(host)).one(&db.conn).await?.is_some() {
if host::Entity::find()
.filter(host::Column::Id.eq(host))
.one(&db.conn)
.await?
.is_some()
{
// pull all of their certs ever and block them
let host_entries = keystore_entry::Entity::find().filter(keystore_entry::Column::Host.eq(host)).all(&db.conn).await?;
let host_entries = keystore_entry::Entity::find()
.filter(keystore_entry::Column::Host.eq(host))
.all(&db.conn)
.await?;
for entry in &host_entries {
// decode the cert
let cert = deserialize_nebula_certificate_from_pem(&entry.certificate)?;
@ -209,11 +218,18 @@ pub async fn generate_config(
let mut current_val = &mut value;
for key_iter in &key_split[..key_split.len()-1] {
current_val = current_val.as_mapping_mut().unwrap().entry(Value::String(key_iter.to_string())).or_insert(Value::Mapping(Mapping::new()));
for key_iter in &key_split[..key_split.len() - 1] {
current_val = current_val
.as_mapping_mut()
.unwrap()
.entry(Value::String(key_iter.to_string()))
.or_insert(Value::Mapping(Mapping::new()));
}
current_val.as_mapping_mut().unwrap().insert(Value::String(key_split[key_split.len()-1].to_string()), serde_yaml::from_str(kv_value)?);
current_val.as_mapping_mut().unwrap().insert(
Value::String(key_split[key_split.len() - 1].to_string()),
serde_yaml::from_str(kv_value)?,
);
}
let config_str_merged = serde_yaml::to_string(&value)?;
@ -338,10 +354,10 @@ pub async fn collect_info<'a>(
let best_ca = best_ca.unwrap();
// pull our host's config overrides
let config_overrides = host_config_overrides.iter().map(|u| {
(u.key.clone(), u.value.clone())
}).collect();
let config_overrides = host_config_overrides
.iter()
.map(|u| (u.key.clone(), u.value.clone()))
.collect();
// pull our role's firewall rules
let firewall_rules = trifid_api_entities::entity::firewall_rule::Entity::find()
@ -386,6 +402,6 @@ pub async fn collect_info<'a>(
lighthouse_ips: lighthouses,
blocked_hosts,
firewall_rules,
config_overrides
config_overrides,
})
}

View File

@ -77,7 +77,7 @@ pub struct TrifidConfigServer {
#[serde(default = "socketaddr_8080")]
pub bind: SocketAddr,
#[serde(default = "default_workers")]
pub workers: usize
pub workers: usize,
}
#[derive(Serialize, Deserialize, Debug)]
@ -733,4 +733,6 @@ fn is_none<T>(o: &Option<T>) -> bool {
o.is_none()
}
fn default_workers() -> usize { 32 }
fn default_workers() -> usize {
32
}

View File

@ -14,6 +14,10 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use crate::config::CONFIG;
use crate::error::{APIError, APIErrorsResponse};
use crate::tokens::random_id_no_id;
use actix_cors::Cors;
use actix_request_identifier::RequestIdentifier;
use actix_web::{
web::{Data, JsonConfig},
@ -23,10 +27,6 @@ use log::{info, Level};
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use std::error::Error;
use std::time::Duration;
use actix_cors::Cors;
use crate::config::CONFIG;
use crate::error::{APIError, APIErrorsResponse};
use crate::tokens::random_id_no_id;
use trifid_api_migration::{Migrator, MigratorTrait};
pub mod auth_tokens;
@ -37,10 +37,10 @@ pub mod cursor;
pub mod error;
//pub mod legacy_keystore; // TODO- Remove
pub mod magic_link;
pub mod response;
pub mod routes;
pub mod timers;
pub mod tokens;
pub mod response;
pub struct AppState {
pub conn: DatabaseConnection,

View File

@ -1,9 +1,9 @@
use std::fmt::{Display, Formatter};
use actix_web::{HttpRequest, HttpResponse, Responder, ResponseError};
use actix_web::body::EitherBody;
use actix_web::web::Json;
use actix_web::{HttpRequest, HttpResponse, Responder, ResponseError};
use log::error;
use sea_orm::DbErr;
use std::fmt::{Display, Formatter};
use crate::error::{APIError, APIErrorsResponse};
@ -30,13 +30,15 @@ impl Responder for ErrResponse {
impl From<DbErr> for ErrResponse {
fn from(value: DbErr) -> Self {
error!("database error: {}", value);
Self(APIErrorsResponse { errors: vec![
APIError {
Self(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error performing the database query. Please try again later.".to_string(),
message:
"There was an error performing the database query. Please try again later."
.to_string(),
path: None,
}
] })
}],
})
}
}

View File

@ -10,18 +10,20 @@ use dnapi_rs::message::{
};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use log::{error, warn};
use sea_orm::{ActiveModelTrait, EntityTrait};
use std::clone::Clone;
use std::time::{SystemTime, UNIX_EPOCH};
use sea_orm::{ActiveModelTrait, EntityTrait};
use trifid_pki::cert::{deserialize_ed25519_public, deserialize_nebula_certificate_from_pem, deserialize_x25519_public};
use trifid_pki::cert::{
deserialize_ed25519_public, deserialize_nebula_certificate_from_pem, deserialize_x25519_public,
};
use trifid_api_entities::entity::{host, keystore_entry, keystore_host};
use crate::error::APIErrorsResponse;
use sea_orm::{ColumnTrait, QueryFilter, IntoActiveModel};
use sea_orm::ActiveValue::Set;
use crate::AppState;
use crate::config::NebulaConfig;
use crate::error::APIErrorsResponse;
use crate::tokens::random_id;
use crate::AppState;
use sea_orm::ActiveValue::Set;
use sea_orm::{ColumnTrait, IntoActiveModel, QueryFilter};
use trifid_api_entities::entity::{host, keystore_entry, keystore_host};
#[post("/v1/dnclient")]
pub async fn dnclient(
@ -109,7 +111,8 @@ pub async fn dnclient(
log::debug!("{:x?}", keystore_data.client_signing_key);
let key = VerifyingKey::from_bytes(&keystore_data.client_signing_key.try_into().unwrap()).unwrap();
let key =
VerifyingKey::from_bytes(&keystore_data.client_signing_key.try_into().unwrap()).unwrap();
if key.verify(req.message.as_bytes(), &signature).is_err() {
// Be intentionally vague as the message is invalid.
@ -176,9 +179,7 @@ pub async fn dnclient(
return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![crate::error::APIError {
code: "ERR_NOT_FOUND".to_string(),
message:
"resource not found"
.to_string(),
message: "resource not found".to_string(),
path: None,
}],
});
@ -241,11 +242,15 @@ pub async fn dnclient(
c1.pki.cert = c0.pki.cert.clone();
if c0 == c1 {
// its just the cert. deserialize both and check if any details have changed
let cert0 = deserialize_nebula_certificate_from_pem(c0.pki.cert.as_bytes()).expect("generated an invalid certificate");
let mut cert1 = deserialize_nebula_certificate_from_pem(c1.pki.cert.as_bytes()).expect("generated an invalid certificate");
let cert0 = deserialize_nebula_certificate_from_pem(c0.pki.cert.as_bytes())
.expect("generated an invalid certificate");
let mut cert1 = deserialize_nebula_certificate_from_pem(c1.pki.cert.as_bytes())
.expect("generated an invalid certificate");
cert1.details.not_before = cert0.details.not_before;
cert1.details.not_after = cert0.details.not_after;
if cert0.serialize_to_pem().expect("generated invalid cert") == cert1.serialize_to_pem().expect("generated invalid cert") {
if cert0.serialize_to_pem().expect("generated invalid cert")
== cert1.serialize_to_pem().expect("generated invalid cert")
{
// fake news! its fine actually
config_is_different = false;
}
@ -256,7 +261,10 @@ pub async fn dnclient(
let config_update_avail = config_is_different || req.counter < keystore_header.counter as u32;
host_am.last_out_of_date = Set(config_update_avail);
host_am.last_seen_at = Set(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64);
host_am.last_seen_at = Set(SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64);
match host_am.update(&db.conn).await {
Ok(_) => (),
@ -391,7 +399,7 @@ pub async fn dnclient(
client_dh_key: dh_pubkey,
client_signing_key: ed_pubkey,
config: cfg_str.clone(),
signing_key: keystore_data.signing_key.clone()
signing_key: keystore_data.signing_key.clone(),
};
match ks_entry_model.into_active_model().insert(&db.conn).await {
@ -425,7 +433,8 @@ pub async fn dnclient(
}
}
let signing_key = SigningKey::from_bytes(&keystore_data.signing_key.try_into().unwrap());
let signing_key =
SigningKey::from_bytes(&keystore_data.signing_key.try_into().unwrap());
// get the signing key that the client last trusted based on its current config version
let msg = DoUpdateResponse {

View File

@ -76,7 +76,9 @@ use serde::{Deserialize, Serialize};
use std::net::{Ipv4Addr, SocketAddrV4};
use std::str::FromStr;
use std::time::{SystemTime, UNIX_EPOCH};
use trifid_api_entities::entity::{host, host_config_override, host_static_address, network, organization, role};
use trifid_api_entities::entity::{
host, host_config_override, host_static_address, network, organization, role,
};
#[derive(Serialize, Deserialize)]
pub struct ListHostsRequestOpts {
@ -577,13 +579,11 @@ pub async fn create_hosts_request(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()),
}
],
}],
});
}
};
@ -592,25 +592,21 @@ pub async fn create_hosts_request(
net_id = net.id;
} else {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()),
}
],
}],
});
}
if net_id != req.network_id {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()),
}
],
}],
});
}
@ -640,7 +636,7 @@ pub async fn create_hosts_request(
code: "ERR_INVALID_VALUE".to_string(),
message: "lighthouse hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()),
}]
}],
});
} else if req.listen_port == 0 && req.is_relay {
return HttpResponse::BadRequest().json(APIErrorsResponse {
@ -648,7 +644,7 @@ pub async fn create_hosts_request(
code: "ERR_INVALID_VALUE".to_string(),
message: "relay hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()),
}]
}],
});
}
@ -662,26 +658,24 @@ pub async fn create_hosts_request(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(),
message:
"There was an error validating the request. Please try again later."
.to_string(),
path: Some("role".to_string()),
}
],
}],
});
}
};
if roles.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("role".to_string()),
}
],
}],
});
}
}
@ -695,26 +689,23 @@ pub async fn create_hosts_request(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(),
message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("name".to_string()),
}
],
}],
});
}
};
if !matching_hostname.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(),
path: Some("name".to_string()),
}
],
}],
});
}
@ -727,26 +718,23 @@ pub async fn create_hosts_request(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(),
message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("ipAddress".to_string()),
}
],
}],
});
}
};
if !matching_ip.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(),
path: Some("ipAddress".to_string()),
}
],
}],
});
}
@ -1006,9 +994,7 @@ pub async fn get_host(id: Path<String>, req_info: HttpRequest, db: Data<AppState
return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(),
message:
"resource not found"
.to_string(),
message: "resource not found".to_string(),
path: None,
}],
});
@ -1243,9 +1229,7 @@ pub async fn delete_host(
return HttpResponse::Unauthorized().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(),
message:
"resource not found"
.to_string(),
message: "resource not found".to_string(),
path: None,
}],
});
@ -1500,9 +1484,7 @@ pub async fn edit_host(
return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(),
message:
"resource not found"
.to_string(),
message: "resource not found".to_string(),
path: None,
}],
});
@ -1827,9 +1809,7 @@ pub async fn block_host(
return HttpResponse::NotFound().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(),
message:
"resource not found"
.to_string(),
message: "resource not found".to_string(),
path: None,
}],
});
@ -2094,9 +2074,7 @@ pub async fn enroll_host(
return HttpResponse::Unauthorized().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_NOT_FOUND".to_string(),
message:
"resource not found"
.to_string(),
message: "resource not found".to_string(),
path: None,
}],
});
@ -2263,13 +2241,11 @@ pub async fn create_host_and_enrollment_code(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()),
}
],
}],
});
}
};
@ -2288,13 +2264,11 @@ pub async fn create_host_and_enrollment_code(
if net_id != req.network_id {
return HttpResponse::Unauthorized().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("networkID".to_string()),
}
],
}],
});
}
@ -2324,7 +2298,7 @@ pub async fn create_host_and_enrollment_code(
code: "ERR_INVALID_VALUE".to_string(),
message: "lighthouse hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()),
}]
}],
});
} else if req.listen_port == 0 && req.is_relay {
return HttpResponse::BadRequest().json(APIErrorsResponse {
@ -2332,7 +2306,7 @@ pub async fn create_host_and_enrollment_code(
code: "ERR_INVALID_VALUE".to_string(),
message: "relay hosts must specify a static listen port".to_string(),
path: Some("listenPort".to_string()),
}]
}],
});
}
@ -2346,26 +2320,24 @@ pub async fn create_host_and_enrollment_code(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(),
message:
"There was an error validating the request. Please try again later."
.to_string(),
path: Some("role".to_string()),
}
],
}],
});
}
};
if roles.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_REFERENCE".to_string(),
message: "referenced value is invalid (perhaps it does not exist?)".to_string(),
path: Some("role".to_string()),
}
],
}],
});
}
}
@ -2379,26 +2351,23 @@ pub async fn create_host_and_enrollment_code(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(),
message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("name".to_string()),
}
],
}],
});
}
};
if !matching_hostname.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(),
path: Some("name".to_string()),
}
],
}],
});
}
@ -2411,26 +2380,23 @@ pub async fn create_host_and_enrollment_code(
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error validating the request. Please try again later.".to_string(),
message: "There was an error validating the request. Please try again later."
.to_string(),
path: Some("ipAddress".to_string()),
}
],
}],
});
}
};
if !matching_ip.is_empty() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(),
path: Some("ipAddress".to_string()),
}
],
}],
});
}
@ -2588,7 +2554,11 @@ pub enum HostConfigOverrideDataOverrideValue {
}
#[get("/v1/hosts/{host_id}/config-overrides")]
pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
pub async fn get_host_overrides(
id: Path<String>,
req_info: HttpRequest,
db: Data<AppState>,
) -> HttpResponse {
// For this endpoint, you either need to be a fully authenticated user OR a token with hosts:read
let session_info = enforce_2fa(&req_info, &db.conn)
.await
@ -2752,7 +2722,11 @@ pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Dat
});
}
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find().filter(host_config_override::Column::Host.eq(host.id)).all(&db.conn).await {
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(host.id))
.all(&db.conn)
.await
{
Ok(h) => h,
Err(e) => {
error!("Database error: {}", e);
@ -2767,7 +2741,9 @@ pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Dat
}
};
let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides.iter().map(|u| {
let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides
.iter()
.map(|u| {
let val;
if u.value == "true" || u.value == "false" {
val = HostConfigOverrideDataOverrideValue::Boolean(u.value == "true");
@ -2780,12 +2756,11 @@ pub async fn get_host_overrides(id: Path<String>, req_info: HttpRequest, db: Dat
key: u.key.clone(),
value: val,
}
}).collect();
})
.collect();
HttpResponse::Ok().json(HostConfigOverrideResponse {
data: HostConfigOverrideData {
overrides,
},
data: HostConfigOverrideData { overrides },
})
}
@ -2795,7 +2770,12 @@ pub struct UpdateOverridesRequest {
}
#[put("/v1/hosts/{host_id}/config-overrides")]
pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRequest>, req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
pub async fn update_host_overrides(
id: Path<String>,
req: Json<UpdateOverridesRequest>,
req_info: HttpRequest,
db: Data<AppState>,
) -> HttpResponse {
// For this endpoint, you either need to be a fully authenticated user OR a token with hosts:read
let session_info = enforce_2fa(&req_info, &db.conn)
.await
@ -2959,7 +2939,11 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
});
}
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find().filter(host_config_override::Column::Host.eq(&host.id)).all(&db.conn).await {
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(&host.id))
.all(&db.conn)
.await
{
Ok(h) => h,
Err(e) => {
error!("Database error: {}", e);
@ -2982,7 +2966,8 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error with the database query. Please try again later."
message:
"There was an error with the database query. Please try again later."
.to_string(),
path: None,
}],
@ -3009,7 +2994,8 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message: "There was an error with the database query. Please try again later."
message:
"There was an error with the database query. Please try again later."
.to_string(),
path: None,
}],
@ -3018,7 +3004,11 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
}
}
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find().filter(host_config_override::Column::Host.eq(&host.id)).all(&db.conn).await {
let config_overrides = match trifid_api_entities::entity::host_config_override::Entity::find()
.filter(host_config_override::Column::Host.eq(&host.id))
.all(&db.conn)
.await
{
Ok(h) => h,
Err(e) => {
error!("Database error: {}", e);
@ -3033,11 +3023,18 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
}
};
let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides.iter().map(|u| {
let overrides: Vec<HostConfigOverrideDataOverride> = config_overrides
.iter()
.map(|u| {
let val;
if u.value == "true" || u.value == "false" {
val = HostConfigOverrideDataOverrideValue::Boolean(u.value == "true");
} else if u.value.chars().all(|c| c.is_numeric()) || u.value.starts_with('-') && u.value.chars().collect::<Vec<_>>()[1..].iter().all(|c| c.is_numeric()) {
} else if u.value.chars().all(|c| c.is_numeric())
|| u.value.starts_with('-')
&& u.value.chars().collect::<Vec<_>>()[1..]
.iter()
.all(|c| c.is_numeric())
{
val = HostConfigOverrideDataOverrideValue::Numeric(u.value.parse().unwrap());
} else {
val = HostConfigOverrideDataOverrideValue::Other(u.value.clone());
@ -3046,11 +3043,10 @@ pub async fn update_host_overrides(id: Path<String>, req: Json<UpdateOverridesRe
key: u.key.clone(),
value: val,
}
}).collect();
})
.collect();
HttpResponse::Ok().json(HostConfigOverrideResponse {
data: HostConfigOverrideData {
overrides,
},
data: HostConfigOverrideData { overrides },
})
}

View File

@ -238,27 +238,23 @@ pub async fn create_role_request(
if role.is_some() {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_DUPLICATE_VALUE".to_string(),
message: "value already exists".to_string(),
path: Some("name".to_string())
}
]
})
path: Some("name".to_string()),
}],
});
}
for (id, rule) in req.firewall_rules.iter().enumerate() {
if let Some(pr) = &rule.port_range {
if pr.from < pr.to {
return HttpResponse::BadRequest().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_INVALID_VALUE".to_string(),
message: "from must be less than or equal to to".to_string(),
path: Some(format!("firewallRules[{}].portRange", id))
}
]
path: Some(format!("firewallRules[{}].portRange", id)),
}],
});
}
}

View File

@ -4,18 +4,20 @@ use actix_web::{post, HttpRequest, HttpResponse, Responder};
use dnapi_rs::message::{
APIError, EnrollRequest, EnrollResponse, EnrollResponseData, EnrollResponseDataOrg,
};
use ed25519_dalek::{SigningKey};
use ed25519_dalek::SigningKey;
use log::{debug, error};
use rand::rngs::OsRng;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter};
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, QueryFilter,
};
use crate::codegen::{collect_info, generate_config};
use crate::response::ErrResponse;
use crate::AppState;
use trifid_api_entities::entity::{host_enrollment_code, keystore_entry, keystore_host};
use trifid_pki::cert::{
deserialize_ed25519_public, deserialize_x25519_public, serialize_ed25519_public,
};
use crate::response::ErrResponse;
use crate::timers::expired;
use crate::tokens::random_id;
@ -111,7 +113,8 @@ pub async fn enroll(
Ok(_) => (),
Err(e) => {
error!("database error: {}", e);
return Ok(HttpResponse::InternalServerError().json(EnrollResponse::Error {
return Ok(
HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_DB_ERROR".to_string(),
message:
@ -119,20 +122,23 @@ pub async fn enroll(
.to_string(),
path: None,
}],
}));
}),
);
}
}
let info = match collect_info(&db, &enroll_info.host, &dh_pubkey).await {
Ok(i) => i,
Err(e) => {
return Ok(HttpResponse::InternalServerError().json(EnrollResponse::Error {
return Ok(
HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: e.to_string(),
path: None,
}],
}));
}),
);
}
};
@ -141,25 +147,34 @@ pub async fn enroll(
Ok(cfg) => cfg,
Err(e) => {
error!("error generating configuration: {}", e);
return Ok(HttpResponse::InternalServerError().json(EnrollResponse::Error {
return Ok(
HttpResponse::InternalServerError().json(EnrollResponse::Error {
errors: vec![APIError {
code: "ERR_CFG_GENERATION_ERROR".to_string(),
message: "There was an error generating the host configuration.".to_string(),
message: "There was an error generating the host configuration."
.to_string(),
path: None,
}],
}));
}),
);
}
};
// delete all entries in the keystore for this host
let entries = keystore_entry::Entity::find().filter(keystore_entry::Column::Host.eq(&enroll_info.host)).all(&db.conn).await?;
let entries = keystore_entry::Entity::find()
.filter(keystore_entry::Column::Host.eq(&enroll_info.host))
.all(&db.conn)
.await?;
for entry in entries {
entry.delete(&db.conn).await?;
}
let host_info = keystore_host::Entity::find().filter(keystore_host::Column::Id.eq(&enroll_info.host)).one(&db.conn).await?;
let host_info = keystore_host::Entity::find()
.filter(keystore_host::Column::Id.eq(&enroll_info.host))
.one(&db.conn)
.await?;
if let Some(old_host) = host_info {
old_host.delete(&db.conn).await?;
@ -183,7 +198,7 @@ pub async fn enroll(
let host_header = keystore_host::Model {
id: enroll_info.host.clone(),
counter: 1
counter: 1,
};
let entry = keystore_entry::Model {
id: random_id("ksentry"),
@ -193,7 +208,7 @@ pub async fn enroll(
client_dh_key: dh_pubkey,
client_signing_key: ed_pubkey,
config: cfg.clone(),
signing_key: key.to_bytes().to_vec()
signing_key: key.to_bytes().to_vec(),
};
host_header.into_active_model().insert(&db.conn).await?;

View File

@ -19,17 +19,17 @@
// This endpoint is considered done. No major features should be added or removed, unless it fixes bugs.
// This endpoint requires the `definednetworking` extension to be enabled to be used.
use actix_web::{get, HttpRequest, HttpResponse};
use crate::auth_tokens::{enforce_2fa, enforce_session, TokenInfo};
use crate::error::{APIError, APIErrorsResponse};
use crate::timers::TIME_FORMAT;
use crate::AppState;
use actix_web::web::Data;
use actix_web::{get, HttpRequest, HttpResponse};
use chrono::{TimeZone, Utc};
use log::error;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use serde::{Deserialize, Serialize};
use trifid_api_entities::entity::{organization, totp_authenticator};
use crate::AppState;
use crate::auth_tokens::{enforce_2fa, enforce_session, TokenInfo};
use crate::error::{APIError, APIErrorsResponse};
use crate::timers::TIME_FORMAT;
#[derive(Serialize, Deserialize)]
pub struct WhoamiResponse {
@ -72,13 +72,11 @@ pub async fn whoami(req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
Err(e) => {
error!("database error: {}", e);
return HttpResponse::InternalServerError().json(APIErrorsResponse {
errors: vec![
APIError {
errors: vec![APIError {
code: "ERR_UNAUTHORIZED".to_string(),
message: "Your authentication token is invalid.".to_string(),
path: None,
}
],
}],
});
}
}
@ -155,7 +153,9 @@ pub async fn whoami(req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
};
HttpResponse::Ok().json(WhoamiResponse {
data: WhoamiResponseData { actor_type: "user".to_string(), actor: WhoamiResponseDataActor {
data: WhoamiResponseData {
actor_type: "user".to_string(),
actor: WhoamiResponseDataActor {
id: user.id,
organization_id: org,
email: user.email,
@ -165,7 +165,8 @@ pub async fn whoami(req_info: HttpRequest, db: Data<AppState>) -> HttpResponse {
.format(TIME_FORMAT)
.to_string(),
has_totp_authenticator: has_totp,
} },
},
},
metadata: WhoamiResponseMetadata {},
})
}

View File

@ -6,22 +6,20 @@ pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.create_table(
manager
.create_table(
Table::create()
.table(KeystoreHost::Table)
.col(
ColumnDef::new(KeystoreHost::Id)
.string()
.not_null()
.primary_key()
.primary_key(),
)
.col(
ColumnDef::new(KeystoreHost::Counter)
.integer()
.not_null()
.col(ColumnDef::new(KeystoreHost::Counter).integer().not_null())
.to_owned(),
)
.to_owned()
).await
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {

View File

@ -1,5 +1,5 @@
use sea_orm_migration::prelude::*;
use crate::m20230427_170037_create_table_hosts::Host;
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
@ -7,54 +7,43 @@ pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.create_table(
manager
.create_table(
Table::create()
.table(KeystoreEntry::Table)
.col(
ColumnDef::new(KeystoreEntry::Id)
.string()
.not_null()
.primary_key()
)
.col(
ColumnDef::new(KeystoreEntry::Host)
.string()
.not_null()
)
.col(
ColumnDef::new(KeystoreEntry::Counter)
.integer()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(KeystoreEntry::Host).string().not_null())
.col(ColumnDef::new(KeystoreEntry::Counter).integer().not_null())
.col(
ColumnDef::new(KeystoreEntry::SigningKey)
.binary()
.not_null()
.not_null(),
)
.col(
ColumnDef::new(KeystoreEntry::ClientSigningKey)
.binary()
.not_null()
.not_null(),
)
.col(
ColumnDef::new(KeystoreEntry::ClientDHKey)
.binary()
.not_null()
)
.col(
ColumnDef::new(KeystoreEntry::Config)
.string()
.not_null()
.not_null(),
)
.col(ColumnDef::new(KeystoreEntry::Config).string().not_null())
.col(
ColumnDef::new(KeystoreEntry::Certificate)
.binary()
.not_null()
.not_null(),
)
.foreign_key(
ForeignKey::create()
.from(KeystoreEntry::Table, KeystoreEntry::Host)
.to(Host::Table, Host::Id)
.to(Host::Table, Host::Id),
)
.index(
Index::create()
@ -62,10 +51,11 @@ impl MigrationTrait for Migration {
.table(KeystoreEntry::Table)
.col(KeystoreEntry::Host)
.col(KeystoreEntry::Counter)
.unique()
.unique(),
)
.to_owned()
).await
.to_owned(),
)
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
@ -86,5 +76,5 @@ enum KeystoreEntry {
ClientSigningKey,
ClientDHKey,
Config,
Certificate
Certificate,
}

View File

@ -42,5 +42,5 @@ pub struct ConfigEmail {
pub struct ConfigTokens {
pub magic_link_expiry_seconds: u64,
pub session_token_expiry_seconds: u64,
pub auth_token_expiry_seconds: u64
pub auth_token_expiry_seconds: u64,
}

View File

@ -2,6 +2,11 @@ use actix_web::error::{JsonPayloadError, PayloadError};
use serde::Serialize;
use std::fmt::{Display, Formatter};
#[derive(Serialize, Debug)]
pub struct APIErrorsResponse {
pub errors: Vec<APIErrorResponse>,
}
#[derive(Serialize, Debug)]
pub struct APIErrorResponse {
pub code: String,

View File

@ -46,11 +46,11 @@ pub struct TotpAuthenticator {
pub verified: bool,
pub name: String,
pub created_at: SystemTime,
pub last_seen_at: SystemTime
pub last_seen_at: SystemTime,
}
#[derive(
Queryable, Selectable, Insertable, Identifiable, Associations, Debug, PartialEq, Clone,
Queryable, Selectable, Insertable, Identifiable, Associations, Debug, PartialEq, Clone,
)]
#[diesel(belongs_to(User))]
#[diesel(table_name = crate::schema::auth_tokens)]

View File

@ -1,4 +1,4 @@
use crate::error::APIErrorResponse;
use crate::error::APIErrorsResponse;
use actix_web::body::BoxBody;
use actix_web::error::JsonPayloadError;
use actix_web::http::StatusCode;
@ -8,7 +8,7 @@ use std::fmt::{Debug, Display, Formatter};
#[derive(Debug)]
pub enum JsonAPIResponse<T: Serialize + Debug> {
Error(StatusCode, APIErrorResponse),
Error(StatusCode, APIErrorsResponse),
Success(StatusCode, T),
}
@ -87,17 +87,21 @@ macro_rules! handle_error {
#[macro_export]
macro_rules! make_err {
($c:expr,$m:expr,$p:expr) => {
$crate::error::APIErrorResponse {
$crate::error::APIErrorsResponse {
errors: vec![$crate::error::APIErrorResponse {
code: $c.to_string(),
message: $m.to_string(),
path: Some($p.to_string()),
}],
}
};
($c:expr,$m:expr) => {
$crate::error::APIErrorResponse {
$crate::error::APIErrorsResponse {
errors: vec![$crate::error::APIErrorResponse {
code: $c.to_string(),
message: $m.to_string(),
path: None,
}],
}
};
}

View File

@ -1,3 +1,3 @@
pub mod magic_link;
pub mod verify_magic_link;
pub mod totp;
pub mod verify_magic_link;

View File

@ -1,19 +1,19 @@
use std::time::{Duration, SystemTime};
use actix_web::http::StatusCode;
use actix_web::{HttpRequest, post};
use actix_web::web::{Data, Json};
use serde::{Deserialize, Serialize};
use crate::{AppState, auth, enforce, randid};
use crate::response::JsonAPIResponse;
use diesel::{QueryDsl, ExpressionMethods, SelectableHelper, BelongingToDsl};
use diesel_async::RunQueryDsl;
use totp_rs::{Algorithm, Secret, TOTP};
use crate::schema::{auth_tokens, users, totp_authenticators};
use crate::models::{AuthToken, TotpAuthenticator, User};
use crate::response::JsonAPIResponse;
use crate::schema::{auth_tokens, totp_authenticators, users};
use crate::{auth, enforce, randid, AppState};
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest};
use diesel::{BelongingToDsl, ExpressionMethods, QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime};
use totp_rs::{Algorithm, Secret, TOTP};
#[derive(Deserialize, Debug)]
pub struct TotpAuthReq {
pub code: String
pub code: String,
}
#[derive(Serialize, Debug)]
@ -32,14 +32,27 @@ pub struct TotpAuthResp {
}
#[post("/v1/auth/totp")]
pub async fn totp_req(req: Json<TotpAuthReq>, state: Data<AppState>, req_info: HttpRequest) -> JsonAPIResponse<TotpAuthResp> {
pub async fn totp_req(
req: Json<TotpAuthReq>,
state: Data<AppState>,
req_info: HttpRequest,
) -> JsonAPIResponse<TotpAuthResp> {
let mut conn = handle_error!(state.pool.get().await);
let auth_info = auth!(req_info, conn);
let session_token = enforce!(sess auth_info);
let user = handle_error!(users::table.find(&session_token.user_id).first::<User>(&mut conn).await);
let user = handle_error!(
users::table
.find(&session_token.user_id)
.first::<User>(&mut conn)
.await
);
let authenticators: Vec<TotpAuthenticator> = handle_error!(TotpAuthenticator::belonging_to(&user).load::<TotpAuthenticator>(&mut conn).await);
let authenticators: Vec<TotpAuthenticator> = handle_error!(
TotpAuthenticator::belonging_to(&user)
.load::<TotpAuthenticator>(&mut conn)
.await
);
let mut found_valid_code = false;
let mut chosen_auther = None;
@ -47,16 +60,36 @@ pub async fn totp_req(req: Json<TotpAuthReq>, state: Data<AppState>, req_info: H
for totp_auther in authenticators {
if totp_auther.verified {
let secret = Secret::Encoded(totp_auther.secret.clone());
let totp_machine = handle_error!(TOTP::new(Algorithm::SHA1, 6, 1, 30, handle_error!(secret.to_bytes()), Some("Trifid".to_string()), user.email.clone()));
let totp_machine = handle_error!(TOTP::new(
Algorithm::SHA1,
6,
1,
30,
handle_error!(secret.to_bytes()),
Some("Trifid".to_string()),
user.email.clone()
));
let is_valid = handle_error!(totp_machine.check_current(&req.code));
if is_valid { found_valid_code = true; chosen_auther = Some(totp_auther); break; }
if is_valid {
found_valid_code = true;
chosen_auther = Some(totp_auther);
break;
}
}
}
if !found_valid_code {
err!(StatusCode::UNAUTHORIZED, make_err!("ERR_UNAUTHORIZED", "unauthorized"));
err!(
StatusCode::UNAUTHORIZED,
make_err!("ERR_UNAUTHORIZED", "unauthorized")
);
}
handle_error!(diesel::update(&(chosen_auther.unwrap())).set(totp_authenticators::dsl::last_seen_at.eq(SystemTime::now())).execute(&mut conn).await);
handle_error!(
diesel::update(&(chosen_auther.unwrap()))
.set(totp_authenticators::dsl::last_seen_at.eq(SystemTime::now()))
.execute(&mut conn)
.await
);
// issue auth token
@ -74,12 +107,10 @@ pub async fn totp_req(req: Json<TotpAuthReq>, state: Data<AppState>, req_info: H
.await
);
ok!(
TotpAuthResp {
ok!(TotpAuthResp {
data: TotpAuthRespData {
auth_token: new_token.id.clone()
},
metadata: TotpAuthRespMeta {}
}
)
})
}

View File

@ -5,10 +5,10 @@ use crate::{randid, AppState};
use actix_web::http::StatusCode;
use actix_web::post;
use actix_web::web::{Data, Json};
use diesel::result::OptionalExtension;
use diesel::QueryDsl;
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use diesel::result::OptionalExtension;
use std::time::{Duration, SystemTime};
#[derive(Deserialize)]
@ -35,11 +35,14 @@ pub async fn verify_link_req(
req: Json<VerifyLinkReq>,
state: Data<AppState>,
) -> JsonAPIResponse<VerifyLinkResp> {
let mut conn = handle_error!(state.pool.get().await);
let token = match handle_error!(magic_links::table.find(&req.magic_link_token).first::<MagicLink>(&mut conn).await.optional()) {
let token = match handle_error!(magic_links::table
.find(&req.magic_link_token)
.first::<MagicLink>(&mut conn)
.await
.optional())
{
Some(t) => t,
None => {
err!(

View File

@ -1,5 +1,8 @@
use crate::models::TotpAuthenticator;
use crate::models::User;
use crate::response::JsonAPIResponse;
use crate::schema::totp_authenticators;
use crate::schema::users;
use crate::{auth, enforce, randid, AppState};
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest};
@ -8,11 +11,8 @@ use diesel::QueryDsl;
use diesel::SelectableHelper;
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use totp_rs::{Algorithm, Secret, TOTP};
use crate::schema::totp_authenticators;
use crate::schema::users;
use crate::models::User;
use std::time::SystemTime;
use totp_rs::{Algorithm, Secret, TOTP};
#[derive(Deserialize)]
pub struct TotpAuthenticatorReq {}
@ -45,11 +45,24 @@ pub async fn create_totp_auth_req(
let auth_info = auth!(req_info, conn);
let session_token = enforce!(sess auth_info);
let user = handle_error!(users::table.find(&session_token.user_id).first::<User>(&mut conn).await);
let user = handle_error!(
users::table
.find(&session_token.user_id)
.first::<User>(&mut conn)
.await
);
let secret = Secret::generate_secret();
let totp = handle_error!(TOTP::new(Algorithm::SHA1, 6, 1, 30, handle_error!(secret.to_bytes()), Some("Trifid".to_string()), user.email.clone()));
let totp = handle_error!(TOTP::new(
Algorithm::SHA1,
6,
1,
30,
handle_error!(secret.to_bytes()),
Some("Trifid".to_string()),
user.email.clone()
));
let new_totp_authenticator = TotpAuthenticator {
id: randid!(id "totp"),
@ -58,7 +71,7 @@ pub async fn create_totp_auth_req(
verified: false,
name: "".to_string(),
created_at: SystemTime::now(),
last_seen_at: SystemTime::now()
last_seen_at: SystemTime::now(),
};
handle_error!(
@ -77,3 +90,6 @@ pub async fn create_totp_auth_req(
metadata: TotpAuthRespMeta {}
})
}
// TODO: Get, Edit, Delete
// All of these API endpoints are the same... they could probably be automated...

View File

@ -1,21 +1,21 @@
use std::time::{Duration, SystemTime};
use actix_web::http::StatusCode;
use actix_web::{HttpRequest, post};
use actix_web::web::{Data, Json};
use serde::{Deserialize, Serialize};
use crate::{AppState, auth, enforce, randid};
use crate::response::JsonAPIResponse;
use diesel::{QueryDsl, ExpressionMethods, SelectableHelper, OptionalExtension};
use diesel_async::RunQueryDsl;
use totp_rs::{Algorithm, Secret, TOTP};
use crate::schema::{auth_tokens, totp_authenticators, users};
use crate::models::{AuthToken, TotpAuthenticator, User};
use crate::response::JsonAPIResponse;
use crate::schema::{auth_tokens, totp_authenticators, users};
use crate::{auth, enforce, randid, AppState};
use actix_web::http::StatusCode;
use actix_web::web::{Data, Json};
use actix_web::{post, HttpRequest};
use diesel::{ExpressionMethods, OptionalExtension, QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime};
use totp_rs::{Algorithm, Secret, TOTP};
#[derive(Deserialize, Debug)]
pub struct VerifyTotpAuthReq {
#[serde(rename = "totpToken")]
pub totp_token: String,
pub code: String
pub code: String,
}
#[derive(Serialize, Debug)]
@ -34,14 +34,28 @@ pub struct TotpAuthResp {
}
#[post("/v1/verify-totp-authenticator")]
pub async fn verify_totp_req(req: Json<VerifyTotpAuthReq>, state: Data<AppState>, req_info: HttpRequest) -> JsonAPIResponse<TotpAuthResp> {
pub async fn verify_totp_req(
req: Json<VerifyTotpAuthReq>,
state: Data<AppState>,
req_info: HttpRequest,
) -> JsonAPIResponse<TotpAuthResp> {
let mut conn = handle_error!(state.pool.get().await);
let auth_info = auth!(req_info, conn);
let session_token = enforce!(sess auth_info);
let user = handle_error!(users::table.find(&session_token.user_id).first::<User>(&mut conn).await);
let user = handle_error!(
users::table
.find(&session_token.user_id)
.first::<User>(&mut conn)
.await
);
let authenticator = match handle_error!(totp_authenticators::table.find(&req.totp_token).first::<TotpAuthenticator>(&mut conn).await.optional()) {
let authenticator = match handle_error!(totp_authenticators::table
.find(&req.totp_token)
.first::<TotpAuthenticator>(&mut conn)
.await
.optional())
{
Some(t) => t,
None => {
err!(
@ -67,14 +81,33 @@ pub async fn verify_totp_req(req: Json<VerifyTotpAuthReq>, state: Data<AppState>
}
let secret = Secret::Encoded(authenticator.secret.clone());
let totp_machine = handle_error!(TOTP::new(Algorithm::SHA1, 6, 1, 30, handle_error!(secret.to_bytes()), Some("Trifid".to_string()), user.email.clone()));
let totp_machine = handle_error!(TOTP::new(
Algorithm::SHA1,
6,
1,
30,
handle_error!(secret.to_bytes()),
Some("Trifid".to_string()),
user.email.clone()
));
let is_valid = handle_error!(totp_machine.check_current(&req.code));
if !is_valid {
err!(StatusCode::UNAUTHORIZED, make_err!("ERR_UNAUTHORIZED", "unauthorized"));
err!(
StatusCode::UNAUTHORIZED,
make_err!("ERR_UNAUTHORIZED", "unauthorized")
);
}
handle_error!(diesel::update(&authenticator).set((totp_authenticators::dsl::verified.eq(true), totp_authenticators::dsl::last_seen_at.eq(SystemTime::now()))).execute(&mut conn).await);
handle_error!(
diesel::update(&authenticator)
.set((
totp_authenticators::dsl::verified.eq(true),
totp_authenticators::dsl::last_seen_at.eq(SystemTime::now())
))
.execute(&mut conn)
.await
);
// issue auth token
@ -92,12 +125,10 @@ pub async fn verify_totp_req(req: Json<VerifyTotpAuthReq>, state: Data<AppState>
.await
);
ok!(
TotpAuthResp {
ok!(TotpAuthResp {
data: TotpAuthRespData {
auth_token: new_token.id.clone()
},
metadata: TotpAuthRespMeta {}
}
)
})
}