wip: JWTAuthenticated
TODO: move into api/ fix: test chore: move JWTAuthenticated into extract chore: fmt
This commit is contained in:
parent
dffbcceeba
commit
f47c57c1c0
@ -63,12 +63,8 @@ pub async fn read_pubkey_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<V
|
|||||||
while let Some(entry) = dir.next_entry().await? {
|
while let Some(entry) = dir.next_entry().await? {
|
||||||
//TODO: investigate why path().ends_with doesn't work
|
//TODO: investigate why path().ends_with doesn't work
|
||||||
let file_name = entry.file_name().into_string().unwrap();
|
let file_name = entry.file_name().into_string().unwrap();
|
||||||
if !file_name.ends_with(".pub") || file_name.ends_with("-cert.pub")
|
if !file_name.ends_with(".pub") || file_name.ends_with("-cert.pub") {
|
||||||
{
|
trace!("skipped {:?} due to missing '.pub' extension", entry.path());
|
||||||
trace!(
|
|
||||||
"skipped {:?} due to missing '.pub' extension",
|
|
||||||
entry.path()
|
|
||||||
);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let cert = load_public_key(entry.path()).await?;
|
let cert = load_public_key(entry.path()).await?;
|
||||||
@ -77,7 +73,6 @@ pub async fn read_pubkey_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<V
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(pubs)
|
Ok(pubs)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> {
|
fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> {
|
||||||
@ -158,5 +153,4 @@ pub async fn load_public_key(file: impl AsRef<Path> + Debug) -> anyhow::Result<O
|
|||||||
Ok(Some(PublicKey::from_openssh(&string_repr).with_context(
|
Ok(Some(PublicKey::from_openssh(&string_repr).with_context(
|
||||||
|| format!("parse {:?} as openssh public key", &file),
|
|| format!("parse {:?} as openssh public key", &file),
|
||||||
)?))
|
)?))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
|
|
||||||
use axum_extra::routing::TypedPath;
|
use axum_extra::routing::TypedPath;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -23,7 +21,7 @@ pub struct GetCertsPubkey {
|
|||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||||
pub struct CertIds {
|
pub struct CertIds {
|
||||||
pub ids: Vec<String>
|
pub ids: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(TypedPath, Deserialize)]
|
#[derive(TypedPath, Deserialize)]
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
mod extract;
|
mod extract;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::path::{self, PathBuf};
|
use std::path::{self, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
|
||||||
use axum::body;
|
use axum::body;
|
||||||
|
use axum::extract::rejection::QueryRejection;
|
||||||
use axum::extract::{Query, State};
|
use axum::extract::{Query, State};
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
|
|
||||||
use shell_escape::escape;
|
use shell_escape::escape;
|
||||||
use ssh_cert_dist_common::*;
|
use ssh_cert_dist_common::*;
|
||||||
|
|
||||||
@ -17,7 +22,7 @@ use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
|||||||
use axum_extra::routing::RouterExt;
|
use axum_extra::routing::RouterExt;
|
||||||
use clap::{Args, Parser};
|
use clap::{Args, Parser};
|
||||||
use jwt_compact::alg::{Hs256, Hs256Key};
|
use jwt_compact::alg::{Hs256, Hs256Key};
|
||||||
use jwt_compact::{AlgorithmExt, Token, UntrustedToken};
|
use jwt_compact::{AlgorithmExt};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use ssh_key::{Certificate, Fingerprint, PublicKey};
|
use ssh_key::{Certificate, Fingerprint, PublicKey};
|
||||||
@ -26,7 +31,7 @@ use tower::ServiceBuilder;
|
|||||||
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
||||||
use tracing::{debug, info, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
use self::extract::{CertificateBody, SignatureBody};
|
use self::extract::{CertificateBody, SignatureBody, JWTAuthenticated, JWTString};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
pub struct ApiArgs {
|
pub struct ApiArgs {
|
||||||
@ -74,7 +79,7 @@ impl Default for ApiArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct ApiState {
|
pub struct ApiState {
|
||||||
certs: Arc<Mutex<HashMap<String, Certificate>>>,
|
certs: Arc<Mutex<HashMap<String, Certificate>>>,
|
||||||
cert_dir: PathBuf,
|
cert_dir: PathBuf,
|
||||||
ca: PublicKey,
|
ca: PublicKey,
|
||||||
@ -179,6 +184,12 @@ pub enum ApiError {
|
|||||||
ParseSignature(anyhow::Error),
|
ParseSignature(anyhow::Error),
|
||||||
#[error("malformed ssh certificate: {0}")]
|
#[error("malformed ssh certificate: {0}")]
|
||||||
ParseCertificate(anyhow::Error),
|
ParseCertificate(anyhow::Error),
|
||||||
|
#[error("{0}")]
|
||||||
|
JWTParse(#[from] jwt_compact::ParseError),
|
||||||
|
#[error("{0}")]
|
||||||
|
JWTVerify(#[from] jwt_compact::ValidationError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Query(#[from] QueryRejection),
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiResult<T> = Result<T, ApiError>;
|
type ApiResult<T> = Result<T, ApiError>;
|
||||||
@ -221,7 +232,7 @@ async fn list_certs(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "aud", rename = "get")]
|
#[serde(tag = "aud", rename = "get")]
|
||||||
struct AuthClaims {
|
struct AuthClaims {
|
||||||
identifier: String,
|
identifier: String,
|
||||||
@ -298,7 +309,10 @@ struct CertInfo {
|
|||||||
|
|
||||||
impl From<&Certificate> for CertInfo {
|
impl From<&Certificate> for CertInfo {
|
||||||
fn from(cert: &Certificate) -> Self {
|
fn from(cert: &Certificate) -> Self {
|
||||||
let validity = cert.valid_before_time().duration_since(cert.valid_after_time()).unwrap_or(Duration::zero().to_std().unwrap());
|
let validity = cert
|
||||||
|
.valid_before_time()
|
||||||
|
.duration_since(cert.valid_after_time())
|
||||||
|
.unwrap_or(Duration::zero().to_std().unwrap());
|
||||||
let expiry = cert.valid_before_time().checked_add(validity).unwrap();
|
let expiry = cert.valid_before_time().checked_add(validity).unwrap();
|
||||||
let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap();
|
let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap();
|
||||||
let host_key = if cert.cert_type().is_host() {
|
let host_key = if cert.cert_type().is_host() {
|
||||||
@ -367,22 +381,26 @@ struct PostCertsQuery {
|
|||||||
challenge: String,
|
challenge: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Query<PostCertsQuery>> for JWTString {
|
||||||
|
fn from(Query(PostCertsQuery { challenge }): Query<PostCertsQuery>) -> Self {
|
||||||
|
Self::from(challenge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// POST with signed challenge
|
/// POST with signed challenge
|
||||||
async fn post_certs_identifier(
|
async fn post_certs_identifier(
|
||||||
PostCertInfo { identifier }: PostCertInfo,
|
PostCertInfo { identifier }: PostCertInfo,
|
||||||
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
|
State(ApiState { certs, .. }): State<ApiState>,
|
||||||
|
JWTAuthenticated {
|
||||||
|
data: auth_claims, ..
|
||||||
|
}: JWTAuthenticated<AuthClaims, Query<PostCertsQuery>>,
|
||||||
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
||||||
SignatureBody(sig): SignatureBody,
|
SignatureBody(sig): SignatureBody,
|
||||||
) -> ApiResult<String> {
|
) -> ApiResult<String> {
|
||||||
let certs = certs.lock().await;
|
let certs = certs.lock().await;
|
||||||
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
||||||
let token: Token<AuthClaims> = Hs256
|
if auth_claims.identifier != identifier {
|
||||||
.validate_integrity(
|
|
||||||
&UntrustedToken::new(&challenge).context("jwt parse")?,
|
|
||||||
&jwt_key,
|
|
||||||
)
|
|
||||||
.map_err(|_| ApiError::InvalidSignature)?;
|
|
||||||
if token.claims().custom.identifier != identifier {
|
|
||||||
return Err(ApiError::InvalidSignature);
|
return Err(ApiError::InvalidSignature);
|
||||||
}
|
}
|
||||||
let pubkey: PublicKey = cert.public_key().clone().into();
|
let pubkey: PublicKey = cert.public_key().clone().into();
|
||||||
@ -613,6 +631,9 @@ mod tests {
|
|||||||
identifier: "test_cert".into(),
|
identifier: "test_cert".into(),
|
||||||
},
|
},
|
||||||
State(state.clone()),
|
State(state.clone()),
|
||||||
|
JWTAuthenticated::new(AuthClaims {
|
||||||
|
identifier: "test_cert".into(),
|
||||||
|
}),
|
||||||
Query(PostCertsQuery { challenge }),
|
Query(PostCertsQuery { challenge }),
|
||||||
SignatureBody(sig),
|
SignatureBody(sig),
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,16 @@
|
|||||||
use super::ApiError;
|
use std::fmt::Debug;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use super::{ApiError, ApiState};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{async_trait, body::BoxBody, extract::FromRequest, http::Request};
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
body::BoxBody,
|
||||||
|
extract::{FromRequest, FromRequestParts},
|
||||||
|
http::Request,
|
||||||
|
};
|
||||||
|
use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken};
|
||||||
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use ssh_key::{Certificate, SshSig};
|
use ssh_key::{Certificate, SshSig};
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
|
||||||
@ -21,7 +31,8 @@ where
|
|||||||
.context("failed to extract body")?;
|
.context("failed to extract body")?;
|
||||||
|
|
||||||
let cert = Certificate::from_openssh(&body)
|
let cert = Certificate::from_openssh(&body)
|
||||||
.with_context(|| format!("failed to parse '{}'", body)).map_err(ApiError::ParseCertificate)?;
|
.with_context(|| format!("failed to parse '{}'", body))
|
||||||
|
.map_err(ApiError::ParseCertificate)?;
|
||||||
trace!(%body, "extracted certificate");
|
trace!(%body, "extracted certificate");
|
||||||
Ok(Self(cert))
|
Ok(Self(cert))
|
||||||
}
|
}
|
||||||
@ -42,8 +53,71 @@ where
|
|||||||
.await
|
.await
|
||||||
.context("failed to extract body")?;
|
.context("failed to extract body")?;
|
||||||
|
|
||||||
let sig = SshSig::from_pem(&body).with_context(|| format!("failed to parse '{}'", body)).map_err(ApiError::ParseSignature)?;
|
let sig = SshSig::from_pem(&body)
|
||||||
|
.with_context(|| format!("failed to parse '{}'", body))
|
||||||
|
.map_err(ApiError::ParseSignature)?;
|
||||||
trace!(%body, "extracted signature");
|
trace!(%body, "extracted signature");
|
||||||
Ok(Self(sig))
|
Ok(Self(sig))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct JWTString(String);
|
||||||
|
|
||||||
|
impl From<String> for JWTString {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: be generic over ApiState -> AsRef<Target=Hs256>, AsRef<Target=A> where A: AlgorithmExt
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct JWTAuthenticated<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
||||||
|
> where
|
||||||
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
||||||
|
{
|
||||||
|
pub data: T,
|
||||||
|
_marker: PhantomData<Q>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
||||||
|
> JWTAuthenticated<T, Q>
|
||||||
|
where
|
||||||
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
||||||
|
{
|
||||||
|
pub fn new(data: T) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
_marker: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
||||||
|
> FromRequestParts<ApiState> for JWTAuthenticated<T, Q>
|
||||||
|
where
|
||||||
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
||||||
|
{
|
||||||
|
type Rejection = ApiError;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut axum::http::request::Parts,
|
||||||
|
state: &ApiState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
|
||||||
|
let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?;
|
||||||
|
let verified: Token<T> = Hs256
|
||||||
|
.validate_integrity(&token, &state.jwt_key)
|
||||||
|
.map_err(ApiError::JWTVerify)?;
|
||||||
|
Ok(Self {
|
||||||
|
data: verified.claims().custom.clone(),
|
||||||
|
_marker: Default::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user