// trifid-api, an open source reimplementation of the Defined Networking nebula management server.
// Copyright (C) 2023 c0repwn3r
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
use actix_web::HttpRequest;
use std::error::Error;
use crate::timers::expired;
use crate::tokens::get_token_type;
use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, QueryFilter};
use trifid_api_entities::entity::api_key;
use trifid_api_entities::entity::api_key_scope;
use trifid_api_entities::entity::user;
use trifid_api_entities::entity::{auth_token, session_token};
pub enum TokenInfo {
SessionToken(SessionTokenInfo),
AuthToken(AuthTokenInfo),
ApiToken(ApiTokenInfo),
NotPresent,
}
pub struct SessionTokenInfo {
pub token: String,
pub user: SessionTokenUser,
pub expires_at: i64,
}
pub struct SessionTokenUser {
pub id: String,
pub email: String,
}
pub struct ApiTokenInfo {
pub scopes: Vec,
pub organization: String,
}
pub struct AuthTokenInfo {
pub token: String,
pub session_info: SessionTokenInfo,
}
pub async fn enforce_session(
req: &HttpRequest,
db: &DatabaseConnection,
) -> Result> {
let header = req
.headers()
.get("Authorization")
.ok_or("Missing authorization header")?;
let authorization = header.to_str()?;
let authorization_split: Vec<&str> = authorization.split(' ').collect();
if authorization_split[0] != "Bearer" {
return Err("Not a bearer token".into());
}
let tokens = &authorization_split[1..];
let sess_token = tokens
.iter()
.find(|i| get_token_type(i).unwrap_or("n-sess") == "sess")
.copied()
.ok_or("Missing session token")?;
let token: session_token::Model = session_token::Entity::find()
.filter(session_token::Column::Id.eq(sess_token))
.one(db)
.await?
.ok_or("Invalid session token")?;
if expired(token.expires_on as u64) {
return Err("Token expired".into());
}
let user: user::Model = user::Entity::find()
.filter(user::Column::Id.eq(token.user))
.one(db)
.await?
.ok_or("Session token has a nonexistent user")?;
Ok(TokenInfo::SessionToken(SessionTokenInfo {
token: token.id,
user: SessionTokenUser {
id: user.id,
email: user.email,
},
expires_at: token.expires_on,
}))
}
pub async fn enforce_2fa(
req: &HttpRequest,
db: &DatabaseConnection,
) -> Result> {
let session_data = match enforce_session(req, db).await? {
TokenInfo::SessionToken(i) => i,
_ => unreachable!(),
};
let header = req
.headers()
.get("Authorization")
.ok_or("Missing authorization header")?;
let authorization = header.to_str()?;
let authorization_split: Vec<&str> = authorization.split(' ').collect();
if authorization_split[0] != "Bearer" {
return Err("Not a bearer token".into());
}
let tokens = &authorization_split[1..];
let auth_token = tokens
.iter()
.find(|i| get_token_type(i).unwrap_or("n-auth") == "auth")
.copied()
.ok_or("Missing auth token")?;
let token: auth_token::Model = auth_token::Entity::find()
.filter(auth_token::Column::Id.eq(auth_token))
.one(db)
.await?
.ok_or("Invalid session token")?;
if expired(token.expires_on as u64) {
return Err("Token expired".into());
}
Ok(TokenInfo::AuthToken(AuthTokenInfo {
token: token.id,
session_info: session_data,
}))
}
pub async fn enforce_api_token(
req: &HttpRequest,
scopes: &[&str],
db: &DatabaseConnection,
) -> Result> {
let header = req
.headers()
.get("Authorization")
.ok_or("Missing authorization header")?;
let authorization = header.to_str()?;
let authorization_split: Vec<&str> = authorization.split(' ').collect();
if authorization_split[0] != "Bearer" {
return Err("Not a bearer token".into());
}
let tokens = &authorization_split[1..];
let api_token = tokens
.iter()
.find(|i| get_token_type(i).unwrap_or("n-tfkey") == "tfkey")
.copied()
.ok_or("Missing api token")?;
// API tokens are special and have a different form than other keys.
// They follow the form:
// tfkey-[ID]-[TOKEN]
let api_token_split: Vec<&str> = api_token.split('-').collect();
if api_token_split.len() != 3 {
return Err("API token is missing key".into());
}
let token_id = format!("{}-{}", api_token_split[0], api_token_split[1]);
let token_key = api_token_split[2].to_string();
let token: api_key::Model = api_key::Entity::find()
.filter(
Condition::all()
.add(api_key::Column::Id.eq(token_id))
.add(api_key::Column::Key.eq(token_key)),
)
.one(db)
.await?
.ok_or("Invalid api token")?;
let token_scopes: Vec = api_key_scope::Entity::find()
.filter(api_key_scope::Column::ApiKey.eq(api_token))
.all(db)
.await?;
let token_scopes: Vec<&str> = token_scopes.iter().map(|i| i.scope.as_str()).collect();
for scope in scopes {
if !token_scopes.contains(scope) {
return Err(format!("API token is missing scope {}", scope).into());
}
}
Ok(TokenInfo::ApiToken(ApiTokenInfo {
scopes: token_scopes.iter().map(|i| i.to_string()).collect(),
organization: token.organization,
}))
}