301 lines
8.7 KiB
Rust
301 lines
8.7 KiB
Rust
mod extract;
|
|
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::path::{self, PathBuf};
|
|
use std::sync::Arc;
|
|
use std::time::SystemTime;
|
|
|
|
use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert};
|
|
use crate::env_key;
|
|
use anyhow::Context;
|
|
use axum::body;
|
|
use axum::extract::{Path, State};
|
|
|
|
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
|
use axum_extra::routing::{
|
|
RouterExt, // for `Router::typed_*`
|
|
TypedPath,
|
|
};
|
|
use clap::{Args, Parser};
|
|
use serde::Deserialize;
|
|
use ssh_key::{Certificate, PublicKey};
|
|
use tokio::sync::Mutex;
|
|
use tower::ServiceBuilder;
|
|
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
|
use tracing::{debug, trace};
|
|
|
|
use self::extract::CertificateBody;
|
|
|
|
#[derive(Parser)]
|
|
pub struct ApiArgs {
|
|
#[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))]
|
|
address: SocketAddr,
|
|
#[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))]
|
|
cert_dir: PathBuf,
|
|
#[clap(flatten)]
|
|
validation_args: CertificateValidationArgs,
|
|
/// CA public key
|
|
#[clap(long = "ca", env = env_key!("CA"))]
|
|
ca: PathBuf,
|
|
}
|
|
|
|
#[derive(Debug, Args, Copy, Clone)]
|
|
pub struct CertificateValidationArgs {
|
|
/// Check whether an certificate update contains an greater serial number than the existing
|
|
/// certificate
|
|
#[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))]
|
|
validate_greater_expiry: bool,
|
|
/// Check whether an certificate update contains an expiry date further in the future than the
|
|
/// existing certificate
|
|
#[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))]
|
|
validate_greater_serial: bool,
|
|
}
|
|
|
|
impl Default for CertificateValidationArgs {
|
|
fn default() -> Self {
|
|
Self {
|
|
validate_greater_expiry: true,
|
|
validate_greater_serial: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ApiArgs {
|
|
fn default() -> Self {
|
|
Self {
|
|
address: SocketAddr::from(([127, 0, 0, 1], 3000)),
|
|
cert_dir: "certs".into(),
|
|
ca: "certs/ca.pub".into(),
|
|
validation_args: Default::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct ApiState {
|
|
certs: Arc<Mutex<HashMap<String, Certificate>>>,
|
|
cert_dir: PathBuf,
|
|
ca: PublicKey,
|
|
validation_args: CertificateValidationArgs,
|
|
}
|
|
|
|
impl ApiState {
|
|
async fn new(
|
|
cert_dir: impl AsRef<path::Path>,
|
|
ca_file: impl AsRef<path::Path>,
|
|
validation_args: CertificateValidationArgs,
|
|
) -> anyhow::Result<Self> {
|
|
let ca = read_pubkey(ca_file.as_ref()).await?;
|
|
let certs = read_certs(&ca, cert_dir.as_ref()).await?;
|
|
Ok(Self {
|
|
certs: Arc::new(Mutex::new(
|
|
certs
|
|
.into_iter()
|
|
.map(|cert| (cert.key_id().to_string(), cert))
|
|
.collect(),
|
|
)),
|
|
cert_dir: cert_dir.as_ref().into(),
|
|
ca,
|
|
validation_args,
|
|
})
|
|
}
|
|
}
|
|
|
|
pub async fn run(
|
|
ApiArgs {
|
|
address,
|
|
cert_dir,
|
|
ca,
|
|
validation_args,
|
|
}: ApiArgs,
|
|
) -> anyhow::Result<()> {
|
|
let state = ApiState::new(&cert_dir, &ca, validation_args).await?;
|
|
|
|
#[cfg(feature = "reload")]
|
|
{
|
|
let state = state.clone();
|
|
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
|
if let Ok(certs) = read_certs(&state.ca, &state.cert_dir).await {
|
|
*state.certs.lock().await = certs
|
|
.into_iter()
|
|
.map(|cert| (cert.key_id().to_string(), cert))
|
|
.collect();
|
|
debug!("reloaded certs");
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
let app = Router::new()
|
|
.typed_get(get_certs_identifier)
|
|
.typed_put(put_cert_update)
|
|
.typed_get(get_cert_info)
|
|
.typed_post(post_certs_identifier)
|
|
.fallback(fallback_404)
|
|
.layer(ServiceBuilder::new().map_request_body(body::boxed))
|
|
.layer(TraceLayer::new_for_http())
|
|
.with_state(state);
|
|
|
|
// run our app with hyper
|
|
// `axum::Server` is a re-export of `hyper::Server`
|
|
debug!("listening on {}", address);
|
|
axum::Server::bind(&address)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum ApiError {
|
|
#[error("internal server error")]
|
|
Internal(#[from] anyhow::Error),
|
|
#[error("certificate not found")]
|
|
CertificateNotFound,
|
|
#[error("invalid certificate")]
|
|
CertificateInvalid,
|
|
#[error("serial must be greater than {1}")]
|
|
LowSerial(u64, u64),
|
|
#[error("expiry date must be greater than {0:?}")]
|
|
InsufficientValidity(SystemTime),
|
|
}
|
|
|
|
type ApiResult<T> = Result<T, ApiError>;
|
|
|
|
impl IntoResponse for ApiError {
|
|
fn into_response(self) -> axum::response::Response {
|
|
(
|
|
match self {
|
|
Self::CertificateNotFound => StatusCode::NOT_FOUND,
|
|
Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST,
|
|
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
|
},
|
|
self.to_string(),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
async fn fallback_404() -> ApiResult<()> {
|
|
Err(ApiError::CertificateNotFound)
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs/:identifier")]
|
|
pub struct GetCert {
|
|
pub identifier: String,
|
|
}
|
|
|
|
/// Retrieve an certificate for identifier
|
|
/// TODO: add option to require auth
|
|
/// return Unauthorized with an challenge
|
|
/// upon which the client will ssh-keysign
|
|
/// the challenge an issue an post request
|
|
async fn get_certs_identifier(
|
|
GetCert { identifier }: GetCert,
|
|
State(ApiState { certs, .. }): State<ApiState>,
|
|
) -> ApiResult<String> {
|
|
let certs = certs.lock().await;
|
|
let cert = certs
|
|
.get(&identifier)
|
|
.ok_or(ApiError::CertificateNotFound)?;
|
|
Ok(cert.to_openssh().context("to openssh")?)
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs/:identifier/info")]
|
|
pub struct GetCertInfo {
|
|
pub identifier: String,
|
|
}
|
|
|
|
#[cfg(feature = "info")]
|
|
async fn get_cert_info(
|
|
GetCertInfo { identifier }: GetCertInfo,
|
|
State(ApiState { certs, .. }): State<ApiState>,
|
|
) -> ApiResult<Json<Certificate>> {
|
|
let certs = certs.lock().await;
|
|
let cert = certs
|
|
.get(&identifier)
|
|
.ok_or(ApiError::CertificateNotFound)?;
|
|
Ok(Json(cert.clone()))
|
|
}
|
|
|
|
#[cfg(not(feature = "info"))]
|
|
async fn get_cert_info(
|
|
GetCertInfo { identifier: _ }: GetCertInfo,
|
|
State(ApiState { certs: _, .. }): State<ApiState>,
|
|
) -> ApiResult<()> {
|
|
unimplemented!()
|
|
}
|
|
|
|
#[derive(TypedPath, Deserialize)]
|
|
#[typed_path("/certs/:identifier")]
|
|
pub struct PostCertInfo {
|
|
pub identifier: String,
|
|
}
|
|
|
|
/// POST with signed challenge
|
|
async fn post_certs_identifier(
|
|
PostCertInfo { identifier: _ }: PostCertInfo,
|
|
State(ApiState { .. }): State<ApiState>,
|
|
Path(_identifier): Path<String>,
|
|
) -> ApiResult<String> {
|
|
unimplemented!()
|
|
}
|
|
|
|
#[derive(TypedPath)]
|
|
#[typed_path("/cert")]
|
|
pub struct PutCert;
|
|
|
|
/// Upload an cert with an higher serial than the previous
|
|
async fn put_cert_update(
|
|
_: PutCert,
|
|
State(ApiState {
|
|
ca,
|
|
cert_dir,
|
|
certs,
|
|
validation_args:
|
|
CertificateValidationArgs {
|
|
validate_greater_expiry,
|
|
validate_greater_serial,
|
|
},
|
|
..
|
|
}): State<ApiState>,
|
|
CertificateBody(cert): CertificateBody,
|
|
) -> ApiResult<String> {
|
|
let cert = {
|
|
let ca = ca.clone();
|
|
tokio::task::spawn_blocking(move || -> ApiResult<Certificate> {
|
|
let cert = cert;
|
|
cert.validate(&[ca.fingerprint(Default::default())])
|
|
.map_err(|_| ApiError::CertificateInvalid)?;
|
|
Ok(cert)
|
|
})
|
|
.await
|
|
.context("signature verification")??
|
|
};
|
|
let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?;
|
|
let mut prev_serial = 0;
|
|
let serial = cert.serial();
|
|
if let Some(prev) = prev {
|
|
let prev_exp = prev.valid_before();
|
|
let exp = cert.valid_before();
|
|
trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate");
|
|
prev_serial = prev.serial();
|
|
if validate_greater_serial && prev.serial() >= cert.serial() {
|
|
return Err(ApiError::LowSerial(prev_serial, serial));
|
|
}
|
|
// check if new certificate is valid for longer than the old one
|
|
if validate_greater_expiry && prev_exp >= exp {
|
|
return Err(ApiError::InsufficientValidity(prev.valid_before_time()));
|
|
}
|
|
}
|
|
store_cert(&cert_dir, &ca, &cert).await?;
|
|
certs.lock().await.insert(cert.key_id().to_string(), cert);
|
|
Ok(format!("{} -> {}", prev_serial, serial))
|
|
}
|