use rocket::http::Status; use rocket::{Request}; use rocket::request::{FromRequest, Outcome}; use crate::tokens::{validate_auth_token, validate_session_token}; pub struct PartialUserInfo { pub user_id: i32, pub created_at: i64, pub email: String, pub has_totp_auth: bool, pub session_id: String, pub auth_id: Option } #[derive(Debug)] pub enum AuthenticationError { MissingToken, InvalidToken(usize), DatabaseError, RequiresTOTP } #[rocket::async_trait] impl<'r> FromRequest<'r> for PartialUserInfo { type Error = AuthenticationError; async fn from_request(req: &'r Request<'_>) -> Outcome { let headers = req.headers(); // make sure the bearer token exists if let Some(authorization) = headers.get_one("Authorization") { // parse bearer token let components = authorization.split(' ').collect::>(); if components.len() != 2 && components.len() != 3 { return Outcome::Failure((Status::Unauthorized, AuthenticationError::MissingToken)); } if components[0] != "Bearer" { return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(0))); } if components.len() == 2 && !components[1].starts_with("st-") { return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(1))); } let st: String; let user_id: i64; let at: Option; match &components[1][..3] { "st-" => { // validate session token st = components[1].to_string(); match validate_session_token(st.clone(), req.rocket().state().unwrap()).await { Ok(uid) => user_id = uid, Err(_) => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(2))) } }, _ => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(3))) } if components.len() == 3 { match &components[2][..3] { "at-" => { // validate auth token at = Some(components[2].to_string()); match validate_auth_token(at.clone().unwrap().clone(), st.clone(), req.rocket().state().unwrap()).await { Ok(_) => (), Err(_) => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(4))) } }, _ => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken(5))) } } else { at = None; } // this user is 100% valid and authenticated, fetch their info let user = match sqlx::query!("SELECT * FROM users WHERE id = $1", user_id.clone() as i32).fetch_one(req.rocket().state().unwrap()).await { Ok(u) => u, Err(_) => return Outcome::Failure((Status::InternalServerError, AuthenticationError::DatabaseError)) }; Outcome::Success(PartialUserInfo { user_id: user_id as i32, created_at: user.created_on as i64, email: user.email, has_totp_auth: at.is_some(), session_id: st, auth_id: at, }) } else { Outcome::Failure((Status::Unauthorized, AuthenticationError::MissingToken)) } } } pub struct TOTPAuthenticatedUserInfo { pub user_id: i32, pub created_at: i64, pub email: String, } #[rocket::async_trait] impl<'r> FromRequest<'r> for TOTPAuthenticatedUserInfo { type Error = AuthenticationError; async fn from_request(request: &'r Request<'_>) -> Outcome { let userinfo = PartialUserInfo::from_request(request).await; match userinfo { Outcome::Failure(e) => Outcome::Failure(e), Outcome::Forward(f) => Outcome::Forward(f), Outcome::Success(s) => { if s.has_totp_auth { Outcome::Success(Self { user_id: s.user_id, created_at: s.created_at, email: s.email, }) } else { Outcome::Failure((Status::Unauthorized, AuthenticationError::RequiresTOTP)) } } } } }