shimun f47c57c1c0
wip: JWTAuthenticated
TODO: move into api/

fix: test

chore: move JWTAuthenticated into extract

chore: fmt
2023-03-12 16:25:17 +01:00

124 lines
3.4 KiB
Rust

use std::fmt::Debug;
use std::marker::PhantomData;
use super::{ApiError, ApiState};
use anyhow::Context;
use axum::{
async_trait,
body::BoxBody,
extract::{FromRequest, FromRequestParts},
http::Request,
};
use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken};
use serde::{de::DeserializeOwned, Serialize};
use ssh_key::{Certificate, SshSig};
use tracing::trace;
#[derive(Debug, Clone)]
pub struct CertificateBody(pub Certificate);
// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body
#[async_trait]
impl<S> FromRequest<S, BoxBody> for CertificateBody
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = String::from_request(req, state)
.await
.context("failed to extract body")?;
let cert = Certificate::from_openssh(&body)
.with_context(|| format!("failed to parse '{}'", body))
.map_err(ApiError::ParseCertificate)?;
trace!(%body, "extracted certificate");
Ok(Self(cert))
}
}
#[derive(Debug, Clone)]
pub struct SignatureBody(pub SshSig);
#[async_trait]
impl<S> FromRequest<S, BoxBody> for SignatureBody
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = String::from_request(req, state)
.await
.context("failed to extract body")?;
let sig = SshSig::from_pem(&body)
.with_context(|| format!("failed to parse '{}'", body))
.map_err(ApiError::ParseSignature)?;
trace!(%body, "extracted signature");
Ok(Self(sig))
}
}
pub struct JWTString(String);
impl From<String> for JWTString {
fn from(s: String) -> Self {
Self(s)
}
}
// TODO: be generic over ApiState -> AsRef<Target=Hs256>, AsRef<Target=A> where A: AlgorithmExt
#[derive(Debug)]
pub struct JWTAuthenticated<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
pub data: T,
_marker: PhantomData<Q>,
}
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> JWTAuthenticated<T, Q>
where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
pub fn new(data: T) -> Self {
Self {
data,
_marker: Default::default(),
}
}
}
#[async_trait]
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> FromRequestParts<ApiState> for JWTAuthenticated<T, Q>
where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
type Rejection = ApiError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &ApiState,
) -> Result<Self, Self::Rejection> {
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?;
let verified: Token<T> = Hs256
.validate_integrity(&token, &state.jwt_key)
.map_err(ApiError::JWTVerify)?;
Ok(Self {
data: verified.claims().custom.clone(),
_marker: Default::default(),
})
}
}