use std::fmt::Debug; use std::marker::PhantomData; use super::{ApiError, ApiState}; use anyhow::Context; use axum::{ async_trait, body::BoxBody, extract::{FromRequest, FromRequestParts}, http::Request, }; use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken}; use serde::{de::DeserializeOwned, Serialize}; use ssh_key::{Certificate, SshSig}; use tracing::trace; #[derive(Debug, Clone)] pub struct CertificateBody(pub Certificate); // we must implement `FromRequest` (and not `FromRequestParts`) to consume the body #[async_trait] impl FromRequest for CertificateBody where S: Send + Sync, { type Rejection = ApiError; async fn from_request(req: Request, state: &S) -> Result { let body = String::from_request(req, state) .await .context("failed to extract body")?; let cert = Certificate::from_openssh(&body) .with_context(|| format!("failed to parse '{}'", body)) .map_err(ApiError::ParseCertificate)?; trace!(%body, "extracted certificate"); Ok(Self(cert)) } } #[derive(Debug, Clone)] pub struct SignatureBody(pub SshSig); #[async_trait] impl FromRequest for SignatureBody where S: Send + Sync, { type Rejection = ApiError; async fn from_request(req: Request, state: &S) -> Result { let body = String::from_request(req, state) .await .context("failed to extract body")?; let sig = SshSig::from_pem(&body) .with_context(|| format!("failed to parse '{}'", body)) .map_err(ApiError::ParseSignature)?; trace!(%body, "extracted signature"); Ok(Self(sig)) } } pub struct JWTString(String); impl From for JWTString { fn from(s: String) -> Self { Self(s) } } // TODO: be generic over ApiState -> AsRef, AsRef where A: AlgorithmExt #[derive(Debug)] pub struct JWTAuthenticated< T: Serialize + DeserializeOwned + Clone + Debug, Q: FromRequestParts + Debug + Into, > where ApiError: From<>::Rejection>, { pub data: T, _marker: PhantomData, } impl< T: Serialize + DeserializeOwned + Clone + Debug, Q: FromRequestParts + Debug + Into, > JWTAuthenticated where ApiError: From<>::Rejection>, { pub fn new(data: T) -> Self { Self { data, _marker: Default::default(), } } } #[async_trait] impl< T: Serialize + DeserializeOwned + Clone + Debug, Q: FromRequestParts + Debug + Into, > FromRequestParts for JWTAuthenticated where ApiError: From<>::Rejection>, { type Rejection = ApiError; async fn from_request_parts( parts: &mut axum::http::request::Parts, state: &ApiState, ) -> Result { let JWTString(token) = Q::from_request_parts(parts, state).await?.into(); let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?; let verified: Token = Hs256 .validate_integrity(&token, &state.jwt_key) .map_err(ApiError::JWTVerify)?; Ok(Self { data: verified.claims().custom.clone(), _marker: Default::default(), }) } }