// Code to handle the nebula worker use std::error::Error; use std::{io, thread}; use std::io::{BufRead, BufReader, BufWriter, Write}; use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream}; use std::sync::mpsc::{Receiver, TryRecvError}; use log::{debug, error, info, trace, warn}; use serde::{Deserialize, Serialize}; use crate::apiworker::APIWorkerMessage; use crate::config::{load_cdata, TFClientConfig}; use crate::daemon::ThreadMessageSender; use crate::nebulaworker::NebulaWorkerMessage; use crate::timerworker::TimerWorkerMessage; pub enum SocketWorkerMessage { Shutdown } pub fn socketworker_main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSender, rx: Receiver) { info!("socketworker_main called, entering realmain"); match _main(config, instance, transmitter, rx) { Ok(_) => (), Err(e) => { error!("Error in socket thread: {}", e); } }; } fn _main(config: TFClientConfig, instance: String, transmitter: ThreadMessageSender, rx: Receiver) -> Result<(), Box> { let listener = TcpListener::bind(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?; listener.set_nonblocking(true)?; loop { match listener.accept() { Ok(stream) => { let transmitter_clone = transmitter.clone(); let config_clone = config.clone(); let instance_clone = instance.clone(); thread::spawn(|| { match handle_stream(stream, transmitter_clone, config_clone, instance_clone) { Ok(_) => (), Err(e) => { error!("Error in client thread: {}", e); } } }); }, Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => { Err(e)?; } } match rx.try_recv() { Ok(msg) => { match msg { SocketWorkerMessage::Shutdown => { info!("recv on command socket: shutdown, stopping"); break; } } }, Err(e) => { match e { TryRecvError::Empty => {} TryRecvError::Disconnected => { error!("socketworker command socket disconnected, shutting down to prevent orphaning"); break; } } } } } Ok(()) } fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSender, config: TFClientConfig, instance: String) -> Result<(), io::Error> { info!("Incoming client"); match handle_client(stream.0, transmitter, config, instance) { Ok(()) => (), Err(e) if e.kind() == io::ErrorKind::TimedOut => { warn!("Client timed out, connection aborted"); }, Err(e) if e.kind() == io::ErrorKind::NotConnected => { warn!("Client connection severed"); }, Err(e) if e.kind() == io::ErrorKind::BrokenPipe => { warn!("Client connection returned error: broken pipe"); }, Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => { warn!("Client aborted connection"); }, Err(e) => { error!("Error in client handler: {}", e); return Err(e); } }; Ok(()) } fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender, config: TFClientConfig, instance: String) -> Result<(), io::Error> { info!("Handling connection from {}", stream.peer_addr()?); let mut client = Client { state: ClientState::WaitHello, reader: BufReader::new(&stream), writer: BufWriter::new(&stream), stream: &stream, config, instance, }; loop { let mut command = String::new(); client.reader.read_line(&mut command)?; let command: JsonMessage = serde_json::from_str(&command)?; trace!("recv {:?} from {}", command, client.stream.peer_addr()?); let should_disconnect; match client.state { ClientState::WaitHello => { should_disconnect = waithello_handle(&mut client, &transmitter, command)?; } ClientState::SentHello => { should_disconnect = senthello_handle(&mut client, &transmitter, command)?; } } if should_disconnect { break; } } // Gracefully close the connection stream.shutdown(Shutdown::Both)?; Ok(()) } struct Client<'a> { state: ClientState, reader: BufReader<&'a TcpStream>, writer: BufWriter<&'a TcpStream>, stream: &'a TcpStream, config: TFClientConfig, instance: String } fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, command: JsonMessage) -> Result { trace!("state: WaitHello, handing with waithello_handle"); let mut should_disconnect = false; match command { JsonMessage::Hello { version } => { if version != JSON_API_VERSION { should_disconnect = true; client.stream.write_all(&ctob(JsonMessage::Goodbye { reason: DisconnectReason::UnsupportedVersion { expected: JSON_API_VERSION, got: version } }))?; } client.stream.write_all(&ctob(JsonMessage::Hello { version: JSON_API_VERSION }))?; client.state = ClientState::SentHello; trace!("setting state to SentHello"); }, JsonMessage::Goodbye { reason } => { info!("Client sent disconnect: {:?}", reason); should_disconnect = true; }, _ => { debug!("message type unexpected in WaitHello state"); should_disconnect = true; client.stream.write_all(&ctob(JsonMessage::Goodbye { reason: DisconnectReason::UnexpectedMessageType, }))?; } } Ok(should_disconnect) } fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, command: JsonMessage) -> Result { trace!("state: SentHello, handing with senthello_handle"); let mut should_disconnect = false; match command { JsonMessage::Goodbye { reason } => { info!("Client sent disconnect: {:?}", reason); should_disconnect = true; }, JsonMessage::Shutdown {} => { info!("Requested to shutdown by local control socket. Sending shutdown message to threads"); match transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to nebula worker thread: {}", e); } } match transmitter.api_thread.send(APIWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to api worker thread: {}", e); } } match transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to socket worker thread: {}", e); } } match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to timer worker thread: {}", e); } } }, JsonMessage::GetHostID {} => { let data = match load_cdata(&client.instance) { Ok(d) => d, Err(e) => { error!("Error loading cdata: {}", e); panic!("{}", e); // TODO: Find a better way of handling this } }; client.stream.write_all(&ctob(JsonMessage::HostID { has_id: data.host_id.is_some(), id: data.host_id }))?; }, JsonMessage::Enroll { code } => { info!("Client sent enroll with code {}", code); info!("Sending enroll request to apiworker"); transmitter.api_thread.send(APIWorkerMessage::Enroll { code }).unwrap(); } _ => { debug!("message type unexpected in SentHello state"); should_disconnect = true; client.stream.write_all(&ctob(JsonMessage::Goodbye { reason: DisconnectReason::UnexpectedMessageType, }))?; } } Ok(should_disconnect) } pub fn ctob(command: JsonMessage) -> Vec { let command_str = serde_json::to_string(&command).unwrap() + "\n"; command_str.into_bytes() } enum ClientState { WaitHello, SentHello } pub const JSON_API_VERSION: i32 = 1; #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "method")] pub enum JsonMessage { #[serde(rename = "hello")] Hello { version: i32 }, #[serde(rename = "goodbye")] Goodbye { reason: DisconnectReason }, #[serde(rename = "shutdown")] Shutdown {}, #[serde(rename = "get_host_id")] GetHostID {}, #[serde(rename = "host_id")] HostID { has_id: bool, id: Option }, #[serde(rename = "enroll")] Enroll { code: String } } #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type")] pub enum DisconnectReason { #[serde(rename = "unsupported_version")] UnsupportedVersion { expected: i32, got: i32 }, #[serde(rename = "unexpected_message_type")] UnexpectedMessageType, #[serde(rename = "done")] Done }