diff --git a/Cargo.lock b/Cargo.lock index 8419086..aecf850 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -424,6 +424,16 @@ dependencies = [ "cipher", ] +[[package]] +name = "ctrlc" +version = "3.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbcf33c2a618cbe41ee43ae6e9f2e48368cd9f9db2896f10167d8d762679f639" +dependencies = [ + "nix", + "windows-sys 0.45.0", +] + [[package]] name = "curve25519-dalek" version = "3.2.0" @@ -1355,6 +1365,18 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "static_assertions", +] + [[package]] name = "nom" version = "7.1.3" @@ -2299,6 +2321,12 @@ dependencies = [ "loom", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "stringprep" version = "0.1.2" @@ -2393,6 +2421,7 @@ name = "tfclient" version = "0.1.0" dependencies = [ "clap", + "ctrlc", "dirs 5.0.0", "flate2", "hex", diff --git a/tfclient/Cargo.toml b/tfclient/Cargo.toml index 499e6ee..d1590d6 100644 --- a/tfclient/Cargo.toml +++ b/tfclient/Cargo.toml @@ -18,6 +18,7 @@ url = "2.3.1" toml = "0.7.3" serde = { version = "1.0.158", features = ["derive"] } serde_json = "1.0.94" +ctrlc = "3.2.5" [build-dependencies] serde = { version = "1.0.157", features = ["derive"] } diff --git a/tfclient/src/apiworker.rs b/tfclient/src/apiworker.rs index 134c085..1ee773e 100644 --- a/tfclient/src/apiworker.rs +++ b/tfclient/src/apiworker.rs @@ -1,11 +1,32 @@ -use std::sync::mpsc::Receiver; +use std::sync::mpsc::{Receiver, TryRecvError}; +use log::{error, info}; use crate::config::TFClientConfig; use crate::daemon::ThreadMessageSender; pub enum APIWorkerMessage { - + Shutdown } -pub fn apiworker_main(_config: TFClientConfig, _transmitters: ThreadMessageSender, _rx: Receiver) { - +pub fn apiworker_main(_config: TFClientConfig, _transmitters: ThreadMessageSender, rx: Receiver) { + loop { + match rx.try_recv() { + Ok(msg) => { + match msg { + APIWorkerMessage::Shutdown => { + info!("recv on command socket: shutdown, stopping"); + break; + } + } + }, + Err(e) => { + match e { + TryRecvError::Empty => {} + TryRecvError::Disconnected => { + error!("apiworker command socket disconnected, shutting down to prevent orphaning"); + break; + } + } + } + } + } } \ No newline at end of file diff --git a/tfclient/src/config.rs b/tfclient/src/config.rs index 5f82572..d5b10e4 100644 --- a/tfclient/src/config.rs +++ b/tfclient/src/config.rs @@ -10,7 +10,7 @@ fn default_port() -> u16 { DEFAULT_PORT } #[derive(Serialize, Deserialize, Clone)] pub struct TFClientConfig { #[serde(default = "default_port")] - listen_port: u16 + pub listen_port: u16 } pub fn create_config(instance: &str) -> Result<(), Box> { diff --git a/tfclient/src/daemon.rs b/tfclient/src/daemon.rs index c0623d6..29b7aaa 100644 --- a/tfclient/src/daemon.rs +++ b/tfclient/src/daemon.rs @@ -5,6 +5,7 @@ use log::{error, info}; use crate::apiworker::{apiworker_main, APIWorkerMessage}; use crate::config::load_config; +use crate::main; use crate::nebulaworker::{nebulaworker_main, NebulaWorkerMessage}; use crate::socketworker::{socketworker_main, SocketWorkerMessage}; use crate::util::check_server_url; @@ -22,7 +23,7 @@ pub fn daemon_main(name: String, server: String) { } }; - info!("Starting API thread..."); + info!("Creating transmitter"); let (tx_api, rx_api) = mpsc::channel::(); let (tx_socket, rx_socket) = mpsc::channel::(); @@ -34,6 +35,42 @@ pub fn daemon_main(name: String, server: String) { nebula_thread: tx_nebula }; + let mainthread_transmitter = transmitter.clone(); + + info!("Setting signal trap..."); + + match ctrlc::set_handler(move || { + info!("Ctrl-C detected. Stopping threads..."); + match mainthread_transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) { + Ok(_) => (), + Err(e) => { + error!("Error sending shutdown message to nebula worker thread: {}", e); + } + } + match mainthread_transmitter.api_thread.send(APIWorkerMessage::Shutdown) { + Ok(_) => (), + Err(e) => { + error!("Error sending shutdown message to api worker thread: {}", e); + } + } + match mainthread_transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) { + Ok(_) => (), + Err(e) => { + error!("Error sending shutdown message to socket worker thread: {}", e); + } + } + }) { + Ok(_) => (), + Err(e) => { + error!("Unable to set sigtrap: {}", e); + std::process::exit(1); + } + } + + info!("Starting API thread..."); + + + let config_api = config.clone(); let transmitter_api = transmitter.clone(); let api_thread = thread::spawn(move || { @@ -60,6 +97,7 @@ pub fn daemon_main(name: String, server: String) { std::process::exit(1); } } + info!("Socket thread exited"); info!("Waiting for API thread to exit..."); match api_thread.join() { @@ -69,6 +107,7 @@ pub fn daemon_main(name: String, server: String) { std::process::exit(1); } } + info!("API thread exited"); info!("Waiting for Nebula thread to exit..."); match nebula_thread.join() { @@ -78,13 +117,14 @@ pub fn daemon_main(name: String, server: String) { std::process::exit(1); } } + info!("Nebula thread exited"); info!("All threads exited"); } #[derive(Clone)] pub struct ThreadMessageSender { - socket_thread: Sender, - api_thread: Sender, - nebula_thread: Sender + pub socket_thread: Sender, + pub api_thread: Sender, + pub nebula_thread: Sender } \ No newline at end of file diff --git a/tfclient/src/nebulaworker.rs b/tfclient/src/nebulaworker.rs index 5549775..3bef8e6 100644 --- a/tfclient/src/nebulaworker.rs +++ b/tfclient/src/nebulaworker.rs @@ -1,13 +1,34 @@ // Code to handle the nebula worker -use std::sync::mpsc::Receiver; +use std::sync::mpsc::{Receiver, TryRecvError}; +use log::{error, info}; use crate::config::TFClientConfig; use crate::daemon::ThreadMessageSender; pub enum NebulaWorkerMessage { - + Shutdown } -pub fn nebulaworker_main(_config: TFClientConfig, _transmitter: ThreadMessageSender, _rx: Receiver) { - +pub fn nebulaworker_main(_config: TFClientConfig, _transmitter: ThreadMessageSender, rx: Receiver) { + loop { + match rx.try_recv() { + Ok(msg) => { + match msg { + NebulaWorkerMessage::Shutdown => { + info!("recv on command socket: shutdown, stopping"); + break; + } + } + }, + Err(e) => { + match e { + TryRecvError::Empty => {} + TryRecvError::Disconnected => { + error!("nebulaworker command socket disconnected, shutting down to prevent orphaning"); + break; + } + } + } + } + } } \ No newline at end of file diff --git a/tfclient/src/socketworker.rs b/tfclient/src/socketworker.rs index af86f4e..71d0c80 100644 --- a/tfclient/src/socketworker.rs +++ b/tfclient/src/socketworker.rs @@ -1,13 +1,259 @@ // Code to handle the nebula worker -use std::sync::mpsc::Receiver; +use std::error::Error; +use std::{io, thread}; +use std::io::{BufRead, BufReader, BufWriter, Read, Write}; +use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream}; +use std::sync::mpsc::{Receiver, TryRecvError}; +use log::{debug, error, info, trace, warn}; +use serde::{Deserialize, Serialize}; +use crate::apiworker::APIWorkerMessage; use crate::config::TFClientConfig; use crate::daemon::ThreadMessageSender; +use crate::nebulaworker::NebulaWorkerMessage; pub enum SocketWorkerMessage { - + Shutdown } -pub fn socketworker_main(_config: TFClientConfig, _transmitter: ThreadMessageSender, _rx: Receiver) { +pub fn socketworker_main(config: TFClientConfig, transmitter: ThreadMessageSender, rx: Receiver) { + info!("socketworker_main called, entering realmain"); + match _main(config, transmitter, rx) { + Ok(_) => (), + Err(e) => { + error!("Error in socket thread: {}", e); + } + }; +} +fn _main(config: TFClientConfig, transmitter: ThreadMessageSender, rx: Receiver) -> Result<(), Box> { + let listener = TcpListener::bind(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?; + listener.set_nonblocking(true)?; + + loop { + match listener.accept() { + Ok(stream) => { + let transmitter_clone = transmitter.clone(); + thread::spawn(|| { + match handle_stream(stream, transmitter_clone) { + Ok(_) => (), + Err(e) => { + error!("Error in client thread: {}", e); + } + } + }); + }, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), + Err(e) => { Err(e)?; } + } + + match rx.try_recv() { + Ok(msg) => { + match msg { + SocketWorkerMessage::Shutdown => { + info!("recv on command socket: shutdown, stopping"); + break; + } + } + }, + Err(e) => { + match e { + TryRecvError::Empty => {} + TryRecvError::Disconnected => { + error!("socketworker command socket disconnected, shutting down to prevent orphaning"); + break; + } + } + } + } + } + + Ok(()) +} + +fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSender) -> Result<(), io::Error> { + info!("Incoming client"); + match handle_client(stream.0, transmitter) { + Ok(()) => (), + Err(e) if e.kind() == io::ErrorKind::TimedOut => { + warn!("Client timed out, connection aborted"); + }, + Err(e) if e.kind() == io::ErrorKind::NotConnected => { + warn!("Client connection severed"); + }, + Err(e) if e.kind() == io::ErrorKind::BrokenPipe => { + warn!("Client connection returned error: broken pipe"); + }, + Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => { + warn!("Client aborted connection"); + }, + Err(e) => { + error!("Error in client handler: {}", e); + return Err(e); + } + }; + Ok(()) +} + +fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender) -> Result<(), io::Error> { + info!("Handling connection from {}", stream.peer_addr()?); + + let mut client = Client { + state: ClientState::WaitHello, + reader: BufReader::new(&stream), + writer: BufWriter::new(&stream), + stream: &stream + }; + + loop { + let mut command = String::new(); + client.reader.read_line(&mut command)?; + + let command: JsonMessage = serde_json::from_str(&command)?; + + trace!("recv {:?} from {}", command, client.stream.peer_addr()?); + + let should_disconnect; + + match client.state { + ClientState::WaitHello => { + should_disconnect = waithello_handle(&mut client, &transmitter, command)?; + } + ClientState::SentHello => { + should_disconnect = senthello_handle(&mut client, &transmitter, command)?; + } + } + + if should_disconnect { break; } + } + + // Gracefully close the connection + stream.shutdown(Shutdown::Both)?; + + Ok(()) +} + +struct Client<'a> { + state: ClientState, + reader: BufReader<&'a TcpStream>, + writer: BufWriter<&'a TcpStream>, + stream: &'a TcpStream +} + +fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, command: JsonMessage) -> Result { + trace!("state: WaitHello, handing with waithello_handle"); + let mut should_disconnect = false; + + match command { + JsonMessage::Hello { version } => { + if version != JSON_API_VERSION { + should_disconnect = true; + client.stream.write_all(&ctob(JsonMessage::Goodbye { + reason: DisconnectReason::UnsupportedVersion { + expected: JSON_API_VERSION, + got: version + } + }))?; + } + client.stream.write_all(&ctob(JsonMessage::Hello { + version: JSON_API_VERSION + }))?; + client.state = ClientState::SentHello; + trace!("setting state to SentHello"); + }, + JsonMessage::Goodbye { reason } => { + info!("Client sent disconnect: {:?}", reason); + should_disconnect = true; + }, + _ => { + debug!("message type unexpected in WaitHello state"); + should_disconnect = true; + client.stream.write_all(&ctob(JsonMessage::Goodbye { + reason: DisconnectReason::UnexpectedMessageType, + }))?; + } + } + + Ok(should_disconnect) +} + +fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, command: JsonMessage) -> Result { + trace!("state: SentHello, handing with senthello_handle"); + let mut should_disconnect = false; + + match command { + JsonMessage::Goodbye { reason } => { + info!("Client sent disconnect: {:?}", reason); + should_disconnect = true; + }, + + JsonMessage::Shutdown {} => { + info!("Requested to shutdown by local control socket. Sending shutdown message to threads"); + match transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) { + Ok(_) => (), + Err(e) => { + error!("Error sending shutdown message to nebula worker thread: {}", e); + } + } + match transmitter.api_thread.send(APIWorkerMessage::Shutdown) { + Ok(_) => (), + Err(e) => { + error!("Error sending shutdown message to api worker thread: {}", e); + } + } + match transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) { + Ok(_) => (), + Err(e) => { + error!("Error sending shutdown message to socket worker thread: {}", e); + } + } + } + + _ => { + debug!("message type unexpected in SentHello state"); + should_disconnect = true; + client.stream.write_all(&ctob(JsonMessage::Goodbye { + reason: DisconnectReason::UnexpectedMessageType, + }))?; + } + } + + Ok(should_disconnect) +} + +fn ctob(command: JsonMessage) -> Vec { + let command_str = serde_json::to_string(&command).unwrap() + "\n"; + command_str.into_bytes() +} + +enum ClientState { + WaitHello, + SentHello +} + +pub const JSON_API_VERSION: i32 = 1; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "method")] +enum JsonMessage { + #[serde(rename = "hello")] + Hello { + version: i32 + }, + #[serde(rename = "goodbye")] + Goodbye { + reason: DisconnectReason + }, + #[serde(rename = "shutdown")] + Shutdown {} +} + +#[derive(Serialize, Deserialize, Debug)] +enum DisconnectReason { + #[serde(rename = "unsupported_version")] + UnsupportedVersion { expected: i32, got: i32 }, + #[serde(rename = "unexpected_message_type")] + UnexpectedMessageType, + #[serde(rename = "done")] + Done } \ No newline at end of file