From cdf2f0a5f82f8eea59e31aea860d8793ec69cacf Mon Sep 17 00:00:00 2001 From: shimun Date: Sun, 4 Dec 2022 22:44:37 +0100 Subject: [PATCH] added: validation options --- Cargo.lock | 21 ++++++++++ Cargo.toml | 5 ++- src/api.rs | 114 ++++++++++++++++++++++++++++++++++++-------------- src/client.rs | 43 +++++++++++++------ src/main.rs | 7 ++++ 5 files changed, 144 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 81be549..9226e0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1465,6 +1465,7 @@ dependencies = [ "reqwest", "serde", "ssh-key", + "thiserror", "tokio", "tower", "tower-http", @@ -1565,6 +1566,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.4" diff --git a/Cargo.toml b/Cargo.toml index 5e63d4b..2204e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "ssh-cert-dist" version = "0.1.0" authors = ["shimun "] -edition = "2018" +edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -21,10 +21,11 @@ clap = { version = "4.0.29", features = ["env", "derive"] } reqwest = { version = "0.11.13", optional = true } serde = { version = "1.0.148", features = ["derive"] } ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] } +thiserror = "1.0.37" tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] } tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.3.4", features = ["map-request-body"] } -tracing = "0.1.37" +tracing = { version = "0.1.37", features = ["release_max_level_debug"] } tracing-subscriber = "0.3.16" url = { version = "2.3.1", optional = true } diff --git a/src/api.rs b/src/api.rs index c5ba1cb..dc715bd 100644 --- a/src/api.rs +++ b/src/api.rs @@ -4,44 +4,70 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::path::{self, PathBuf}; use std::sync::Arc; +use std::time::SystemTime; use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert}; +use crate::env_key; use anyhow::Context; -use axum::body::{self}; +use axum::body; use axum::extract::{Path, State}; -use axum::routing::{post, put}; +use axum::routing::post; use axum::{http::StatusCode, response::IntoResponse, Router}; use axum_extra::routing::{ RouterExt, // for `Router::typed_*` TypedPath, }; -use clap::Parser; +use clap::{Args, Parser}; use serde::Deserialize; use ssh_key::{Certificate, PublicKey}; use tokio::sync::Mutex; use tower::ServiceBuilder; use tower_http::ServiceBuilderExt; -use tracing::debug; +use tracing::{debug, instrument, trace}; use self::extract::CertificateBody; #[derive(Parser)] pub struct ApiArgs { - #[clap(short = 'a', long = "address")] + #[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))] address: SocketAddr, - #[clap(short = 'c', long = "cert-store")] + #[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))] cert_dir: PathBuf, + #[clap(flatten)] + validation_args: CertificateValidationArgs, /// CA public key - #[clap(long = "ca")] + #[clap(long = "ca", env = env_key!("CA"))] ca: PathBuf, } +#[derive(Debug, Args, Copy, Clone)] +pub struct CertificateValidationArgs { + /// Check whether an certificate update contains an greater serial number than the existing + /// certificate + #[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))] + validate_greater_expiry: bool, + /// Check whether an certificate update contains an expiry date further in the future than the + /// existing certificate + #[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))] + validate_greater_serial: bool, +} + +impl Default for CertificateValidationArgs { + fn default() -> Self { + Self { + validate_greater_expiry: true, + validate_greater_serial: false, + } + } +} + impl Default for ApiArgs { fn default() -> Self { Self { address: SocketAddr::from(([127, 0, 0, 1], 3000)), cert_dir: "certs".into(), ca: "certs/ca.pub".into(), + validation_args: Default::default(), } } } @@ -51,12 +77,14 @@ struct ApiState { certs: Arc>>, cert_dir: PathBuf, ca: PublicKey, + validation_args: CertificateValidationArgs, } impl ApiState { async fn new( cert_dir: impl AsRef, ca_file: impl AsRef, + validation_args: CertificateValidationArgs, ) -> anyhow::Result { let ca = read_pubkey(ca_file.as_ref()).await?; let certs = read_certs(&ca, cert_dir.as_ref()).await?; @@ -69,6 +97,7 @@ impl ApiState { )), cert_dir: cert_dir.as_ref().into(), ca, + validation_args, }) } } @@ -78,9 +107,10 @@ pub async fn run( address, cert_dir, ca, + validation_args, }: ApiArgs, ) -> anyhow::Result<()> { - let state = ApiState::new(&cert_dir, &ca).await?; + let state = ApiState::new(&cert_dir, &ca, validation_args).await?; #[cfg(feature = "reload")] { @@ -104,6 +134,7 @@ pub async fn run( .typed_get(get_certs_identifier) .typed_put(put_cert_update) .route("/certs/:identifier", post(post_certs_identifier)) + .fallback(fallback_404) .layer(ServiceBuilder::new().map_request_body(body::boxed)) .with_state(state); @@ -117,36 +148,38 @@ pub async fn run( Ok(()) } +#[derive(Debug, thiserror::Error)] pub enum ApiError { - Internal, - NotFound, - Invalid, + #[error("internal server error")] + Internal(#[from] anyhow::Error), + #[error("certificate not found")] + CertificateNotFound, + #[error("invalid certificate")] + CertificateInvalid, + #[error("serial must be greater than {1}")] LowSerial(u64, u64), + #[error("expiry date must be greater than {0:?}")] + InsufficientValidity(SystemTime), } type ApiResult = Result; impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { - match self { - Self::NotFound => (StatusCode::NOT_FOUND, "not here").into_response(), - Self::LowSerial(prev, next) => ( - StatusCode::BAD_REQUEST, - format!( - "new certificate serial must be greater than {}, got {}", - prev, next - ), - ) - .into_response(), - _ => (StatusCode::INTERNAL_SERVER_ERROR, "Oops").into_response(), - } + ( + match self { + Self::CertificateNotFound => StatusCode::NOT_FOUND, + Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + self.to_string(), + ) + .into_response() } } -impl From for ApiError { - fn from(_: anyhow::Error) -> Self { - ApiError::Internal - } +async fn fallback_404() -> ApiResult<()> { + Err(ApiError::CertificateNotFound) } #[derive(TypedPath, Deserialize)] @@ -159,17 +192,21 @@ pub struct GetCert { /// TODO: add option to require auth /// return Unauthorized with an challenge /// upon which the client will ssh-keysign -/// the challene an issue an post request +/// the challenge an issue an post request +#[instrument(skip_all, ret)] async fn get_certs_identifier( GetCert { identifier }: GetCert, State(ApiState { certs, .. }): State, ) -> ApiResult { let certs = certs.lock().await; - let cert = certs.get(&identifier).ok_or(ApiError::NotFound)?; + let cert = certs + .get(&identifier) + .ok_or(ApiError::CertificateNotFound)?; Ok(cert.to_openssh().context("to openssh")?) } /// POST with signed challenge +#[instrument(skip_all, ret)] async fn post_certs_identifier( State(ApiState { .. }): State, Path(_identifier): Path, @@ -182,27 +219,40 @@ async fn post_certs_identifier( pub struct PutCert; /// Upload an cert with an higher serial than the previous +#[instrument(skip_all, ret)] async fn put_cert_update( _: PutCert, State(ApiState { ca, cert_dir, certs, + validation_args: + CertificateValidationArgs { + validate_greater_expiry, + validate_greater_serial, + }, .. }): State, CertificateBody(cert): CertificateBody, ) -> ApiResult { cert.validate(&[ca.fingerprint(Default::default())]) - .map_err(|_| ApiError::Invalid)?; + .map_err(|_| ApiError::CertificateInvalid)?; let _string_repr = cert.to_openssh(); - let prev = load_cert_by_id(&cert_dir, &ca, &cert.key_id()).await?; + let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?; let mut prev_serial = 0; let serial = cert.serial(); if let Some(prev) = prev { + let prev_exp = prev.valid_before(); + let exp = cert.valid_before(); + trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate"); prev_serial = prev.serial(); - if prev.serial() >= cert.serial() { + if validate_greater_serial && prev.serial() >= cert.serial() { return Err(ApiError::LowSerial(prev_serial, serial)); } + // check if new certificate is valid for longer than the old one + if validate_greater_expiry && prev_exp >= exp { + return Err(ApiError::InsufficientValidity(prev.valid_before_time())); + } } store_cert(&cert_dir, &ca, &cert).await?; certs.lock().await.insert(cert.key_id().to_string(), cert); diff --git a/src/client.rs b/src/client.rs index 2e81701..a9c6ab3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,25 +1,27 @@ -use anyhow::{bail, Context}; +use anyhow::{bail}; use axum_extra::routing::TypedPath; use clap::{Args, Parser, Subcommand}; use reqwest::{Client, StatusCode}; use ssh_key::Certificate; use std::path::PathBuf; +use std::time::{Duration, SystemTime}; use tokio::fs; use tracing::{debug, error, info, instrument}; -use tracing::{info_span, Instrument}; + use url::Url; use crate::api::PutCert; use crate::certs::load_cert; +use crate::env_key; use crate::{ api::GetCert, - certs::{self, read_dir}, + certs::{read_dir}, }; #[derive(Parser)] pub struct ClientArgs { /// Url for the API endpoint - #[clap(short = 'a', long = "api-endpoint")] + #[clap(short = 'a', long = "api-endpoint", env = env_key!("API"))] api: Url, } @@ -27,8 +29,11 @@ pub struct ClientArgs { pub struct FetchArgs { #[clap(flatten)] args: ClientArgs, - #[clap(short = 'c', long = "cert-dir", default_value = "~/.ssh")] + #[clap(short = 'c', long = "cert-dir", env = env_key!("CERT_DIR") )] cert_dir: PathBuf, + /// minimum time in days between now and expiry to consider checking + #[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))] + min_delta_days: Option, } #[derive(Parser)] @@ -36,6 +41,7 @@ pub struct UploadArgs { #[clap(flatten)] args: ClientArgs, /// Certificates to be uploaded + #[clap(env = env_key!("FILES"))] files: Vec, } @@ -58,6 +64,9 @@ pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> { } } +#[derive(Debug, thiserror::Error)] +enum UploadError {} + async fn upload( UploadArgs { args: ClientArgs { api }, @@ -102,19 +111,29 @@ async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Res async fn fetch( FetchArgs { cert_dir, + min_delta_days: min_delta, args: ClientArgs { api }, }: FetchArgs, ) -> anyhow::Result<()> { let certs = read_dir(&cert_dir).await?; let client = reqwest::Client::new(); - let updates = certs.into_iter().map(|cert| { - let path = GetCert { - identifier: cert.key_id().to_string(), - }; - let url = api.join(path.to_uri().path()).unwrap(); - let client = client.clone(); - tokio::spawn(async move { fetch_cert(client, url, cert).await }) + let threshold_exp = min_delta.and_then(|min_delta| { + SystemTime::now().checked_sub(Duration::from_secs(60 * 60 * 24 * min_delta as u64)) }); + let updates = certs + .into_iter() + .filter(|cert| { + let exp = cert.valid_before_time(); + threshold_exp.as_ref().map(|th| &exp < th).unwrap_or(true) + }) + .map(|cert| { + let path = GetCert { + identifier: cert.key_id().to_string(), + }; + let url = api.join(path.to_uri().path()).unwrap(); + let client = client.clone(); + tokio::spawn(async move { fetch_cert(client, url, cert).await }) + }); for cert in updates { if let Ok(Some((cert, update))) = cert.await? { fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?; diff --git a/src/main.rs b/src/main.rs index b576320..e3185a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,13 @@ mod certs; #[cfg(feature = "client")] mod client; +#[macro_export] +macro_rules! env_key { + ( $var:expr ) => { + concat!("SSH_CD_", $var) + }; +} + #[derive(Parser)] enum Command { Server(ApiArgs),