573 lines
17 KiB
Rust
573 lines
17 KiB
Rust
mod extract;
|
|
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::path::{self, PathBuf};
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, SystemTime};
|
|
|
|
use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert};
|
|
use crate::env_key;
|
|
use anyhow::Context;
|
|
use axum::body;
|
|
use axum::extract::rejection::QueryRejection;
|
|
use axum::extract::{Query, State};
|
|
|
|
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
|
use axum_extra::routing::{
|
|
RouterExt, // for `Router::typed_*`
|
|
TypedPath,
|
|
};
|
|
use clap::{Args, Parser};
|
|
use jwt_compact::alg::{Hs256, Hs256Key};
|
|
use jwt_compact::{AlgorithmExt, Token, UntrustedToken};
|
|
use rand::{thread_rng, Rng};
|
|
use serde::{Deserialize, Serialize};
|
|
use ssh_key::private::Ed25519Keypair;
|
|
use ssh_key::{certificate, Certificate, PrivateKey, PublicKey};
|
|
use tokio::sync::Mutex;
|
|
use tower::ServiceBuilder;
|
|
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
|
use tracing::{debug, info, trace};
|
|
|
|
use self::extract::{AsJWTVerifier, CertificateBody, JWTAuthenticated, JWTString, SignatureBody};
|
|
|
|
#[derive(Parser)]
|
|
pub struct ApiArgs {
|
|
#[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))]
|
|
address: SocketAddr,
|
|
#[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))]
|
|
cert_dir: PathBuf,
|
|
#[clap(flatten)]
|
|
validation_args: CertificateValidationArgs,
|
|
/// CA public key
|
|
#[clap(long = "ca", env = env_key!("CA"))]
|
|
ca: PathBuf,
|
|
}
|
|
|
|
#[derive(Debug, Args, Copy, Clone)]
|
|
pub struct CertificateValidationArgs {
|
|
/// Check whether an certificate update contains an greater serial number than the existing
|
|
/// certificate
|
|
#[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))]
|
|
validate_greater_expiry: bool,
|
|
/// Check whether an certificate update contains an expiry date further in the future than the
|
|
/// existing certificate
|
|
#[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))]
|
|
validate_greater_serial: bool,
|
|
}
|
|
|
|
impl Default for CertificateValidationArgs {
|
|
fn default() -> Self {
|
|
Self {
|
|
validate_greater_expiry: true,
|
|
validate_greater_serial: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ApiArgs {
|
|
fn default() -> Self {
|
|
Self {
|
|
address: SocketAddr::from(([127, 0, 0, 1], 3000)),
|
|
cert_dir: "certs".into(),
|
|
ca: "certs/ca.pub".into(),
|
|
validation_args: Default::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct ApiState {
|
|
certs: Arc<Mutex<HashMap<String, Certificate>>>,
|
|
cert_dir: PathBuf,
|
|
ca: PublicKey,
|
|
client_auth: bool,
|
|
validation_args: CertificateValidationArgs,
|
|
jwt_key: Hs256Key,
|
|
}
|
|
|
|
impl AsJWTVerifier for ApiState {
|
|
type Algo = Hs256;
|
|
fn as_secret(&self) -> &<Self::Algo as jwt_compact::Algorithm>::VerifyingKey {
|
|
&self.jwt_key
|
|
}
|
|
}
|
|
|
|
impl ApiState {
|
|
async fn new(
|
|
cert_dir: impl AsRef<path::Path>,
|
|
ca_file: impl AsRef<path::Path>,
|
|
validation_args: CertificateValidationArgs,
|
|
) -> anyhow::Result<Self> {
|
|
let ca = read_pubkey(ca_file.as_ref()).await?;
|
|
let certs = read_certs(&ca, cert_dir.as_ref()).await?;
|
|
Ok(Self {
|
|
certs: Arc::new(Mutex::new(
|
|
certs
|
|
.into_iter()
|
|
.map(|cert| (cert.key_id().to_string(), cert))
|
|
.collect(),
|
|
)),
|
|
cert_dir: cert_dir.as_ref().into(),
|
|
ca,
|
|
client_auth: false,
|
|
validation_args,
|
|
jwt_key: Hs256Key::new(thread_rng().gen::<[u8; 16]>()),
|
|
})
|
|
}
|
|
}
|
|
|
|
pub async fn run(
|
|
ApiArgs {
|
|
address,
|
|
cert_dir,
|
|
ca,
|
|
validation_args,
|
|
}: ApiArgs,
|
|
) -> anyhow::Result<()> {
|
|
let state = ApiState::new(&cert_dir, &ca, validation_args).await?;
|
|
|
|
#[cfg(feature = "reload")]
|
|
{
|
|
let state = state.clone();
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
|
if let Ok(certs) = read_certs(&state.ca, &state.cert_dir).await {
|
|
*state.certs.lock().await = certs
|
|
.into_iter()
|
|
.map(|cert| (cert.key_id().to_string(), cert))
|
|
.collect();
|
|
trace!("reloaded certs");
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
let app = Router::new()
|
|
.typed_get(get_certs_identifier)
|
|
.typed_put(put_cert_update)
|
|
.typed_get(get_cert_info)
|
|
.typed_post(post_certs_identifier);
|
|
#[cfg(feature = "index")]
|
|
let app = app.typed_get(list_certs);
|
|
let app = app
|
|
.fallback(fallback_404)
|
|
.layer(ServiceBuilder::new().map_request_body(body::boxed))
|
|
.layer(TraceLayer::new_for_http())
|
|
.with_state(state);
|
|
|
|
// run our app with hyper
|
|
// `axum::Server` is a re-export of `hyper::Server`
|
|
debug!("listening on {}", address);
|
|
axum::Server::bind(&address)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum ApiError {
|
|
#[error("internal server error")]
|
|
Internal(#[from] anyhow::Error),
|
|
#[error("certificate not found")]
|
|
CertificateNotFound,
|
|
#[error("invalid certificate")]
|
|
CertificateInvalid,
|
|
#[error("serial must be greater than {1}")]
|
|
LowSerial(u64, u64),
|
|
#[error("expiry date must be greater than {0:?}")]
|
|
InsufficientValidity(SystemTime),
|
|
#[error("authentication required")]
|
|
AuthenticationRequired(String),
|
|
#[error("invalid ssh signature")]
|
|
InvalidSignature,
|
|
#[error("invalid jwt")]
|
|
JWTVerify(#[from] jwt_compact::ValidationError),
|
|
#[error("invalid jwt")]
|
|
JWTParse(#[from] jwt_compact::ParseError),
|
|
#[error("{0}")]
|
|
Query(#[from] QueryRejection),
|
|
}
|
|
|
|
type ApiResult<T> = Result<T, ApiError>;
|
|
|
|
impl IntoResponse for ApiError {
|
|
fn into_response(self) -> axum::response::Response {
|
|
(
|
|
match self {
|
|
Self::CertificateNotFound => StatusCode::NOT_FOUND,
|
|
Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST,
|
|
Self::AuthenticationRequired(challenge) => {
|
|
return (StatusCode::UNAUTHORIZED, challenge).into_response()
|
|
}
|
|
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
|
},
|
|
self.to_string(),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
async fn fallback_404() -> ApiResult<()> {
|
|
Err(ApiError::CertificateNotFound)
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs")]
|
|
pub struct CertList;
|
|
|
|
async fn list_certs(
|
|
_: CertList,
|
|
State(ApiState { certs, .. }): State<ApiState>,
|
|
) -> ApiResult<Json<Vec<String>>> {
|
|
Ok(Json(
|
|
certs
|
|
.lock()
|
|
.await
|
|
.values()
|
|
.into_iter()
|
|
.map(|cert| cert.key_id().to_string())
|
|
.collect(),
|
|
))
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "aud", rename = "get")]
|
|
struct AuthClaims {
|
|
identifier: String,
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs/:identifier")]
|
|
pub struct GetCert {
|
|
pub identifier: String,
|
|
}
|
|
|
|
/// Retrieve an certificate for identifier
|
|
/// TODO: add option to require auth
|
|
/// return Unauthorized with an challenge
|
|
/// upon which the client will ssh-keysign
|
|
/// the challenge an issue an post request
|
|
async fn get_certs_identifier(
|
|
GetCert { identifier }: GetCert,
|
|
State(ApiState {
|
|
certs,
|
|
jwt_key,
|
|
client_auth,
|
|
..
|
|
}): State<ApiState>,
|
|
) -> ApiResult<String> {
|
|
use jwt_compact::{AlgorithmExt, Claims, Header, TimeOptions};
|
|
|
|
if client_auth {
|
|
let claims = Claims::new(AuthClaims { identifier })
|
|
.set_duration(&TimeOptions::default(), chrono::Duration::seconds(120));
|
|
let challenge = Hs256
|
|
.compact_token(Header::default(), &claims, &jwt_key)
|
|
.context("jwt sign")?;
|
|
return Err(ApiError::AuthenticationRequired(challenge));
|
|
}
|
|
let certs = certs.lock().await;
|
|
let cert = certs
|
|
.get(&identifier)
|
|
.ok_or(ApiError::CertificateNotFound)?;
|
|
Ok(cert.to_openssh().context("to openssh")?)
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs/:identifier/info")]
|
|
pub struct GetCertInfo {
|
|
pub identifier: String,
|
|
}
|
|
|
|
#[cfg(feature = "info")]
|
|
#[derive(Debug, Serialize)]
|
|
struct CertInfo {
|
|
principals: Vec<String>,
|
|
ca: PublicKey,
|
|
identity: PublicKey,
|
|
key_id: String,
|
|
expiry: SystemTime,
|
|
}
|
|
|
|
impl From<&Certificate> for CertInfo {
|
|
fn from(cert: &Certificate) -> Self {
|
|
CertInfo {
|
|
principals: cert.valid_principals().to_vec(),
|
|
ca: cert.signature_key().clone().into(),
|
|
identity: cert.public_key().clone().into(),
|
|
key_id: cert.key_id().to_string(),
|
|
expiry: cert.valid_before_time(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "info")]
|
|
async fn get_cert_info(
|
|
GetCertInfo { identifier }: GetCertInfo,
|
|
State(ApiState { certs, .. }): State<ApiState>,
|
|
) -> ApiResult<Json<CertInfo>> {
|
|
let certs = certs.lock().await;
|
|
let cert = certs
|
|
.get(&identifier)
|
|
.ok_or(ApiError::CertificateNotFound)?;
|
|
Ok(Json(cert.into()))
|
|
}
|
|
|
|
#[cfg(not(feature = "info"))]
|
|
async fn get_cert_info(
|
|
GetCertInfo { identifier: _ }: GetCertInfo,
|
|
State(ApiState { certs: _, .. }): State<ApiState>,
|
|
) -> ApiResult<()> {
|
|
unimplemented!()
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs/:identifier")]
|
|
pub struct PostCertInfo {
|
|
pub identifier: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct PostCertsQuery {
|
|
challenge: String,
|
|
}
|
|
|
|
impl Into<JWTString> for Query<PostCertsQuery> {
|
|
fn into(self) -> JWTString {
|
|
self.0.challenge.into()
|
|
}
|
|
}
|
|
|
|
/// POST with signed challenge
|
|
async fn post_certs_identifier(
|
|
PostCertInfo { identifier }: PostCertInfo,
|
|
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
|
|
JWTAuthenticated {
|
|
data: AuthClaims {
|
|
identifier: authenticated_identifier,
|
|
},
|
|
..
|
|
}: JWTAuthenticated<AuthClaims, ApiState, Query<PostCertsQuery>, ApiError>,
|
|
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
|
SignatureBody(sig): SignatureBody,
|
|
) -> ApiResult<String> {
|
|
let certs = certs.lock().await;
|
|
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
|
if authenticated_identifier != identifier {
|
|
return Err(ApiError::InvalidSignature);
|
|
}
|
|
let pubkey: PublicKey = cert.public_key().clone().into();
|
|
let verification = tokio::task::spawn_blocking(move || {
|
|
pubkey
|
|
.verify(&identifier, challenge.as_bytes(), &sig)
|
|
.map_err(|_| ApiError::InvalidSignature)
|
|
})
|
|
.await
|
|
.context("tokio blocking")?;
|
|
verification?;
|
|
Ok(cert.to_openssh().context("to openssh")?)
|
|
}
|
|
|
|
#[derive(TypedPath)]
|
|
#[typed_path("/cert")]
|
|
pub struct PutCert;
|
|
|
|
/// Upload an cert with an higher serial than the previous
|
|
async fn put_cert_update(
|
|
_: PutCert,
|
|
State(ApiState {
|
|
ca,
|
|
cert_dir,
|
|
certs,
|
|
validation_args:
|
|
CertificateValidationArgs {
|
|
validate_greater_expiry,
|
|
validate_greater_serial,
|
|
},
|
|
..
|
|
}): State<ApiState>,
|
|
CertificateBody(cert): CertificateBody,
|
|
) -> ApiResult<String> {
|
|
let cert = {
|
|
let ca = ca.clone();
|
|
tokio::task::spawn_blocking(move || -> ApiResult<Certificate> {
|
|
let cert = cert;
|
|
cert.validate(&[ca.fingerprint(Default::default())])
|
|
.map_err(|_| ApiError::CertificateInvalid)?;
|
|
Ok(cert)
|
|
})
|
|
.await
|
|
.context("signature verification")??
|
|
};
|
|
let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?;
|
|
let mut prev_serial = 0;
|
|
let serial = cert.serial();
|
|
if let Some(prev) = prev {
|
|
let prev_exp = prev.valid_before();
|
|
let exp = cert.valid_before();
|
|
trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate");
|
|
prev_serial = prev.serial();
|
|
if validate_greater_serial && prev.serial() >= cert.serial() {
|
|
return Err(ApiError::LowSerial(prev_serial, serial));
|
|
}
|
|
// check if new certificate is valid for longer than the old one
|
|
if validate_greater_expiry && prev_exp >= exp {
|
|
return Err(ApiError::InsufficientValidity(prev.valid_before_time()));
|
|
}
|
|
}
|
|
store_cert(&cert_dir, &ca, &cert).await?;
|
|
let principals = cert.valid_principals();
|
|
let identity = cert.key_id();
|
|
info!(%identity, ?principals, "updating certificate");
|
|
certs.lock().await.insert(cert.key_id().to_string(), cert);
|
|
Ok(format!("{} -> {}", prev_serial, serial))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::env::temp_dir;
|
|
|
|
use super::*;
|
|
|
|
fn ca_key() -> Ed25519Keypair {
|
|
Ed25519Keypair::from_seed(&[0u8; 32])
|
|
}
|
|
|
|
fn ca_key2() -> Ed25519Keypair {
|
|
Ed25519Keypair::from_seed(&[10u8; 32])
|
|
}
|
|
|
|
fn ca_pub() -> PublicKey {
|
|
PublicKey::new(ca_key().public.into(), "TEST CA")
|
|
}
|
|
|
|
fn user_key() -> Ed25519Keypair {
|
|
Ed25519Keypair::from_seed(&[1u8; 32])
|
|
}
|
|
|
|
fn user_cert(ca: Ed25519Keypair, user_key: PublicKey) -> Certificate {
|
|
let ca_private: PrivateKey = ca.into();
|
|
let unix_time = |time: SystemTime| -> u64 {
|
|
time.duration_since(SystemTime::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs()
|
|
};
|
|
let mut builder = certificate::Builder::new(
|
|
[0u8; 16],
|
|
user_key,
|
|
unix_time(SystemTime::now()),
|
|
unix_time(SystemTime::now() + Duration::from_secs(30)),
|
|
);
|
|
|
|
builder
|
|
.valid_principal("git")
|
|
.unwrap()
|
|
.key_id("test_cert")
|
|
.unwrap()
|
|
.comment("A TEST CERT")
|
|
.unwrap();
|
|
|
|
builder.sign(&ca_private).unwrap()
|
|
}
|
|
|
|
fn api_state() -> ApiState {
|
|
let ca: PublicKey = ca_pub();
|
|
ApiState {
|
|
ca,
|
|
certs: Default::default(),
|
|
cert_dir: dbg!(temp_dir()),
|
|
validation_args: Default::default(),
|
|
client_auth: false,
|
|
jwt_key: Hs256Key::new(&[0u8; 16]),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_certificate() {
|
|
let valid_cert = user_cert(ca_key(), user_key().public.into());
|
|
let ca_pub: PublicKey = ca_pub();
|
|
assert!(valid_cert
|
|
.validate(&[ca_pub.fingerprint(Default::default())])
|
|
.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn routes() -> anyhow::Result<()> {
|
|
let state = api_state();
|
|
let valid_cert = user_cert(ca_key(), user_key().public.into());
|
|
let invalid_cert = user_cert(ca_key2(), user_key().public.into());
|
|
let res = put_cert_update(
|
|
PutCert,
|
|
State(state.clone()),
|
|
CertificateBody(valid_cert.clone()),
|
|
)
|
|
.await;
|
|
assert!(dbg!(res).is_ok());
|
|
assert_eq!(state.certs.lock().await.get("test_cert"), Some(&valid_cert));
|
|
let res = put_cert_update(
|
|
PutCert,
|
|
State(state.clone()),
|
|
CertificateBody(invalid_cert.clone()),
|
|
)
|
|
.await;
|
|
assert!(matches!(res, Err(ApiError::CertificateInvalid)));
|
|
|
|
let cert = get_certs_identifier(
|
|
GetCert {
|
|
identifier: "test_cert".into(),
|
|
},
|
|
State(state.clone()),
|
|
)
|
|
.await?;
|
|
assert_eq!(cert, valid_cert.to_openssh()?);
|
|
let res = get_certs_identifier(
|
|
GetCert {
|
|
identifier: "missing_cert".into(),
|
|
},
|
|
State(state.clone()),
|
|
)
|
|
.await;
|
|
assert!(matches!(res, Err(ApiError::CertificateNotFound)));
|
|
|
|
let state = ApiState {
|
|
client_auth: true,
|
|
..state
|
|
};
|
|
|
|
let res = get_certs_identifier(
|
|
GetCert {
|
|
identifier: "test_cert".into(),
|
|
},
|
|
State(state.clone()),
|
|
)
|
|
.await;
|
|
|
|
assert!(matches!(res, Err(ApiError::AuthenticationRequired(_))));
|
|
|
|
if let Err(ApiError::AuthenticationRequired(challenge)) = res {
|
|
let signing_key: PrivateKey = user_key().into();
|
|
let sig = signing_key.sign("test_cert", Default::default(), challenge.as_bytes())?;
|
|
let cert = post_certs_identifier(
|
|
PostCertInfo {
|
|
identifier: "test_cert".into(),
|
|
},
|
|
State(state.clone()),
|
|
JWTAuthenticated::from(AuthClaims {
|
|
identifier: "test_cert".into(),
|
|
}),
|
|
Query(PostCertsQuery { challenge }),
|
|
SignatureBody(sig),
|
|
)
|
|
.await?;
|
|
assert_eq!(cert, valid_cert.to_openssh()?);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|