From f47c57c1c072f03027ee7232f78c079f65784dca Mon Sep 17 00:00:00 2001 From: shimun Date: Fri, 10 Mar 2023 12:15:48 +0100 Subject: [PATCH] wip: JWTAuthenticated TODO: move into api/ fix: test chore: move JWTAuthenticated into extract chore: fmt --- common/src/certs.rs | 10 +---- common/src/routes.rs | 4 +- server/src/api.rs | 47 +++++++++++++++------- server/src/api/extract.rs | 82 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 115 insertions(+), 28 deletions(-) diff --git a/common/src/certs.rs b/common/src/certs.rs index 5dd761c..b0a4575 100644 --- a/common/src/certs.rs +++ b/common/src/certs.rs @@ -63,12 +63,8 @@ pub async fn read_pubkey_dir(path: impl AsRef + Debug) -> anyhow::Result + Debug) -> anyhow::Result) -> anyhow::Result { @@ -158,5 +153,4 @@ pub async fn load_public_key(file: impl AsRef + Debug) -> anyhow::Result + pub ids: Vec, } #[derive(TypedPath, Deserialize)] diff --git a/server/src/api.rs b/server/src/api.rs index f4423ab..576633b 100644 --- a/server/src/api.rs +++ b/server/src/api.rs @@ -1,15 +1,20 @@ mod extract; use std::collections::HashMap; +use std::fmt::Debug; + use std::net::SocketAddr; use std::path::{self, PathBuf}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use anyhow::Context; + use axum::body; +use axum::extract::rejection::QueryRejection; use axum::extract::{Query, State}; use chrono::Duration; + use shell_escape::escape; use ssh_cert_dist_common::*; @@ -17,7 +22,7 @@ use axum::{http::StatusCode, response::IntoResponse, Json, Router}; use axum_extra::routing::RouterExt; use clap::{Args, Parser}; use jwt_compact::alg::{Hs256, Hs256Key}; -use jwt_compact::{AlgorithmExt, Token, UntrustedToken}; +use jwt_compact::{AlgorithmExt}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use ssh_key::{Certificate, Fingerprint, PublicKey}; @@ -26,7 +31,7 @@ use tower::ServiceBuilder; use tower_http::{trace::TraceLayer, ServiceBuilderExt}; use tracing::{debug, info, trace}; -use self::extract::{CertificateBody, SignatureBody}; +use self::extract::{CertificateBody, SignatureBody, JWTAuthenticated, JWTString}; #[derive(Parser)] pub struct ApiArgs { @@ -74,7 +79,7 @@ impl Default for ApiArgs { } #[derive(Debug, Clone)] -struct ApiState { +pub struct ApiState { certs: Arc>>, cert_dir: PathBuf, ca: PublicKey, @@ -179,6 +184,12 @@ pub enum ApiError { ParseSignature(anyhow::Error), #[error("malformed ssh certificate: {0}")] ParseCertificate(anyhow::Error), + #[error("{0}")] + JWTParse(#[from] jwt_compact::ParseError), + #[error("{0}")] + JWTVerify(#[from] jwt_compact::ValidationError), + #[error("{0}")] + Query(#[from] QueryRejection), } type ApiResult = Result; @@ -221,7 +232,7 @@ async fn list_certs( )) } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "aud", rename = "get")] struct AuthClaims { identifier: String, @@ -298,7 +309,10 @@ struct CertInfo { impl From<&Certificate> for CertInfo { fn from(cert: &Certificate) -> Self { - let validity = cert.valid_before_time().duration_since(cert.valid_after_time()).unwrap_or(Duration::zero().to_std().unwrap()); + let validity = cert + .valid_before_time() + .duration_since(cert.valid_after_time()) + .unwrap_or(Duration::zero().to_std().unwrap()); let expiry = cert.valid_before_time().checked_add(validity).unwrap(); let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap(); let host_key = if cert.cert_type().is_host() { @@ -367,22 +381,26 @@ struct PostCertsQuery { challenge: String, } +impl From> for JWTString { + fn from(Query(PostCertsQuery { challenge }): Query) -> Self { + Self::from(challenge) + } +} + + /// POST with signed challenge async fn post_certs_identifier( PostCertInfo { identifier }: PostCertInfo, - State(ApiState { certs, jwt_key, .. }): State, + State(ApiState { certs, .. }): State, + JWTAuthenticated { + data: auth_claims, .. + }: JWTAuthenticated>, Query(PostCertsQuery { challenge }): Query, SignatureBody(sig): SignatureBody, ) -> ApiResult { let certs = certs.lock().await; let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?; - let token: Token = Hs256 - .validate_integrity( - &UntrustedToken::new(&challenge).context("jwt parse")?, - &jwt_key, - ) - .map_err(|_| ApiError::InvalidSignature)?; - if token.claims().custom.identifier != identifier { + if auth_claims.identifier != identifier { return Err(ApiError::InvalidSignature); } let pubkey: PublicKey = cert.public_key().clone().into(); @@ -613,6 +631,9 @@ mod tests { identifier: "test_cert".into(), }, State(state.clone()), + JWTAuthenticated::new(AuthClaims { + identifier: "test_cert".into(), + }), Query(PostCertsQuery { challenge }), SignatureBody(sig), ) diff --git a/server/src/api/extract.rs b/server/src/api/extract.rs index 8e1731f..4ea447e 100644 --- a/server/src/api/extract.rs +++ b/server/src/api/extract.rs @@ -1,6 +1,16 @@ -use super::ApiError; +use std::fmt::Debug; +use std::marker::PhantomData; + +use super::{ApiError, ApiState}; use anyhow::Context; -use axum::{async_trait, body::BoxBody, extract::FromRequest, http::Request}; +use axum::{ + async_trait, + body::BoxBody, + extract::{FromRequest, FromRequestParts}, + http::Request, +}; +use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken}; +use serde::{de::DeserializeOwned, Serialize}; use ssh_key::{Certificate, SshSig}; use tracing::trace; @@ -21,7 +31,8 @@ where .context("failed to extract body")?; let cert = Certificate::from_openssh(&body) - .with_context(|| format!("failed to parse '{}'", body)).map_err(ApiError::ParseCertificate)?; + .with_context(|| format!("failed to parse '{}'", body)) + .map_err(ApiError::ParseCertificate)?; trace!(%body, "extracted certificate"); Ok(Self(cert)) } @@ -42,8 +53,71 @@ where .await .context("failed to extract body")?; - let sig = SshSig::from_pem(&body).with_context(|| format!("failed to parse '{}'", body)).map_err(ApiError::ParseSignature)?; + let sig = SshSig::from_pem(&body) + .with_context(|| format!("failed to parse '{}'", body)) + .map_err(ApiError::ParseSignature)?; trace!(%body, "extracted signature"); Ok(Self(sig)) } } + +pub struct JWTString(String); + +impl From for JWTString { + fn from(s: String) -> Self { + Self(s) + } +} + +// TODO: be generic over ApiState -> AsRef, AsRef where A: AlgorithmExt +#[derive(Debug)] +pub struct JWTAuthenticated< + T: Serialize + DeserializeOwned + Clone + Debug, + Q: FromRequestParts + Debug + Into, +> where + ApiError: From<>::Rejection>, +{ + pub data: T, + _marker: PhantomData, +} + +impl< + T: Serialize + DeserializeOwned + Clone + Debug, + Q: FromRequestParts + Debug + Into, + > JWTAuthenticated +where + ApiError: From<>::Rejection>, +{ + pub fn new(data: T) -> Self { + Self { + data, + _marker: Default::default(), + } + } +} + +#[async_trait] +impl< + T: Serialize + DeserializeOwned + Clone + Debug, + Q: FromRequestParts + Debug + Into, + > FromRequestParts for JWTAuthenticated +where + ApiError: From<>::Rejection>, +{ + type Rejection = ApiError; + + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &ApiState, + ) -> Result { + let JWTString(token) = Q::from_request_parts(parts, state).await?.into(); + let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?; + let verified: Token = Hs256 + .validate_integrity(&token, &state.jwt_key) + .map_err(ApiError::JWTVerify)?; + Ok(Self { + data: verified.claims().custom.clone(), + _marker: Default::default(), + }) + } +}