Sophisticated type magic for JWT extraction #2

Open
shimun wants to merge 2 commits from extract into master
3 changed files with 137 additions and 11 deletions

5
.woodpecker.yml Normal file
View File

@ -0,0 +1,5 @@
pipeline:
test:
image: rust
commands:
- cargo test

View File

@ -10,6 +10,7 @@ use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert};
use crate::env_key;
use anyhow::Context;
use axum::body;
use axum::extract::rejection::QueryRejection;
use axum::extract::{Query, State};
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
@ -29,7 +30,7 @@ use tower::ServiceBuilder;
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
use tracing::{debug, info, trace};
use self::extract::{CertificateBody, SignatureBody};
use self::extract::{AsJWTVerifier, CertificateBody, JWTAuthenticated, JWTString, SignatureBody};
#[derive(Parser)]
pub struct ApiArgs {
@ -86,6 +87,13 @@ struct ApiState {
jwt_key: Hs256Key,
}
impl AsJWTVerifier for ApiState {
type Algo = Hs256;
fn as_secret(&self) -> &<Self::Algo as jwt_compact::Algorithm>::VerifyingKey {
&self.jwt_key
}
}
impl ApiState {
async fn new(
cert_dir: impl AsRef<path::Path>,
@ -177,6 +185,12 @@ pub enum ApiError {
AuthenticationRequired(String),
#[error("invalid ssh signature")]
InvalidSignature,
#[error("invalid jwt")]
JWTVerify(#[from] jwt_compact::ValidationError),
#[error("invalid jwt")]
JWTParse(#[from] jwt_compact::ParseError),
#[error("{0}")]
Query(#[from] QueryRejection),
}
type ApiResult<T> = Result<T, ApiError>;
@ -221,7 +235,7 @@ async fn list_certs(
))
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "aud", rename = "get")]
struct AuthClaims {
identifier: String,
@ -323,22 +337,28 @@ struct PostCertsQuery {
challenge: String,
}
impl Into<JWTString> for Query<PostCertsQuery> {
fn into(self) -> JWTString {
self.0.challenge.into()
}
}
/// POST with signed challenge
async fn post_certs_identifier(
PostCertInfo { identifier }: PostCertInfo,
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
JWTAuthenticated {
data: AuthClaims {
identifier: authenticated_identifier,
},
..
}: JWTAuthenticated<AuthClaims, ApiState, Query<PostCertsQuery>, ApiError>,
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
SignatureBody(sig): SignatureBody,
) -> ApiResult<String> {
let certs = certs.lock().await;
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
let token: Token<AuthClaims> = Hs256
.validate_integrity(
&UntrustedToken::new(&challenge).context("jwt parse")?,
&jwt_key,
)
.map_err(|_| ApiError::InvalidSignature)?;
if token.claims().custom.identifier != identifier {
if authenticated_identifier != identifier {
return Err(ApiError::InvalidSignature);
}
let pubkey: PublicKey = cert.public_key().clone().into();
@ -537,6 +557,9 @@ mod tests {
identifier: "test_cert".into(),
},
State(state.clone()),
JWTAuthenticated::from(AuthClaims {
identifier: "test_cert".into(),
}),
Query(PostCertsQuery { challenge }),
SignatureBody(sig),
)

View File

@ -1,10 +1,22 @@
use super::ApiError;
use anyhow::Context;
use axum::{
async_trait, body::BoxBody, extract::FromRequest, http::Request, response::IntoResponse,
async_trait,
body::BoxBody,
extract::{FromRequest, FromRequestParts},
http::Request,
response::IntoResponse,
};
use jwt_compact::{
alg::{SigningKey, VerifyingKey},
AlgorithmSignature, ParseError, Token, UntrustedToken, ValidationError,
};
use jwt_compact::{Algorithm, AlgorithmExt};
use serde::{de::DeserializeOwned, Serialize};
use ssh_key::{Certificate, SshSig};
use std::marker::PhantomData;
use std::{fmt::Debug, ops::Deref};
use tracing::trace;
use super::ApiError;
#[derive(Debug, Clone)]
pub struct CertificateBody(pub Certificate);
@ -49,3 +61,89 @@ where
Ok(Self(sig))
}
}
pub trait AsJWTVerifier: Send + Sync {
type Algo: Algorithm + Default;
fn as_secret(&self) -> &<Self::Algo as Algorithm>::VerifyingKey;
}
pub struct JWTString(String);
impl From<String> for JWTString {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Debug)]
pub struct JWTAuthenticated<T, S, Q, E> {
pub data: T,
_marker: PhantomData<(Q, S, E)>,
}
impl<T, S, Q, E> From<T> for JWTAuthenticated<T, S, Q, E> {
fn from(data: T) -> Self {
Self {
data,
_marker: Default::default(),
}
}
}
impl<T, S, Q, E> Deref for JWTAuthenticated<T, S, Q, E> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
S: AsJWTVerifier,
Q: FromRequestParts<S> + Debug + Into<JWTString>,
E: From<<Q as FromRequestParts<S>>::Rejection>
+ From<ValidationError>
+ From<ParseError>
+ Debug
+ Send
+ Sync,
> JWTAuthenticated<T, S, Q, E>
{
pub fn new(data: T) -> Self {
Self {
data,
_marker: Default::default(),
}
}
}
#[async_trait]
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
S: AsJWTVerifier,
Q: FromRequestParts<S> + Debug + Into<JWTString>,
E: From<<Q as FromRequestParts<S>>::Rejection>
+ From<ValidationError>
+ From<ParseError>
+ Debug
+ Send
+ Sync
+ IntoResponse,
> FromRequestParts<S> for JWTAuthenticated<T, S, Q, E>
{
type Rejection = E;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
let token = UntrustedToken::new(&token)?;
let verified: Token<T> =
<S::Algo as Default>::default().validate_integrity(&token, &state.as_secret())?;
Ok(Self {
data: verified.claims().custom.clone(),
_marker: Default::default(),
})
}
}