mod extract; use std::collections::HashMap; use std::net::SocketAddr; use std::path::{self, PathBuf}; use std::sync::Arc; use std::time::SystemTime; use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert}; use crate::env_key; use anyhow::Context; use axum::body; use axum::extract::{Path, State}; use axum::{http::StatusCode, response::IntoResponse, Json, Router}; use axum_extra::routing::{ RouterExt, // for `Router::typed_*` TypedPath, }; use clap::{Args, Parser}; use serde::Deserialize; use ssh_key::{Certificate, PublicKey}; use tokio::sync::Mutex; use tower::ServiceBuilder; use tower_http::{trace::TraceLayer, ServiceBuilderExt}; use tracing::{debug, trace}; use self::extract::CertificateBody; #[derive(Parser)] pub struct ApiArgs { #[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))] address: SocketAddr, #[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))] cert_dir: PathBuf, #[clap(flatten)] validation_args: CertificateValidationArgs, /// CA public key #[clap(long = "ca", env = env_key!("CA"))] ca: PathBuf, } #[derive(Debug, Args, Copy, Clone)] pub struct CertificateValidationArgs { /// Check whether an certificate update contains an greater serial number than the existing /// certificate #[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))] validate_greater_expiry: bool, /// Check whether an certificate update contains an expiry date further in the future than the /// existing certificate #[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))] validate_greater_serial: bool, } impl Default for CertificateValidationArgs { fn default() -> Self { Self { validate_greater_expiry: true, validate_greater_serial: false, } } } impl Default for ApiArgs { fn default() -> Self { Self { address: SocketAddr::from(([127, 0, 0, 1], 3000)), cert_dir: "certs".into(), ca: "certs/ca.pub".into(), validation_args: Default::default(), } } } #[derive(Debug, Clone)] struct ApiState { certs: Arc>>, cert_dir: PathBuf, ca: PublicKey, validation_args: CertificateValidationArgs, } impl ApiState { async fn new( cert_dir: impl AsRef, ca_file: impl AsRef, validation_args: CertificateValidationArgs, ) -> anyhow::Result { let ca = read_pubkey(ca_file.as_ref()).await?; let certs = read_certs(&ca, cert_dir.as_ref()).await?; Ok(Self { certs: Arc::new(Mutex::new( certs .into_iter() .map(|cert| (cert.key_id().to_string(), cert)) .collect(), )), cert_dir: cert_dir.as_ref().into(), ca, validation_args, }) } } pub async fn run( ApiArgs { address, cert_dir, ca, validation_args, }: ApiArgs, ) -> anyhow::Result<()> { let state = ApiState::new(&cert_dir, &ca, validation_args).await?; #[cfg(feature = "reload")] { let state = state.clone(); tokio::spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(30)).await; if let Ok(certs) = read_certs(&state.ca, &state.cert_dir).await { *state.certs.lock().await = certs .into_iter() .map(|cert| (cert.key_id().to_string(), cert)) .collect(); debug!("reloaded certs"); } } }); } let app = Router::new() .typed_get(get_certs_identifier) .typed_put(put_cert_update) .typed_get(get_cert_info) .typed_post(post_certs_identifier) .fallback(fallback_404) .layer(ServiceBuilder::new().map_request_body(body::boxed)) .layer(TraceLayer::new_for_http()) .with_state(state); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` debug!("listening on {}", address); axum::Server::bind(&address) .serve(app.into_make_service()) .await .unwrap(); Ok(()) } #[derive(Debug, thiserror::Error)] pub enum ApiError { #[error("internal server error")] Internal(#[from] anyhow::Error), #[error("certificate not found")] CertificateNotFound, #[error("invalid certificate")] CertificateInvalid, #[error("serial must be greater than {1}")] LowSerial(u64, u64), #[error("expiry date must be greater than {0:?}")] InsufficientValidity(SystemTime), } type ApiResult = Result; impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { ( match self { Self::CertificateNotFound => StatusCode::NOT_FOUND, Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST, _ => StatusCode::INTERNAL_SERVER_ERROR, }, self.to_string(), ) .into_response() } } async fn fallback_404() -> ApiResult<()> { Err(ApiError::CertificateNotFound) } #[derive(TypedPath, Deserialize)] #[typed_path("/certs/:identifier")] pub struct GetCert { pub identifier: String, } /// Retrieve an certificate for identifier /// TODO: add option to require auth /// return Unauthorized with an challenge /// upon which the client will ssh-keysign /// the challenge an issue an post request async fn get_certs_identifier( GetCert { identifier }: GetCert, State(ApiState { certs, .. }): State, ) -> ApiResult { let certs = certs.lock().await; let cert = certs .get(&identifier) .ok_or(ApiError::CertificateNotFound)?; Ok(cert.to_openssh().context("to openssh")?) } #[derive(TypedPath, Deserialize)] #[typed_path("/certs/:identifier/info")] pub struct GetCertInfo { pub identifier: String, } #[cfg(feature = "info")] async fn get_cert_info( GetCertInfo { identifier }: GetCertInfo, State(ApiState { certs, .. }): State, ) -> ApiResult> { let certs = certs.lock().await; let cert = certs .get(&identifier) .ok_or(ApiError::CertificateNotFound)?; Ok(Json(cert.clone())) } #[cfg(not(feature = "info"))] async fn get_cert_info( GetCertInfo { identifier: _ }: GetCertInfo, State(ApiState { certs: _, .. }): State, ) -> ApiResult<()> { unimplemented!() } #[derive(TypedPath, Deserialize)] #[typed_path("/certs/:identifier")] pub struct PostCertInfo { pub identifier: String, } /// POST with signed challenge async fn post_certs_identifier( PostCertInfo { identifier: _ }: PostCertInfo, State(ApiState { .. }): State, Path(_identifier): Path, ) -> ApiResult { unimplemented!() } #[derive(TypedPath)] #[typed_path("/cert")] pub struct PutCert; /// Upload an cert with an higher serial than the previous async fn put_cert_update( _: PutCert, State(ApiState { ca, cert_dir, certs, validation_args: CertificateValidationArgs { validate_greater_expiry, validate_greater_serial, }, .. }): State, CertificateBody(cert): CertificateBody, ) -> ApiResult { let cert = { let ca = ca.clone(); tokio::task::spawn_blocking(move || -> ApiResult { let cert = cert; cert.validate(&[ca.fingerprint(Default::default())]) .map_err(|_| ApiError::CertificateInvalid)?; Ok(cert) }) .await .context("signature verification")?? }; let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?; let mut prev_serial = 0; let serial = cert.serial(); if let Some(prev) = prev { let prev_exp = prev.valid_before(); let exp = cert.valid_before(); trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate"); prev_serial = prev.serial(); if validate_greater_serial && prev.serial() >= cert.serial() { return Err(ApiError::LowSerial(prev_serial, serial)); } // check if new certificate is valid for longer than the old one if validate_greater_expiry && prev_exp >= exp { return Err(ApiError::InsufficientValidity(prev.valid_before_time())); } } store_cert(&cert_dir, &ca, &cert).await?; certs.lock().await.insert(cert.key_id().to_string(), cert); Ok(format!("{} -> {}", prev_serial, serial)) }