wip: JWTAuthenticated

TODO: move into api/

fix: test

chore: move JWTAuthenticated into extract

chore: fmt
This commit is contained in:
shimun 2023-03-10 12:15:48 +01:00
parent dffbcceeba
commit f47c57c1c0
Signed by: shimun
GPG Key ID: E0420647856EA39E
4 changed files with 115 additions and 28 deletions

View File

@ -63,12 +63,8 @@ pub async fn read_pubkey_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<V
while let Some(entry) = dir.next_entry().await? { while let Some(entry) = dir.next_entry().await? {
//TODO: investigate why path().ends_with doesn't work //TODO: investigate why path().ends_with doesn't work
let file_name = entry.file_name().into_string().unwrap(); let file_name = entry.file_name().into_string().unwrap();
if !file_name.ends_with(".pub") || file_name.ends_with("-cert.pub") if !file_name.ends_with(".pub") || file_name.ends_with("-cert.pub") {
{ trace!("skipped {:?} due to missing '.pub' extension", entry.path());
trace!(
"skipped {:?} due to missing '.pub' extension",
entry.path()
);
continue; continue;
} }
let cert = load_public_key(entry.path()).await?; let cert = load_public_key(entry.path()).await?;
@ -77,7 +73,6 @@ pub async fn read_pubkey_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<V
} }
} }
Ok(pubs) Ok(pubs)
} }
fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> { fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> {
@ -158,5 +153,4 @@ pub async fn load_public_key(file: impl AsRef<Path> + Debug) -> anyhow::Result<O
Ok(Some(PublicKey::from_openssh(&string_repr).with_context( Ok(Some(PublicKey::from_openssh(&string_repr).with_context(
|| format!("parse {:?} as openssh public key", &file), || format!("parse {:?} as openssh public key", &file),
)?)) )?))
} }

View File

