trifid/trifid-api/src/auth.rs

128 lines
4.6 KiB
Rust

use rocket::http::Status;
use rocket::{Request};
use rocket::request::{FromRequest, Outcome};
use crate::tokens::{validate_auth_token, validate_session_token};
pub struct PartialUserInfo {
pub user_id: i32,
pub created_at: i64,
pub email: String,
pub has_totp_auth: bool,
pub session_id: String,
pub auth_id: Option<String>
}
#[derive(Debug)]
pub enum AuthenticationError {
MissingToken,
InvalidToken(usize),
DatabaseError,
RequiresTOTP
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for PartialUserInfo {
type Error = AuthenticationError;
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = req.headers();
// make sure the bearer token exists
if let Some(authorization) = headers.get_one("Authorization") {
// parse bearer token
let components = authorization.split(' ').collect::<Vec<&str>>();
if components.len() != 2 && components.len() != 3 {
return Outcome::Failure((Status::Unauthorized, AuthenticationError::MissingToken));
}
if components[0] != "Bearer" {
return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(0)));
}
if components.len() == 2 && !components[1].starts_with("st-") {
return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(1)));
}
let st: String;
let user_id: i64;
let at: Option<String>;
match &components[1][..3] {
"st-" => {
// validate session token
st = components[1].to_string();
match validate_session_token(st.clone(), req.rocket().state().unwrap()).await {
Ok(uid) => user_id = uid,
Err(_) => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(2)))
}
},
_ => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(3)))
}
if components.len() == 3 {
match &components[2][..3] {
"at-" => {
// validate auth token
at = Some(components[2].to_string());
match validate_auth_token(at.clone().unwrap().clone(), st.clone(), req.rocket().state().unwrap()).await {
Ok(_) => (),
Err(_) => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(4)))
}
},
_ => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(5)))
}
} else {
at = None;
}
// this user is 100% valid and authenticated, fetch their info
let user = match sqlx::query!("SELECT * FROM users WHERE id = $1", user_id.clone() as i32).fetch_one(req.rocket().state().unwrap()).await {
Ok(u) => u,
Err(_) => return Outcome::Failure((Status::InternalServerError, AuthenticationError::DatabaseError))
};
Outcome::Success(PartialUserInfo {
user_id: user_id as i32,
created_at: user.created_on as i64,
email: user.email,
has_totp_auth: at.is_some(),
session_id: st,
auth_id: at,
})
} else {
Outcome::Failure((Status::Unauthorized, AuthenticationError::MissingToken))
}
}
}
pub struct TOTPAuthenticatedUserInfo {
pub user_id: i32,
pub created_at: i64,
pub email: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for TOTPAuthenticatedUserInfo {
type Error = AuthenticationError;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let userinfo = PartialUserInfo::from_request(request).await;
match userinfo {
Outcome::Failure(e) => Outcome::Failure(e),
Outcome::Forward(f) => Outcome::Forward(f),
Outcome::Success(s) => {
if s.has_totp_auth {
Outcome::Success(Self {
user_id: s.user_id,
created_at: s.created_at,
email: s.email,
})
} else {
Outcome::Failure((Status::Unauthorized, AuthenticationError::RequiresTOTP))
}
}
}
}
}