// trifid-api, an open source reimplementation of the Defined Networking nebula management server. // Copyright (C) 2023 c0repwn3r // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . use actix_web::HttpRequest; use std::error::Error; use crate::timers::expired; use crate::tokens::get_token_type; use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, QueryFilter}; use trifid_api_entities::entity::api_key; use trifid_api_entities::entity::api_key_scope; use trifid_api_entities::entity::user; use trifid_api_entities::entity::{auth_token, session_token}; pub enum TokenInfo { SessionToken(SessionTokenInfo), AuthToken(AuthTokenInfo), ApiToken(ApiTokenInfo), NotPresent, } pub struct SessionTokenInfo { pub token: String, pub user: SessionTokenUser, pub expires_at: i64, } pub struct SessionTokenUser { pub id: String, pub email: String, } pub struct ApiTokenInfo { pub scopes: Vec, pub organization: String, } pub struct AuthTokenInfo { pub token: String, pub session_info: SessionTokenInfo, } pub async fn enforce_session( req: &HttpRequest, db: &DatabaseConnection, ) -> Result> { let header = req .headers() .get("Authorization") .ok_or("Missing authorization header")?; let authorization = header.to_str()?; let authorization_split: Vec<&str> = authorization.split(' ').collect(); if authorization_split[0] != "Bearer" { return Err("Not a bearer token".into()); } let tokens = &authorization_split[1..]; let sess_token = tokens .iter() .find(|i| get_token_type(i).unwrap_or("n-sess") == "sess") .copied() .ok_or("Missing session token")?; let token: session_token::Model = session_token::Entity::find() .filter(session_token::Column::Id.eq(sess_token)) .one(db) .await? .ok_or("Invalid session token")?; if expired(token.expires_on as u64) { return Err("Token expired".into()); } let user: user::Model = user::Entity::find() .filter(user::Column::Id.eq(token.user)) .one(db) .await? .ok_or("Session token has a nonexistent user")?; Ok(TokenInfo::SessionToken(SessionTokenInfo { token: token.id, user: SessionTokenUser { id: user.id, email: user.email, }, expires_at: token.expires_on, })) } pub async fn enforce_2fa( req: &HttpRequest, db: &DatabaseConnection, ) -> Result> { let session_data = match enforce_session(req, db).await? { TokenInfo::SessionToken(i) => i, _ => unreachable!(), }; let header = req .headers() .get("Authorization") .ok_or("Missing authorization header")?; let authorization = header.to_str()?; let authorization_split: Vec<&str> = authorization.split(' ').collect(); if authorization_split[0] != "Bearer" { return Err("Not a bearer token".into()); } let tokens = &authorization_split[1..]; let auth_token = tokens .iter() .find(|i| get_token_type(i).unwrap_or("n-auth") == "auth") .copied() .ok_or("Missing auth token")?; let token: auth_token::Model = auth_token::Entity::find() .filter(auth_token::Column::Id.eq(auth_token)) .one(db) .await? .ok_or("Invalid session token")?; if expired(token.expires_on as u64) { return Err("Token expired".into()); } Ok(TokenInfo::AuthToken(AuthTokenInfo { token: token.id, session_info: session_data, })) } pub async fn enforce_api_token( req: &HttpRequest, scopes: &[&str], db: &DatabaseConnection, ) -> Result> { let header = req .headers() .get("Authorization") .ok_or("Missing authorization header")?; let authorization = header.to_str()?; let authorization_split: Vec<&str> = authorization.split(' ').collect(); if authorization_split[0] != "Bearer" { return Err("Not a bearer token".into()); } let tokens = &authorization_split[1..]; let api_token = tokens .iter() .find(|i| get_token_type(i).unwrap_or("n-tfkey") == "tfkey") .copied() .ok_or("Missing api token")?; // API tokens are special and have a different form than other keys. // They follow the form: // tfkey-[ID]-[TOKEN] let api_token_split: Vec<&str> = api_token.split('-').collect(); if api_token_split.len() != 3 { return Err("API token is missing key".into()); } let token_id = format!("{}-{}", api_token_split[0], api_token_split[1]); let token_key = api_token_split[2].to_string(); let token: api_key::Model = api_key::Entity::find() .filter( Condition::all() .add(api_key::Column::Id.eq(token_id)) .add(api_key::Column::Key.eq(token_key)), ) .one(db) .await? .ok_or("Invalid api token")?; let token_scopes: Vec = api_key_scope::Entity::find() .filter(api_key_scope::Column::ApiKey.eq(api_token)) .all(db) .await?; let token_scopes: Vec<&str> = token_scopes.iter().map(|i| i.scope.as_str()).collect(); for scope in scopes { if !token_scopes.contains(scope) { return Err(format!("API token is missing scope {}", scope).into()); } } Ok(TokenInfo::ApiToken(ApiTokenInfo { scopes: token_scopes.iter().map(|i| i.to_string()).collect(), organization: token.organization, })) }