124 lines
4.5 KiB
Rust
124 lines
4.5 KiB
Rust
|
use rocket::http::Status;
|
||
|
use rocket::{Request, State};
|
||
|
use rocket::request::{FromRequest, Outcome};
|
||
|
use sqlx::PgPool;
|
||
|
use crate::tokens::{validate_auth_token, validate_session_token};
|
||
|
|
||
|
pub struct PartialUserInfo {
|
||
|
pub user_id: i32,
|
||
|
pub created_at: i64,
|
||
|
pub email: String,
|
||
|
pub hasTotp: bool
|
||
|
}
|
||
|
|
||
|
#[derive(Debug)]
|
||
|
pub enum AuthenticationError {
|
||
|
MissingToken,
|
||
|
InvalidToken,
|
||
|
DatabaseError,
|
||
|
RequiresTOTP
|
||
|
}
|
||
|
|
||
|
#[rocket::async_trait]
|
||
|
impl<'r> FromRequest<'r> for PartialUserInfo {
|
||
|
type Error = AuthenticationError;
|
||
|
|
||
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||
|
let headers = req.headers();
|
||
|
|
||
|
// make sure the bearer token exists
|
||
|
if let Some(authorization) = headers.get_one("Authorization") {
|
||
|
// parse bearer token
|
||
|
let components = authorization.split(' ').collect::<Vec<&str>>();
|
||
|
|
||
|
if components.len() != 2 || components.len() != 3 {
|
||
|
return Outcome::Failure((Status::Unauthorized, AuthenticationError::MissingToken));
|
||
|
}
|
||
|
|
||
|
if components[0] != "Bearer" {
|
||
|
return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken));
|
||
|
}
|
||
|
|
||
|
if components.len() == 2 && components[1].starts_with("st-") {
|
||
|
return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken));
|
||
|
}
|
||
|
|
||
|
let st: String;
|
||
|
let user_id: i64;
|
||
|
let at: Option<String>;
|
||
|
|
||
|
match &components[1][..3] {
|
||
|
"st-" => {
|
||
|
// validate session token
|
||
|
st = components[1].to_string();
|
||
|
match validate_session_token(st.clone(), req.rocket().state().unwrap()).await {
|
||
|
Ok(uid) => user_id = uid,
|
||
|
Err(_) => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken))
|
||
|
}
|
||
|
},
|
||
|
_ => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken))
|
||
|
}
|
||
|
|
||
|
if components.len() == 3 {
|
||
|
match &components[2][..3] {
|
||
|
"at-" => {
|
||
|
// validate auth token
|
||
|
at = Some(components[2].to_string());
|
||
|
match validate_auth_token(at.clone().unwrap().clone(), st.clone(), req.rocket().state().unwrap()).await {
|
||
|
Ok(_) => (),
|
||
|
Err(_) => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken))
|
||
|
}
|
||
|
},
|
||
|
_ => return Outcome::Failure((Status::Unauthorized, AuthenticationError::InvalidToken))
|
||
|
}
|
||
|
} else {
|
||
|
at = None;
|
||
|
}
|
||
|
|
||
|
// this user is 100% valid and authenticated, fetch their info
|
||
|
|
||
|
let user = match sqlx::query!("SELECT * FROM users WHERE id = $1", user_id.clone() as i32).fetch_one(req.rocket().state().unwrap()).await {
|
||
|
Ok(u) => u,
|
||
|
Err(_) => return Outcome::Failure((Status::InternalServerError, AuthenticationError::DatabaseError))
|
||
|
};
|
||
|
|
||
|
Outcome::Success(PartialUserInfo {
|
||
|
user_id: user_id as i32,
|
||
|
created_at: user.created_on as i64,
|
||
|
email: user.email,
|
||
|
hasTotp: at.is_some(),
|
||
|
})
|
||
|
} else {
|
||
|
Outcome::Failure((Status::Unauthorized, AuthenticationError::MissingToken))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub struct TOTPAuthenticatedUserInfo {
|
||
|
pub user_id: i32,
|
||
|
pub created_at: i64,
|
||
|
pub email: String,
|
||
|
}
|
||
|
#[rocket::async_trait]
|
||
|
impl<'r> FromRequest<'r> for TOTPAuthenticatedUserInfo {
|
||
|
type Error = AuthenticationError;
|
||
|
|
||
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||
|
let userinfo = PartialUserInfo::from_request(request).await;
|
||
|
match userinfo {
|
||
|
Outcome::Failure(e) => Outcome::Failure(e),
|
||
|
Outcome::Forward(f) => Outcome::Forward(f),
|
||
|
Outcome::Success(s) => {
|
||
|
if s.hasTotp {
|
||
|
Outcome::Success(Self {
|
||
|
user_id: s.user_id,
|
||
|
created_at: s.created_at,
|
||
|
email: s.email,
|
||
|
})
|
||
|
} else {
|
||
|
Outcome::Failure((Status::Unauthorized, AuthenticationError::RequiresTOTP))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|