// Code to handle the nebula worker use std::error::Error; use std::{io, thread}; use std::io::{BufRead, BufReader, BufWriter, Read, Write}; use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream}; use std::sync::mpsc::{Receiver, TryRecvError}; use log::{debug, error, info, trace, warn}; use serde::{Deserialize, Serialize}; use crate::apiworker::APIWorkerMessage; use crate::config::TFClientConfig; use crate::daemon::ThreadMessageSender; use crate::nebulaworker::NebulaWorkerMessage; pub enum SocketWorkerMessage { Shutdown } pub fn socketworker_main(config: TFClientConfig, transmitter: ThreadMessageSender, rx: Receiver) { info!("socketworker_main called, entering realmain"); match _main(config, transmitter, rx) { Ok(_) => (), Err(e) => { error!("Error in socket thread: {}", e); } }; } fn _main(config: TFClientConfig, transmitter: ThreadMessageSender, rx: Receiver) -> Result<(), Box> { let listener = TcpListener::bind(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), config.listen_port))?; listener.set_nonblocking(true)?; loop { match listener.accept() { Ok(stream) => { let transmitter_clone = transmitter.clone(); thread::spawn(|| { match handle_stream(stream, transmitter_clone) { Ok(_) => (), Err(e) => { error!("Error in client thread: {}", e); } } }); }, Err(e) if e.kind() == io::ErrorKind::WouldBlock => (), Err(e) => { Err(e)?; } } match rx.try_recv() { Ok(msg) => { match msg { SocketWorkerMessage::Shutdown => { info!("recv on command socket: shutdown, stopping"); break; } } }, Err(e) => { match e { TryRecvError::Empty => {} TryRecvError::Disconnected => { error!("socketworker command socket disconnected, shutting down to prevent orphaning"); break; } } } } } Ok(()) } fn handle_stream(stream: (TcpStream, SocketAddr), transmitter: ThreadMessageSender) -> Result<(), io::Error> { info!("Incoming client"); match handle_client(stream.0, transmitter) { Ok(()) => (), Err(e) if e.kind() == io::ErrorKind::TimedOut => { warn!("Client timed out, connection aborted"); }, Err(e) if e.kind() == io::ErrorKind::NotConnected => { warn!("Client connection severed"); }, Err(e) if e.kind() == io::ErrorKind::BrokenPipe => { warn!("Client connection returned error: broken pipe"); }, Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => { warn!("Client aborted connection"); }, Err(e) => { error!("Error in client handler: {}", e); return Err(e); } }; Ok(()) } fn handle_client(stream: TcpStream, transmitter: ThreadMessageSender) -> Result<(), io::Error> { info!("Handling connection from {}", stream.peer_addr()?); let mut client = Client { state: ClientState::WaitHello, reader: BufReader::new(&stream), writer: BufWriter::new(&stream), stream: &stream }; loop { let mut command = String::new(); client.reader.read_line(&mut command)?; let command: JsonMessage = serde_json::from_str(&command)?; trace!("recv {:?} from {}", command, client.stream.peer_addr()?); let should_disconnect; match client.state { ClientState::WaitHello => { should_disconnect = waithello_handle(&mut client, &transmitter, command)?; } ClientState::SentHello => { should_disconnect = senthello_handle(&mut client, &transmitter, command)?; } } if should_disconnect { break; } } // Gracefully close the connection stream.shutdown(Shutdown::Both)?; Ok(()) } struct Client<'a> { state: ClientState, reader: BufReader<&'a TcpStream>, writer: BufWriter<&'a TcpStream>, stream: &'a TcpStream } fn waithello_handle(client: &mut Client, _transmitter: &ThreadMessageSender, command: JsonMessage) -> Result { trace!("state: WaitHello, handing with waithello_handle"); let mut should_disconnect = false; match command { JsonMessage::Hello { version } => { if version != JSON_API_VERSION { should_disconnect = true; client.stream.write_all(&ctob(JsonMessage::Goodbye { reason: DisconnectReason::UnsupportedVersion { expected: JSON_API_VERSION, got: version } }))?; } client.stream.write_all(&ctob(JsonMessage::Hello { version: JSON_API_VERSION }))?; client.state = ClientState::SentHello; trace!("setting state to SentHello"); }, JsonMessage::Goodbye { reason } => { info!("Client sent disconnect: {:?}", reason); should_disconnect = true; }, _ => { debug!("message type unexpected in WaitHello state"); should_disconnect = true; client.stream.write_all(&ctob(JsonMessage::Goodbye { reason: DisconnectReason::UnexpectedMessageType, }))?; } } Ok(should_disconnect) } fn senthello_handle(client: &mut Client, transmitter: &ThreadMessageSender, command: JsonMessage) -> Result { trace!("state: SentHello, handing with senthello_handle"); let mut should_disconnect = false; match command { JsonMessage::Goodbye { reason } => { info!("Client sent disconnect: {:?}", reason); should_disconnect = true; }, JsonMessage::Shutdown {} => { info!("Requested to shutdown by local control socket. Sending shutdown message to threads"); match transmitter.nebula_thread.send(NebulaWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to nebula worker thread: {}", e); } } match transmitter.api_thread.send(APIWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to api worker thread: {}", e); } } match transmitter.socket_thread.send(SocketWorkerMessage::Shutdown) { Ok(_) => (), Err(e) => { error!("Error sending shutdown message to socket worker thread: {}", e); } } } _ => { debug!("message type unexpected in SentHello state"); should_disconnect = true; client.stream.write_all(&ctob(JsonMessage::Goodbye { reason: DisconnectReason::UnexpectedMessageType, }))?; } } Ok(should_disconnect) } fn ctob(command: JsonMessage) -> Vec { let command_str = serde_json::to_string(&command).unwrap() + "\n"; command_str.into_bytes() } enum ClientState { WaitHello, SentHello } pub const JSON_API_VERSION: i32 = 1; #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "method")] enum JsonMessage { #[serde(rename = "hello")] Hello { version: i32 }, #[serde(rename = "goodbye")] Goodbye { reason: DisconnectReason }, #[serde(rename = "shutdown")] Shutdown {} } #[derive(Serialize, Deserialize, Debug)] enum DisconnectReason { #[serde(rename = "unsupported_version")] UnsupportedVersion { expected: i32, got: i32 }, #[serde(rename = "unexpected_message_type")] UnexpectedMessageType, #[serde(rename = "done")] Done }