diff --git a/Cargo.lock b/Cargo.lock index d2384c5..fc1592a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2139,6 +2139,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f82e6c8c047aa50a7328632d067bcae6ef38772a79e28daf32f735e0e4f3dd10" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha1" version = "0.10.5" @@ -2472,6 +2485,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "serde_yaml", "sha2", "simple_logger", "tar", @@ -2889,6 +2903,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "unsafe-libyaml" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad2024452afd3874bf539695e04af6732ba06517424dbf958fdb16a01f3bef6c" + [[package]] name = "url" version = "2.3.1" diff --git a/tfclient/Cargo.toml b/tfclient/Cargo.toml index 7dce9aa..bd16e92 100644 --- a/tfclient/Cargo.toml +++ b/tfclient/Cargo.toml @@ -25,6 +25,7 @@ chrono = "0.4.24" ipnet = "2.7.1" base64-serde = "0.7.0" dnapi-rs = { version = "0.1.6", path = "../dnapi-rs" } +serde_yaml = "0.9.19" [build-dependencies] serde = { version = "1.0.157", features = ["derive"] } diff --git a/tfclient/src/config.rs b/tfclient/src/config.rs index 0ada404..ed92faa 100644 --- a/tfclient/src/config.rs +++ b/tfclient/src/config.rs @@ -15,7 +15,7 @@ use crate::dirs::{get_cdata_dir, get_cdata_file, get_config_dir, get_config_file pub const DEFAULT_PORT: u16 = 8157; fn default_port() -> u16 { DEFAULT_PORT } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct TFClientConfig { #[serde(default = "default_port")] pub listen_port: u16 @@ -23,7 +23,7 @@ pub struct TFClientConfig { #[derive(Serialize, Deserialize, Clone)] pub struct TFClientData { - pub dh_privkey: Option<[u8; 32]>, + pub dh_privkey: Option>, pub creds: Option, pub meta: Option } @@ -98,7 +98,7 @@ pub fn save_cdata(instance: &str, data: TFClientData) -> Result<(), Box } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigPki { pub ca: String, pub cert: String, @@ -154,7 +154,7 @@ pub struct NebulaConfigPki { pub disconnect_invalid: bool } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigLighthouse { #[serde(default = "bool_false")] #[serde(skip_serializing_if = "is_bool_false")] @@ -178,7 +178,7 @@ pub struct NebulaConfigLighthouse { pub local_allow_list: HashMap, // `interfaces` is not supported } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigLighthouseDns { pub host: Ipv4Addr, #[serde(default = "u16_53")] @@ -186,7 +186,7 @@ pub struct NebulaConfigLighthouseDns { pub port: u16 } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigListen { #[serde(default = "ipv4_0000")] #[serde(skip_serializing_if = "is_ipv4_0000")] @@ -203,7 +203,7 @@ pub struct NebulaConfigListen { pub write_buffer: Option } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigPunchy { #[serde(default = "bool_false")] #[serde(skip_serializing_if = "is_bool_false")] @@ -216,7 +216,7 @@ pub struct NebulaConfigPunchy { pub delay: String } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub enum NebulaConfigCipher { #[serde(rename = "aes")] Aes, @@ -224,7 +224,7 @@ pub enum NebulaConfigCipher { ChaChaPoly } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigRelay { #[serde(default = "empty_vec")] #[serde(skip_serializing_if = "is_empty_vec")] @@ -237,7 +237,7 @@ pub struct NebulaConfigRelay { pub use_relays: bool } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigTun { #[serde(default = "bool_false")] #[serde(skip_serializing_if = "is_bool_false")] @@ -264,13 +264,13 @@ pub struct NebulaConfigTun { pub unsafe_routes: Vec } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigTunRouteOverride { pub mtu: u64, pub route: Ipv4Net } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigTunUnsafeRoute { pub route: Ipv4Net, pub via: Ipv4Addr, @@ -282,7 +282,7 @@ pub struct NebulaConfigTunUnsafeRoute { pub metric: i64 } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigLogging { #[serde(default = "loglevel_info")] #[serde(skip_serializing_if = "is_loglevel_info")] @@ -298,7 +298,7 @@ pub struct NebulaConfigLogging { pub timestamp_format: String } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub enum NebulaConfigLoggingLevel { #[serde(rename = "panic")] Panic, @@ -314,7 +314,7 @@ pub enum NebulaConfigLoggingLevel { Debug } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub enum NebulaConfigLoggingFormat { #[serde(rename = "json")] Json, @@ -322,7 +322,7 @@ pub enum NebulaConfigLoggingFormat { Text } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigSshd { #[serde(default = "bool_false")] #[serde(skip_serializing_if = "is_bool_false")] @@ -334,7 +334,7 @@ pub struct NebulaConfigSshd { pub authorized_users: Vec } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigSshdAuthorizedUser { pub user: String, #[serde(default = "empty_vec")] @@ -342,7 +342,7 @@ pub struct NebulaConfigSshdAuthorizedUser { pub keys: Vec } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] #[serde(tag = "type")] pub enum NebulaConfigStats { #[serde(rename = "graphite")] @@ -351,7 +351,7 @@ pub enum NebulaConfigStats { Prometheus(NebulaConfigStatsPrometheus) } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigStatsGraphite { #[serde(default = "string_nebula")] #[serde(skip_serializing_if = "is_string_nebula")] @@ -369,7 +369,7 @@ pub struct NebulaConfigStatsGraphite { pub lighthouse_metrics: bool } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub enum NebulaConfigStatsGraphiteProtocol { #[serde(rename = "tcp")] Tcp, @@ -377,7 +377,7 @@ pub enum NebulaConfigStatsGraphiteProtocol { Udp } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigStatsPrometheus { pub listen: String, pub path: String, @@ -396,7 +396,7 @@ pub struct NebulaConfigStatsPrometheus { pub lighthouse_metrics: bool } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigFirewall { #[serde(default = "none")] #[serde(skip_serializing_if = "is_none")] @@ -411,7 +411,7 @@ pub struct NebulaConfigFirewall { pub outbound: Option>, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigFirewallConntrack { #[serde(default = "string_12m")] #[serde(skip_serializing_if = "is_string_12m")] @@ -424,7 +424,7 @@ pub struct NebulaConfigFirewallConntrack { pub default_timeout: String } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct NebulaConfigFirewallRule { #[serde(default = "none")] #[serde(skip_serializing_if = "is_none")] diff --git a/tfclient/src/daemon.rs b/tfclient/src/daemon.rs index 7d912e0..6f59054 100644 --- a/tfclient/src/daemon.rs +++ b/tfclient/src/daemon.rs @@ -91,7 +91,7 @@ pub fn daemon_main(name: String, server: String) { let transmitter_nebula = transmitter.clone(); let name_nebula = name.clone(); let nebula_thread = thread::spawn(move || { - nebulaworker_main(config_nebula, name, transmitter_nebula, rx_nebula); + nebulaworker_main(config_nebula, name_nebula, transmitter_nebula, rx_nebula); }); info!("Starting timer thread..."); @@ -101,8 +101,9 @@ pub fn daemon_main(name: String, server: String) { }); info!("Starting socket worker thread..."); + let name_socket = name.clone(); let socket_thread = thread::spawn(move || { - socketworker_main(config, name.clone(), transmitter, rx_socket); + socketworker_main(config, name_socket, transmitter, rx_socket); }); info!("Waiting for socket thread to exit..."); diff --git a/tfclient/src/nebulaworker.rs b/tfclient/src/nebulaworker.rs index 7cb685a..eadc28a 100644 --- a/tfclient/src/nebulaworker.rs +++ b/tfclient/src/nebulaworker.rs @@ -1,15 +1,36 @@ // Code to handle the nebula worker +use std::error::Error; +use std::fs; use std::sync::mpsc::{Receiver, TryRecvError}; -use log::{error, info}; -use crate::config::{load_cdata, save_cdata, TFClientConfig}; +use log::{debug, error, info}; +use crate::config::{load_cdata, NebulaConfig, save_cdata, TFClientConfig}; use crate::daemon::ThreadMessageSender; +use crate::dirs::get_nebulaconfig_file; +use crate::embedded_nebula::run_embedded_nebula; pub enum NebulaWorkerMessage { Shutdown, ConfigUpdated } +fn insert_private_key(instance: &str) -> Result<(), Box> { + let cdata = load_cdata(instance)?; + let key = cdata.dh_privkey.ok_or("Missing private key")?; + + let config_str = fs::read_to_string(get_nebulaconfig_file(instance).ok_or("Could not get config file location")?)?; + let mut config: NebulaConfig = serde_yaml::from_str(&config_str)?; + + config.pki.key = String::from_utf8(key)?; + + debug!("inserted private key into config: {:?}", config); + + let config_str = serde_yaml::to_string(&config)?; + fs::write(get_nebulaconfig_file(instance).ok_or("Could not get config file location")?, config_str)?; + + Ok(()) +} + pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter: ThreadMessageSender, rx: Receiver) { let cdata = match load_cdata(&instance) { Ok(data) => data, @@ -20,6 +41,28 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter } }; + info!("fixing config..."); + match insert_private_key(&instance) { + Ok(_) => { + info!("config fixed (private-key embedded)"); + }, + Err(e) => { + error!("unable to fix config: {}", e); + error!("nebula thread exiting with error"); + return; + } + } + info!("starting nebula child..."); + let mut child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) { + Ok(c) => c, + Err(e) => { + error!("unable to start embedded nebula binary: {}", e); + error!("nebula thread exiting with error"); + return; + } + }; + info!("nebula process started"); + // dont need to save it, because we do not, in any circumstance, write to it loop { match rx.try_recv() { @@ -27,10 +70,50 @@ pub fn nebulaworker_main(_config: TFClientConfig, instance: String, _transmitter match msg { NebulaWorkerMessage::Shutdown => { info!("recv on command socket: shutdown, stopping"); + info!("shutting down nebula binary"); + match child.kill() { + Ok(_) => { + debug!("nebula process exited"); + }, + Err(e) => { + error!("nebula process already exited: {}", e); + } + } + info!("nebula shut down"); break; }, NebulaWorkerMessage::ConfigUpdated => { - info!("our configuration has been updated - reloading"); + info!("our configuration has been updated - restarting"); + debug!("killing existing process"); + match child.kill() { + Ok(_) => { + debug!("nebula process exited"); + }, + Err(e) => { + error!("nebula process already exited: {}", e); + } + } + debug!("fixing config..."); + match insert_private_key(&instance) { + Ok(_) => { + debug!("config fixed (private-key embedded)"); + }, + Err(e) => { + error!("unable to fix config: {}", e); + error!("nebula thread exiting with error"); + return; + } + } + debug!("restarting nebula process"); + child = match run_embedded_nebula(&["-config".to_string(), get_nebulaconfig_file(&instance).unwrap().to_str().unwrap().to_string()]) { + Ok(c) => c, + Err(e) => { + error!("unable to start embedded nebula binary: {}", e); + error!("nebula thread exiting with error"); + return; + } + }; + debug!("nebula process restarted"); } } },