trifid/tfclient/src/socketworker.rs

348 lines
10 KiB
Rust

// Code to handle the nebula worker
use std::error::Error;
use std::io::{BufRead, BufReader, Write};
use std::net::{IpAddr, Shutdown, SocketAddr, TcpListener, TcpStream};
use std::sync::mpsc::Receiver;
use std::{io, thread};
use crate::apiworker::APIWorkerMessage;
use crate::config::{load_cdata, TFClientConfig};
use crate::daemon::ThreadMessageSender;
use crate::nebulaworker::NebulaWorkerMessage;
use crate::timerworker::TimerWorkerMessage;
use log::{debug, error, info, trace, warn};
use serde::{Deserialize, Serialize};
pub enum SocketWorkerMessage {
Shutdown,
WakeUp,
}
pub fn socketworker_main(
config: TFClientConfig,
instance: String,
transmitter: ThreadMessageSender,
rx: Receiver<SocketWorkerMessage>,
) {
info!("socketworker_main called, entering realmain");
match _main(config, instance, transmitter, rx) {
Ok(_) => (),
Err(e) => {
error!("Error in socket thread: {}", e);
}
};
}
fn _main(
config: TFClientConfig,
instance: String,
transmitter: ThreadMessageSender,
rx: Receiver<SocketWorkerMessage>,
) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(SocketAddr::new(
IpAddr::from([127, 0, 0, 1]),
config.listen_port,
))?;
listener.set_nonblocking(true)?;
loop {
match listener.accept() {
Ok(stream) => {
let transmitter_clone = transmitter.clone();
let config_clone = config.clone();
let instance_clone = instance.clone();
thread::spawn(|| {
match handle_stream(stream, transmitter_clone, config_clone, instance_clone) {
Ok(_) => (),
Err(e) => {
error!("Error in client thread: {}", e);
}
}
});
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => (),
Err(e) => {
Err(e)?;
}
}
match rx.recv() {
Ok(msg) => match msg {
SocketWorkerMessage::Shutdown => {
info!("recv on command socket: shutdown, stopping");
break;
}
SocketWorkerMessage::WakeUp => {
continue;
}
},
Err(e) => {
error!("socketworker command socket errored: {}", e);
break;
}
}
}
Ok(())
}
fn handle_stream(
stream: (TcpStream, SocketAddr),
transmitter: ThreadMessageSender,
config: TFClientConfig,
instance: String,
) -> Result<(), io::Error> {
info!("Incoming client");
match handle_client(stream.0, transmitter, config, instance) {
Ok(()) => (),
Err(e) if e.kind() == io::ErrorKind::TimedOut => {
warn!("Client timed out, connection aborted");
}
Err(e) if e.kind() == io::ErrorKind::NotConnected => {
warn!("Client connection severed");
}
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => {
warn!("Client connection returned error: broken pipe");
}
Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
warn!("Client aborted connection");
}
Err(e) => {
error!("Error in client handler: {}", e);
return Err(e);
}
};
Ok(())
}
fn handle_client(
stream: TcpStream,
transmitter: ThreadMessageSender,
_config: TFClientConfig,
instance: String,
) -> Result<(), io::Error> {
info!("Handling connection from {}", stream.peer_addr()?);
let mut client = Client {
state: ClientState::WaitHello,
reader: BufReader::new(&stream),
stream: &stream,
instance,
};
loop {
let mut command = String::new();
client.reader.read_line(&mut command)?;
let command: JsonMessage = serde_json::from_str(&command)?;
trace!("recv {:?} from {}", command, client.stream.peer_addr()?);
let should_disconnect = match client.state {
ClientState::WaitHello => waithello_handle(&mut client, &transmitter, command)?,
ClientState::SentHello => senthello_handle(&mut client, &transmitter, command)?,
};
if should_disconnect {
break;
}
}
// Gracefully close the connection
stream.shutdown(Shutdown::Both)?;
Ok(())
}
struct Client<'a> {
state: ClientState,
reader: BufReader<&'a TcpStream>,
stream: &'a TcpStream,
instance: String,
}
fn waithello_handle(
client: &mut Client,
_transmitter: &ThreadMessageSender,
command: JsonMessage,
) -> Result<bool, io::Error> {
trace!("state: WaitHello, handing with waithello_handle");
let mut should_disconnect = false;
match command {
JsonMessage::Hello { version } => {
if version != JSON_API_VERSION {
should_disconnect = true;
client.stream.write_all(&ctob(JsonMessage::Goodbye {
reason: DisconnectReason::UnsupportedVersion {
expected: JSON_API_VERSION,
got: version,
},
}))?;
}
client.stream.write_all(&ctob(JsonMessage::Hello {
version: JSON_API_VERSION,
}))?;
client.state = ClientState::SentHello;
trace!("setting state to SentHello");
}
JsonMessage::Goodbye { reason } => {
info!("Client sent disconnect: {:?}", reason);
should_disconnect = true;
}
_ => {
debug!("message type unexpected in WaitHello state");
should_disconnect = true;
client.stream.write_all(&ctob(JsonMessage::Goodbye {
reason: DisconnectReason::UnexpectedMessageType,
}))?;
}
}
Ok(should_disconnect)
}
fn senthello_handle(
client: &mut Client,
transmitter: &ThreadMessageSender,
command: JsonMessage,
) -> Result<bool, io::Error> {
trace!("state: SentHello, handing with senthello_handle");
let mut should_disconnect = false;
match command {
JsonMessage::Goodbye { reason } => {
info!("Client sent disconnect: {:?}", reason);
should_disconnect = true;
}
JsonMessage::Shutdown {} => {
info!("Requested to shutdown by local control socket. Sending shutdown message to threads");
match transmitter
.nebula_thread
.send(NebulaWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!(
"Error sending shutdown message to nebula worker thread: {}",
e
);
}
}
match transmitter.api_thread.send(APIWorkerMessage::Shutdown) {
Ok(_) => (),
Err(e) => {
error!("Error sending shutdown message to api worker thread: {}", e);
}
}
match transmitter
.socket_thread
.send(SocketWorkerMessage::Shutdown)
{
Ok(_) => (),
Err(e) => {
error!(
"Error sending shutdown message to socket worker thread: {}",
e
);
}
}
match transmitter.timer_thread.send(TimerWorkerMessage::Shutdown) {
Ok(_) => (),
Err(e) => {
error!(
"Error sending shutdown message to timer worker thread: {}",
e
);
}
}
}
JsonMessage::GetHostID {} => {
let data = match load_cdata(&client.instance) {
Ok(d) => d,
Err(e) => {
error!("Error loading cdata: {}", e);
panic!("{}", e); // TODO: Find a better way of handling this
}
};
client.stream.write_all(&ctob(JsonMessage::HostID {
has_id: data.creds.is_some(),
id: data.creds.map(|c| c.host_id),
}))?;
}
JsonMessage::Enroll { code } => {
info!("Client sent enroll with code {}", code);
info!("Sending enroll request to apiworker");
transmitter
.api_thread
.send(APIWorkerMessage::Enroll { code })
.unwrap();
}
JsonMessage::Update {} => {
info!("Client sent update request.");
info!("Telling apiworker to update configuration");
transmitter
.api_thread
.send(APIWorkerMessage::Update)
.unwrap();
}
_ => {
debug!("message type unexpected in SentHello state");
should_disconnect = true;
client.stream.write_all(&ctob(JsonMessage::Goodbye {
reason: DisconnectReason::UnexpectedMessageType,
}))?;
}
}
Ok(should_disconnect)
}
pub fn ctob(command: JsonMessage) -> Vec<u8> {
let command_str = serde_json::to_string(&command).unwrap() + "\n";
command_str.into_bytes()
}
enum ClientState {
WaitHello,
SentHello,
}
pub const JSON_API_VERSION: i32 = 1;
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "method")]
pub enum JsonMessage {
#[serde(rename = "hello")]
Hello { version: i32 },
#[serde(rename = "goodbye")]
Goodbye { reason: DisconnectReason },
#[serde(rename = "shutdown")]
Shutdown {},
#[serde(rename = "get_host_id")]
GetHostID {},
#[serde(rename = "host_id")]
HostID { has_id: bool, id: Option<String> },
#[serde(rename = "enroll")]
Enroll { code: String },
#[serde(rename = "update")]
Update {},
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type")]
pub enum DisconnectReason {
#[serde(rename = "unsupported_version")]
UnsupportedVersion { expected: i32, got: i32 },
#[serde(rename = "unexpected_message_type")]
UnexpectedMessageType,
#[serde(rename = "done")]
Done,
}