TODO: move into api/ fix: test chore: move JWTAuthenticated into extract chore: fmt
124 lines
3.4 KiB
Rust
124 lines
3.4 KiB
Rust
use std::fmt::Debug;
|
|
use std::marker::PhantomData;
|
|
|
|
use super::{ApiError, ApiState};
|
|
use anyhow::Context;
|
|
use axum::{
|
|
async_trait,
|
|
body::BoxBody,
|
|
extract::{FromRequest, FromRequestParts},
|
|
http::Request,
|
|
};
|
|
use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken};
|
|
use serde::{de::DeserializeOwned, Serialize};
|
|
use ssh_key::{Certificate, SshSig};
|
|
use tracing::trace;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct CertificateBody(pub Certificate);
|
|
|
|
// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body
|
|
#[async_trait]
|
|
impl<S> FromRequest<S, BoxBody> for CertificateBody
|
|
where
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = ApiError;
|
|
|
|
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
|
|
let body = String::from_request(req, state)
|
|
.await
|
|
.context("failed to extract body")?;
|
|
|
|
let cert = Certificate::from_openssh(&body)
|
|
.with_context(|| format!("failed to parse '{}'", body))
|
|
.map_err(ApiError::ParseCertificate)?;
|
|
trace!(%body, "extracted certificate");
|
|
Ok(Self(cert))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct SignatureBody(pub SshSig);
|
|
|
|
#[async_trait]
|
|
impl<S> FromRequest<S, BoxBody> for SignatureBody
|
|
where
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = ApiError;
|
|
|
|
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
|
|
let body = String::from_request(req, state)
|
|
.await
|
|
.context("failed to extract body")?;
|
|
|
|
let sig = SshSig::from_pem(&body)
|
|
.with_context(|| format!("failed to parse '{}'", body))
|
|
.map_err(ApiError::ParseSignature)?;
|
|
trace!(%body, "extracted signature");
|
|
Ok(Self(sig))
|
|
}
|
|
}
|
|
|
|
pub struct JWTString(String);
|
|
|
|
impl From<String> for JWTString {
|
|
fn from(s: String) -> Self {
|
|
Self(s)
|
|
}
|
|
}
|
|
|
|
// TODO: be generic over ApiState -> AsRef<Target=Hs256>, AsRef<Target=A> where A: AlgorithmExt
|
|
#[derive(Debug)]
|
|
pub struct JWTAuthenticated<
|
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
|
> where
|
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
|
{
|
|
pub data: T,
|
|
_marker: PhantomData<Q>,
|
|
}
|
|
|
|
impl<
|
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
|
> JWTAuthenticated<T, Q>
|
|
where
|
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
|
{
|
|
pub fn new(data: T) -> Self {
|
|
Self {
|
|
data,
|
|
_marker: Default::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl<
|
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
|
> FromRequestParts<ApiState> for JWTAuthenticated<T, Q>
|
|
where
|
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
|
{
|
|
type Rejection = ApiError;
|
|
|
|
async fn from_request_parts(
|
|
parts: &mut axum::http::request::Parts,
|
|
state: &ApiState,
|
|
) -> Result<Self, Self::Rejection> {
|
|
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
|
|
let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?;
|
|
let verified: Token<T> = Hs256
|
|
.validate_integrity(&token, &state.jwt_key)
|
|
.map_err(ApiError::JWTVerify)?;
|
|
Ok(Self {
|
|
data: verified.claims().custom.clone(),
|
|
_marker: Default::default(),
|
|
})
|
|
}
|
|
}
|