@ -1,5 +1,3 @@
use axum_extra::routing::TypedPath; use axum_extra::routing::TypedPath;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -23,7 +21,7 @@ pub struct GetCertsPubkey {
#[derive(Debug, Serialize, Deserialize, Default)] #[derive(Debug, Serialize, Deserialize, Default)]
pub struct CertIds { pub struct CertIds {
pub ids: Vec<String> pub ids: Vec<String>,
} }
#[derive(TypedPath, Deserialize)] #[derive(TypedPath, Deserialize)]

View File

@ -1,15 +1,20 @@
mod extract; mod extract;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{self, PathBuf}; use std::path::{self, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::Context; use anyhow::Context;
use axum::body; use axum::body;
use axum::extract::rejection::QueryRejection;
use axum::extract::{Query, State}; use axum::extract::{Query, State};
use chrono::Duration; use chrono::Duration;
use shell_escape::escape; use shell_escape::escape;
use ssh_cert_dist_common::*; use ssh_cert_dist_common::*;
@ -17,7 +22,7 @@ use axum::{http::StatusCode, response::IntoResponse, Json, Router};
use axum_extra::routing::RouterExt; use axum_extra::routing::RouterExt;
use clap::{Args, Parser}; use clap::{Args, Parser};
use jwt_compact::alg::{Hs256, Hs256Key}; use jwt_compact::alg::{Hs256, Hs256Key};
use jwt_compact::{AlgorithmExt, Token, UntrustedToken}; use jwt_compact::{AlgorithmExt};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ssh_key::{Certificate, Fingerprint, PublicKey}; use ssh_key::{Certificate, Fingerprint, PublicKey};
@ -26,7 +31,7 @@ use tower::ServiceBuilder;
use tower_http::{trace::TraceLayer, ServiceBuilderExt}; use tower_http::{trace::TraceLayer, ServiceBuilderExt};
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
use self::extract::{CertificateBody, SignatureBody}; use self::extract::{CertificateBody, SignatureBody, JWTAuthenticated, JWTString};
#[derive(Parser)] #[derive(Parser)]
pub struct ApiArgs { pub struct ApiArgs {
@ -74,7 +79,7 @@ impl Default for ApiArgs {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ApiState { pub struct ApiState {
certs: Arc<Mutex<HashMap<String, Certificate>>>, certs: Arc<Mutex<HashMap<String, Certificate>>>,
cert_dir: PathBuf, cert_dir: PathBuf,
ca: PublicKey, ca: PublicKey,
@ -179,6 +184,12 @@ pub enum ApiError {
ParseSignature(anyhow::Error), ParseSignature(anyhow::Error),
#[error("malformed ssh certificate: {0}")] #[error("malformed ssh certificate: {0}")]
ParseCertificate(anyhow::Error), ParseCertificate(anyhow::Error),
#[error("{0}")]
JWTParse(#[from] jwt_compact::ParseError),
#[error("{0}")]
JWTVerify(#[from] jwt_compact::ValidationError),
#[error("{0}")]
Query(#[from] QueryRejection),
} }
type ApiResult<T> = Result<T, ApiError>; type ApiResult<T> = Result<T, ApiError>;
@ -221,7 +232,7 @@ async fn list_certs(
)) ))
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "aud", rename = "get")] #[serde(tag = "aud", rename = "get")]
struct AuthClaims { struct AuthClaims {
identifier: String, identifier: String,
@ -298,7 +309,10 @@ struct CertInfo {
impl From<&Certificate> for CertInfo { impl From<&Certificate> for CertInfo {
fn from(cert: &Certificate) -> Self { fn from(cert: &Certificate) -> Self {
let validity = cert.valid_before_time().duration_since(cert.valid_after_time()).unwrap_or(Duration::zero().to_std().unwrap()); let validity = cert
.valid_before_time()
.duration_since(cert.valid_after_time())
.unwrap_or(Duration::zero().to_std().unwrap());
let expiry = cert.valid_before_time().checked_add(validity).unwrap(); let expiry = cert.valid_before_time().checked_add(validity).unwrap();
let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap(); let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap();
let host_key = if cert.cert_type().is_host() { let host_key = if cert.cert_type().is_host() {
@ -367,22 +381,26 @@ struct PostCertsQuery {
challenge: String, challenge: String,
} }
impl From<Query<PostCertsQuery>> for JWTString {
fn from(Query(PostCertsQuery { challenge }): Query<PostCertsQuery>) -> Self {
Self::from(challenge)
}
}
/// POST with signed challenge /// POST with signed challenge
async fn post_certs_identifier( async fn post_certs_identifier(
PostCertInfo { identifier }: PostCertInfo, PostCertInfo { identifier }: PostCertInfo,
State(ApiState { certs, jwt_key, .. }): State<ApiState>, State(ApiState { certs, .. }): State<ApiState>,
JWTAuthenticated {
data: auth_claims, ..
}: JWTAuthenticated<AuthClaims, Query<PostCertsQuery>>,
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>, Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
SignatureBody(sig): SignatureBody, SignatureBody(sig): SignatureBody,
) -> ApiResult<String> { ) -> ApiResult<String> {
let certs = certs.lock().await; let certs = certs.lock().await;
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?; let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
let token: Token<AuthClaims> = Hs256 if auth_claims.identifier != identifier {
.validate_integrity(
&UntrustedToken::new(&challenge).context("jwt parse")?,
&jwt_key,
)
.map_err(|_| ApiError::InvalidSignature)?;
if token.claims().custom.identifier != identifier {
return Err(ApiError::InvalidSignature); return Err(ApiError::InvalidSignature);
} }
let pubkey: PublicKey = cert.public_key().clone().into(); let pubkey: PublicKey = cert.public_key().clone().into();
@ -613,6 +631,9 @@ mod tests {
identifier: "test_cert".into(), identifier: "test_cert".into(),
}, },
State(state.clone()), State(state.clone()),
JWTAuthenticated::new(AuthClaims {
identifier: "test_cert".into(),
}),
Query(PostCertsQuery { challenge }), Query(PostCertsQuery { challenge }),
SignatureBody(sig), SignatureBody(sig),
) )

View File

@ -1,6 +1,16 @@
use super::ApiError; use std::fmt::Debug;
use std::marker::PhantomData;
use super::{ApiError, ApiState};
use anyhow::Context; use anyhow::Context;
use axum::{async_trait, body::BoxBody, extract::FromRequest, http::Request}; use axum::{
async_trait,
body::BoxBody,
extract::{FromRequest, FromRequestParts},
http::Request,
};
use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken};
use serde::{de::DeserializeOwned, Serialize};
use ssh_key::{Certificate, SshSig}; use ssh_key::{Certificate, SshSig};
use tracing::trace; use tracing::trace;
@ -21,7 +31,8 @@ where
.context("failed to extract body")?; .context("failed to extract body")?;
let cert = Certificate::from_openssh(&body) let cert = Certificate::from_openssh(&body)
.with_context(|| format!("failed to parse '{}'", body)).map_err(ApiError::ParseCertificate)?; .with_context(|| format!("failed to parse '{}'", body))
.map_err(ApiError::ParseCertificate)?;
trace!(%body, "extracted certificate"); trace!(%body, "extracted certificate");
Ok(Self(cert)) Ok(Self(cert))
} }
@ -42,8 +53,71 @@ where
.await .await
.context("failed to extract body")?; .context("failed to extract body")?;
let sig = SshSig::from_pem(&body).with_context(|| format!("failed to parse '{}'", body)).map_err(ApiError::ParseSignature)?; let sig = SshSig::from_pem(&body)
.with_context(|| format!("failed to parse '{}'", body))
.map_err(ApiError::ParseSignature)?;
trace!(%body, "extracted signature"); trace!(%body, "extracted signature");
Ok(Self(sig)) Ok(Self(sig))
} }
} }
pub struct JWTString(String);
impl From<String> for JWTString {
fn from(s: String) -> Self {
Self(s)
}
}
// TODO: be generic over ApiState -> AsRef<Target=Hs256>, AsRef<Target=A> where A: AlgorithmExt
#[derive(Debug)]
pub struct JWTAuthenticated<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
pub data: T,
_marker: PhantomData<Q>,
}
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> JWTAuthenticated<T, Q>
where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
pub fn new(data: T) -> Self {
Self {
data,
_marker: Default::default(),
}
}
}
#[async_trait]
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> FromRequestParts<ApiState> for JWTAuthenticated<T, Q>
where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
type Rejection = ApiError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &ApiState,
) -> Result<Self, Self::Rejection> {
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?;
let verified: Token<T> = Hs256
.validate_integrity(&token, &state.jwt_key)
.map_err(ApiError::JWTVerify)?;
Ok(Self {
data: verified.claims().custom.clone(),
_marker: Default::default(),
})
}
}