Sophisticated type magic for JWT extraction #2
5
.woodpecker.yml
Normal file
5
.woodpecker.yml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
pipeline:
|
||||||
|
test:
|
||||||
|
image: rust
|
||||||
|
commands:
|
||||||
|
- cargo test
|
41
src/api.rs
41
src/api.rs
@ -10,6 +10,7 @@ use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert};
|
|||||||
use crate::env_key;
|
use crate::env_key;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::body;
|
use axum::body;
|
||||||
|
use axum::extract::rejection::QueryRejection;
|
||||||
use axum::extract::{Query, State};
|
use axum::extract::{Query, State};
|
||||||
|
|
||||||
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
||||||
@ -29,7 +30,7 @@ use tower::ServiceBuilder;
|
|||||||
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
||||||
use tracing::{debug, info, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
use self::extract::{CertificateBody, SignatureBody};
|
use self::extract::{AsJWTVerifier, CertificateBody, JWTAuthenticated, JWTString, SignatureBody};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
pub struct ApiArgs {
|
pub struct ApiArgs {
|
||||||
@ -86,6 +87,13 @@ struct ApiState {
|
|||||||
jwt_key: Hs256Key,
|
jwt_key: Hs256Key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsJWTVerifier for ApiState {
|
||||||
|
type Algo = Hs256;
|
||||||
|
fn as_secret(&self) -> &<Self::Algo as jwt_compact::Algorithm>::VerifyingKey {
|
||||||
|
&self.jwt_key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ApiState {
|
impl ApiState {
|
||||||
async fn new(
|
async fn new(
|
||||||
cert_dir: impl AsRef<path::Path>,
|
cert_dir: impl AsRef<path::Path>,
|
||||||
@ -177,6 +185,12 @@ pub enum ApiError {
|
|||||||
AuthenticationRequired(String),
|
AuthenticationRequired(String),
|
||||||
#[error("invalid ssh signature")]
|
#[error("invalid ssh signature")]
|
||||||
InvalidSignature,
|
InvalidSignature,
|
||||||
|
#[error("invalid jwt")]
|
||||||
|
JWTVerify(#[from] jwt_compact::ValidationError),
|
||||||
|
#[error("invalid jwt")]
|
||||||
|
JWTParse(#[from] jwt_compact::ParseError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Query(#[from] QueryRejection),
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiResult<T> = Result<T, ApiError>;
|
type ApiResult<T> = Result<T, ApiError>;
|
||||||
@ -221,7 +235,7 @@ async fn list_certs(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "aud", rename = "get")]
|
#[serde(tag = "aud", rename = "get")]
|
||||||
struct AuthClaims {
|
struct AuthClaims {
|
||||||
identifier: String,
|
identifier: String,
|
||||||
@ -323,22 +337,28 @@ struct PostCertsQuery {
|
|||||||
challenge: String,
|
challenge: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Into<JWTString> for Query<PostCertsQuery> {
|
||||||
|
fn into(self) -> JWTString {
|
||||||
|
self.0.challenge.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// POST with signed challenge
|
/// POST with signed challenge
|
||||||
async fn post_certs_identifier(
|
async fn post_certs_identifier(
|
||||||
PostCertInfo { identifier }: PostCertInfo,
|
PostCertInfo { identifier }: PostCertInfo,
|
||||||
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
|
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
|
||||||
|
JWTAuthenticated {
|
||||||
|
data: AuthClaims {
|
||||||
|
identifier: authenticated_identifier,
|
||||||
|
},
|
||||||
|
..
|
||||||
|
}: JWTAuthenticated<AuthClaims, ApiState, Query<PostCertsQuery>, ApiError>,
|
||||||
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
||||||
SignatureBody(sig): SignatureBody,
|
SignatureBody(sig): SignatureBody,
|
||||||
) -> ApiResult<String> {
|
) -> ApiResult<String> {
|
||||||
let certs = certs.lock().await;
|
let certs = certs.lock().await;
|
||||||
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
||||||
let token: Token<AuthClaims> = Hs256
|
if authenticated_identifier != identifier {
|
||||||
.validate_integrity(
|
|
||||||
&UntrustedToken::new(&challenge).context("jwt parse")?,
|
|
||||||
&jwt_key,
|
|
||||||
)
|
|
||||||
.map_err(|_| ApiError::InvalidSignature)?;
|
|
||||||
if token.claims().custom.identifier != identifier {
|
|
||||||
return Err(ApiError::InvalidSignature);
|
return Err(ApiError::InvalidSignature);
|
||||||
}
|
}
|
||||||
let pubkey: PublicKey = cert.public_key().clone().into();
|
let pubkey: PublicKey = cert.public_key().clone().into();
|
||||||
@ -537,6 +557,9 @@ mod tests {
|
|||||||
identifier: "test_cert".into(),
|
identifier: "test_cert".into(),
|
||||||
},
|
},
|
||||||
State(state.clone()),
|
State(state.clone()),
|
||||||
|
JWTAuthenticated::from(AuthClaims {
|
||||||
|
identifier: "test_cert".into(),
|
||||||
|
}),
|
||||||
Query(PostCertsQuery { challenge }),
|
Query(PostCertsQuery { challenge }),
|
||||||
SignatureBody(sig),
|
SignatureBody(sig),
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,22 @@
|
|||||||
|
use super::ApiError;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait, body::BoxBody, extract::FromRequest, http::Request, response::IntoResponse,
|
async_trait,
|
||||||
|
body::BoxBody,
|
||||||
|
extract::{FromRequest, FromRequestParts},
|
||||||
|
http::Request,
|
||||||
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
|
use jwt_compact::{
|
||||||
|
alg::{SigningKey, VerifyingKey},
|
||||||
|
AlgorithmSignature, ParseError, Token, UntrustedToken, ValidationError,
|
||||||
|
};
|
||||||
|
use jwt_compact::{Algorithm, AlgorithmExt};
|
||||||
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use ssh_key::{Certificate, SshSig};
|
use ssh_key::{Certificate, SshSig};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
use std::{fmt::Debug, ops::Deref};
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
use super::ApiError;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CertificateBody(pub Certificate);
|
pub struct CertificateBody(pub Certificate);
|
||||||
@ -49,3 +61,89 @@ where
|
|||||||
Ok(Self(sig))
|
Ok(Self(sig))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait AsJWTVerifier: Send + Sync {
|
||||||
|
type Algo: Algorithm + Default;
|
||||||
|
fn as_secret(&self) -> &<Self::Algo as Algorithm>::VerifyingKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct JWTString(String);
|
||||||
|
|
||||||
|
impl From<String> for JWTString {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct JWTAuthenticated<T, S, Q, E> {
|
||||||
|
pub data: T,
|
||||||
|
_marker: PhantomData<(Q, S, E)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S, Q, E> From<T> for JWTAuthenticated<T, S, Q, E> {
|
||||||
|
fn from(data: T) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
_marker: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, S, Q, E> Deref for JWTAuthenticated<T, S, Q, E> {
|
||||||
|
type Target = T;
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
S: AsJWTVerifier,
|
||||||
|
Q: FromRequestParts<S> + Debug + Into<JWTString>,
|
||||||
|
E: From<<Q as FromRequestParts<S>>::Rejection>
|
||||||
|
+ From<ValidationError>
|
||||||
|
+ From<ParseError>
|
||||||
|
+ Debug
|
||||||
|
+ Send
|
||||||
|
+ Sync,
|
||||||
|
> JWTAuthenticated<T, S, Q, E>
|
||||||
|
{
|
||||||
|
pub fn new(data: T) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
_marker: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
S: AsJWTVerifier,
|
||||||
|
Q: FromRequestParts<S> + Debug + Into<JWTString>,
|
||||||
|
E: From<<Q as FromRequestParts<S>>::Rejection>
|
||||||
|
+ From<ValidationError>
|
||||||
|
+ From<ParseError>
|
||||||
|
+ Debug
|
||||||
|
+ Send
|
||||||
|
+ Sync
|
||||||
|
+ IntoResponse,
|
||||||
|
> FromRequestParts<S> for JWTAuthenticated<T, S, Q, E>
|
||||||
|
{
|
||||||
|
type Rejection = E;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut axum::http::request::Parts,
|
||||||
|
state: &S,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
|
||||||
|
let token = UntrustedToken::new(&token)?;
|
||||||
|
let verified: Token<T> =
|
||||||
|
<S::Algo as Default>::default().validate_integrity(&token, &state.as_secret())?;
|
||||||
|
Ok(Self {
|
||||||
|
data: verified.claims().custom.clone(),
|
||||||
|
_marker: Default::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user