mod extract; use std::collections::HashMap; use std::net::SocketAddr; use std::path::{self, PathBuf}; use std::sync::Arc; use std::time::SystemTime; use anyhow::Context; use axum::body; use axum::extract::{Query, State}; use chrono::Duration; use ssh_cert_dist_common::*; use axum::{http::StatusCode, response::IntoResponse, Json, Router}; use axum_extra::routing::RouterExt; use clap::{Args, Parser}; use jwt_compact::alg::{Hs256, Hs256Key}; use jwt_compact::{AlgorithmExt, Token, UntrustedToken}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use ssh_key::{Certificate, Fingerprint, PublicKey}; use tokio::sync::Mutex; use tower::ServiceBuilder; use tower_http::{trace::TraceLayer, ServiceBuilderExt}; use tracing::{debug, info, trace}; use self::extract::{CertificateBody, SignatureBody}; #[derive(Parser)] pub struct ApiArgs { #[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))] address: SocketAddr, #[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))] cert_dir: PathBuf, #[clap(flatten)] validation_args: CertificateValidationArgs, /// CA public key #[clap(long = "ca", env = env_key!("CA"))] ca: PathBuf, } #[derive(Debug, Args, Copy, Clone)] pub struct CertificateValidationArgs { /// Check whether an certificate update contains an greater serial number than the existing /// certificate #[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))] validate_greater_expiry: bool, /// Check whether an certificate update contains an expiry date further in the future than the /// existing certificate #[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))] validate_greater_serial: bool, } impl Default for CertificateValidationArgs { fn default() -> Self { Self { validate_greater_expiry: true, validate_greater_serial: false, } } } impl Default for ApiArgs { fn default() -> Self { Self { address: SocketAddr::from(([127, 0, 0, 1], 3000)), cert_dir: "certs".into(), ca: "certs/ca.pub".into(), validation_args: Default::default(), } } } #[derive(Debug, Clone)] struct ApiState { certs: Arc>>, cert_dir: PathBuf, ca: PublicKey, client_auth: bool, validation_args: CertificateValidationArgs, jwt_key: Hs256Key, } impl ApiState { async fn new( cert_dir: impl AsRef, ca_file: impl AsRef, validation_args: CertificateValidationArgs, ) -> anyhow::Result { let ca = read_pubkey(ca_file.as_ref()).await?; let certs = read_certs(&ca, cert_dir.as_ref()).await?; Ok(Self { certs: Arc::new(Mutex::new( certs .into_iter() .map(|cert| (cert.key_id().to_string(), cert)) .collect(), )), cert_dir: cert_dir.as_ref().into(), ca, client_auth: false, validation_args, jwt_key: Hs256Key::new(thread_rng().gen::<[u8; 16]>()), }) } } pub async fn run( ApiArgs { address, cert_dir, ca, validation_args, }: ApiArgs, ) -> anyhow::Result<()> { let state = ApiState::new(&cert_dir, &ca, validation_args).await?; #[cfg(feature = "reload")] { let state = state.clone(); tokio::spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(30)).await; if let Ok(certs) = read_certs(&state.ca, &state.cert_dir).await { *state.certs.lock().await = certs .into_iter() .map(|cert| (cert.key_id().to_string(), cert)) .collect(); trace!("reloaded certs"); } } }); } let app = Router::new() .typed_get(get_certs_identifier) .typed_get(get_certs_pubkey) .typed_put(put_cert_update) .typed_get(get_cert_info) .typed_post(post_certs_identifier); #[cfg(feature = "index")] let app = app.typed_get(list_certs); let app = app .fallback(fallback_404) .layer(ServiceBuilder::new().map_request_body(body::boxed)) .layer(TraceLayer::new_for_http()) .with_state(state); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` debug!("listening on {}", address); axum::Server::bind(&address) .serve(app.into_make_service()) .await .unwrap(); Ok(()) } #[derive(Debug, thiserror::Error)] pub enum ApiError { #[error("internal server error")] Internal(#[from] anyhow::Error), #[error("certificate not found")] CertificateNotFound, #[error("invalid certificate")] CertificateInvalid, #[error("serial must be greater than {1}")] LowSerial(u64, u64), #[error("expiry date must be greater than {0:?}")] InsufficientValidity(SystemTime), #[error("authentication required")] AuthenticationRequired(String), #[error("invalid ssh signature")] InvalidSignature, } type ApiResult = Result; impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { ( match self { Self::CertificateNotFound => StatusCode::NOT_FOUND, Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST, Self::AuthenticationRequired(challenge) => { return (StatusCode::UNAUTHORIZED, challenge).into_response() } _ => StatusCode::INTERNAL_SERVER_ERROR, }, self.to_string(), ) .into_response() } } async fn fallback_404() -> ApiResult<()> { Err(ApiError::CertificateNotFound) } #[cfg(feature = "index")] async fn list_certs( _: CertList, State(ApiState { certs, .. }): State, ) -> ApiResult>> { Ok(Json( certs .lock() .await .values() .into_iter() .map(|cert| cert.key_id().to_string()) .collect(), )) } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "aud", rename = "get")] struct AuthClaims { identifier: String, } async fn request_client_auth(enabled: bool, identifier: &str, jwt_key: &Hs256Key) -> ApiResult<()> { use jwt_compact::{Claims, Header, TimeOptions}; if enabled { let claims = Claims::new(AuthClaims { identifier: identifier.into(), }) .set_duration(&TimeOptions::default(), chrono::Duration::seconds(120)); let challenge = Hs256 .compact_token(Header::default(), &claims, &jwt_key) .context("jwt sign")?; return Err(ApiError::AuthenticationRequired(challenge)); } else { Ok(()) } } /// Retrieve an certificate for identifier /// TODO: add option to require auth /// return Unauthorized with an challenge /// upon which the client will ssh-keysign /// the challenge an issue an post request async fn get_certs_identifier( GetCert { identifier }: GetCert, State(ApiState { certs, jwt_key, client_auth, .. }): State, ) -> ApiResult { request_client_auth(client_auth, &identifier, &jwt_key).await?; let certs = certs.lock().await; let cert = certs .get(&identifier) .ok_or(ApiError::CertificateNotFound)?; Ok(cert.to_openssh().context("to openssh")?) } async fn get_certs_pubkey( GetCertsPubkey { pubkey_hash }: GetCertsPubkey, State(ApiState { certs, jwt_key: _, client_auth: _, .. }): State, ) -> ApiResult> { let certs = certs.lock().await; let ids = certs .values() .filter(|cert| &cert.public_key().fingerprint(pubkey_hash.algorithm()) == &pubkey_hash) .map(|cert| cert.key_id().to_string()) .collect::>(); Ok(Json(CertIds { ids })) } #[cfg(feature = "info")] #[derive(Debug, Serialize)] struct CertInfo { principals: Vec, ca: PublicKey, ca_hash: Fingerprint, identity: PublicKey, identity_hash: Fingerprint, key_id: String, expiry: SystemTime, renew_command: String, } impl From<&Certificate> for CertInfo { fn from(cert: &Certificate) -> Self { let validity = cert.valid_before_time().duration_since(cert.valid_after_time()).unwrap_or(Duration::zero().to_std().unwrap()); let validity_days = validity.as_secs() / ((60*60) * 24); let host_key = if cert.cert_type().is_host() { " -h" } else { "" }; let opts = cert.critical_options().iter().map(|(opt, val)| if val.is_empty() { opt.clone() } else { format!("{opt}={val}") }).map(|arg| format!("-O {arg}")).collect::>().join(" "); let renew_command = format!("ssh-keygen -s ./ca_key {host_key} -I {} -n {} -V {validity_days}d {opts}", cert.key_id(), cert.valid_principals().join(",")); CertInfo { principals: cert.valid_principals().to_vec(), ca: cert.signature_key().clone().into(), ca_hash: cert.signature_key().fingerprint(ssh_key::HashAlg::Sha256), identity: cert.public_key().clone().into(), identity_hash: cert.public_key().fingerprint(ssh_key::HashAlg::Sha256), key_id: cert.key_id().to_string(), expiry: cert.valid_before_time(), renew_command } } } #[cfg(feature = "info")] async fn get_cert_info( GetCertInfo { identifier }: GetCertInfo, State(ApiState { certs, .. }): State, ) -> ApiResult> { let certs = certs.lock().await; let cert = certs .get(&identifier) .ok_or(ApiError::CertificateNotFound)?; Ok(Json(cert.into())) } #[cfg(not(feature = "info"))] async fn get_cert_info( GetCertInfo { identifier: _ }: GetCertInfo, State(ApiState { certs: _, .. }): State, ) -> ApiResult<()> { unimplemented!() } #[derive(Debug, Deserialize)] struct PostCertsQuery { challenge: String, } /// POST with signed challenge async fn post_certs_identifier( PostCertInfo { identifier }: PostCertInfo, State(ApiState { certs, jwt_key, .. }): State, Query(PostCertsQuery { challenge }): Query, SignatureBody(sig): SignatureBody, ) -> ApiResult { let certs = certs.lock().await; let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?; let token: Token = Hs256 .validate_integrity( &UntrustedToken::new(&challenge).context("jwt parse")?, &jwt_key, ) .map_err(|_| ApiError::InvalidSignature)?; if token.claims().custom.identifier != identifier { return Err(ApiError::InvalidSignature); } let pubkey: PublicKey = cert.public_key().clone().into(); let verification = tokio::task::spawn_blocking(move || { pubkey .verify(&identifier, challenge.as_bytes(), &sig) .map_err(|_| ApiError::InvalidSignature) }) .await .context("tokio blocking")?; verification?; Ok(cert.to_openssh().context("to openssh")?) } /// Upload an cert with an higher serial than the previous async fn put_cert_update( _: PutCert, State(ApiState { ca, cert_dir, certs, validation_args: CertificateValidationArgs { validate_greater_expiry, validate_greater_serial, }, .. }): State, CertificateBody(cert): CertificateBody, ) -> ApiResult { let cert = { let ca = ca.clone(); tokio::task::spawn_blocking(move || -> ApiResult { let cert = cert; cert.validate(&[ca.fingerprint(Default::default())]) .map_err(|_| ApiError::CertificateInvalid)?; Ok(cert) }) .await .context("signature verification")?? }; let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?; let mut prev_serial = 0; let serial = cert.serial(); if let Some(prev) = prev { let prev_exp = prev.valid_before(); let exp = cert.valid_before(); trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate"); prev_serial = prev.serial(); if validate_greater_serial && prev.serial() >= cert.serial() { return Err(ApiError::LowSerial(prev_serial, serial)); } // check if new certificate is valid for longer than the old one if validate_greater_expiry && prev_exp >= exp { return Err(ApiError::InsufficientValidity(prev.valid_before_time())); } } store_cert(&cert_dir, &ca, &cert).await?; let principals = cert.valid_principals(); let identity = cert.key_id(); info!(%identity, ?principals, "updating certificate"); certs.lock().await.insert(cert.key_id().to_string(), cert); Ok(format!("{prev_serial} -> {serial}")) } #[cfg(test)] mod tests { use ssh_key::{certificate, private::Ed25519Keypair, PrivateKey}; use std::env::temp_dir; use std::time::Duration; use super::*; fn ca_key() -> Ed25519Keypair { Ed25519Keypair::from_seed(&[0u8; 32]) } fn ca_key2() -> Ed25519Keypair { Ed25519Keypair::from_seed(&[10u8; 32]) } fn ca_pub() -> PublicKey { PublicKey::new( ca_key().public.into(), format!( "TEST CA {}", SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs() ), ) } fn user_key() -> Ed25519Keypair { Ed25519Keypair::from_seed(&[1u8; 32]) } fn user_cert(ca: Ed25519Keypair, user_key: PublicKey, validity: Duration) -> Certificate { let ca_private: PrivateKey = ca.into(); let unix_time = |time: SystemTime| -> u64 { time.duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs() }; let mut builder = certificate::Builder::new( [0u8; 16], user_key, unix_time(SystemTime::now()), unix_time(SystemTime::now() + validity), ); builder .valid_principal("git") .unwrap() .key_id("test_cert") .unwrap() .comment(&format!("A TEST CERT, VALID FOR {}s", validity.as_secs())) .unwrap(); builder.sign(&ca_private).unwrap() } fn api_state() -> ApiState { let ca: PublicKey = ca_pub(); ApiState { ca, certs: Default::default(), cert_dir: dbg!(temp_dir()), validation_args: Default::default(), client_auth: false, jwt_key: Hs256Key::new([0u8; 16]), } } #[test] fn test_certificate() { let valid_cert = user_cert(ca_key(), user_key().public.into(), Duration::from_secs(30)); let ca_pub: PublicKey = ca_pub(); assert!(valid_cert .validate(&[ca_pub.fingerprint(Default::default())]) .is_ok()); } #[tokio::test] async fn update_cert() { let state = api_state(); let ca = ca_key(); let user: PublicKey = user_key().public.into(); let (cert_first, cert_newer, cert_outdated) = { ( user_cert(ca.clone(), user.clone(), Duration::from_secs(300)), user_cert(ca.clone(), user.clone(), Duration::from_secs(600)), user_cert(ca.clone(), user.clone(), Duration::from_secs(30)), ) }; let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_first)).await; assert!(res.is_ok()); let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_newer)).await; assert!(res.is_ok()); let res = put_cert_update( PutCert, State(state.clone()), CertificateBody(cert_outdated), ) .await; assert!(res.is_err()); } #[tokio::test] async fn routes() -> anyhow::Result<()> { let state = api_state(); let valid_cert = user_cert(ca_key(), user_key().public.into(), Duration::from_secs(30)); let invalid_cert = user_cert(ca_key2(), user_key().public.into(), Duration::from_secs(30)); let res = put_cert_update( PutCert, State(state.clone()), CertificateBody(valid_cert.clone()), ) .await; assert!(dbg!(res).is_ok()); assert_eq!(state.certs.lock().await.get("test_cert"), Some(&valid_cert)); let res = put_cert_update( PutCert, State(state.clone()), CertificateBody(invalid_cert.clone()), ) .await; assert!(matches!(res, Err(ApiError::CertificateInvalid))); let cert = get_certs_identifier( GetCert { identifier: "test_cert".into(), }, State(state.clone()), ) .await?; assert_eq!(cert, valid_cert.to_openssh()?); let res = get_certs_identifier( GetCert { identifier: "missing_cert".into(), }, State(state.clone()), ) .await; assert!(matches!(res, Err(ApiError::CertificateNotFound))); let state = ApiState { client_auth: true, ..state }; let res = get_certs_identifier( GetCert { identifier: "test_cert".into(), }, State(state.clone()), ) .await; assert!(matches!(res, Err(ApiError::AuthenticationRequired(_)))); if let Err(ApiError::AuthenticationRequired(challenge)) = res { let signing_key: PrivateKey = user_key().into(); let sig = signing_key.sign("test_cert", Default::default(), challenge.as_bytes())?; let cert = post_certs_identifier( PostCertInfo { identifier: "test_cert".into(), }, State(state.clone()), Query(PostCertsQuery { challenge }), SignatureBody(sig), ) .await?; assert_eq!(cert, valid_cert.to_openssh()?); } Ok(()) } }