added: validation options

This commit is contained in:
shimun 2022-12-04 22:44:37 +01:00
parent b905fa802a
commit cdf2f0a5f8
Signed by: shimun
GPG Key ID: E0420647856EA39E
5 changed files with 144 additions and 46 deletions

21
Cargo.lock generated
View File

@ -1465,6 +1465,7 @@ dependencies = [
"reqwest", "reqwest",
"serde", "serde",
"ssh-key", "ssh-key",
"thiserror",
"tokio", "tokio",
"tower", "tower",
"tower-http", "tower-http",
@ -1565,6 +1566,26 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "thiserror"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.4" version = "1.1.4"

View File

@ -2,7 +2,7 @@
name = "ssh-cert-dist" name = "ssh-cert-dist"
version = "0.1.0" version = "0.1.0"
authors = ["shimun <shimun@shimun.net>"] authors = ["shimun <shimun@shimun.net>"]
edition = "2018" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -21,10 +21,11 @@ clap = { version = "4.0.29", features = ["env", "derive"] }
reqwest = { version = "0.11.13", optional = true } reqwest = { version = "0.11.13", optional = true }
serde = { version = "1.0.148", features = ["derive"] } serde = { version = "1.0.148", features = ["derive"] }
ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] } ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] }
thiserror = "1.0.37"
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] } tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
tower = { version = "0.4.13", features = ["util"] } tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.3.4", features = ["map-request-body"] } tower-http = { version = "0.3.4", features = ["map-request-body"] }
tracing = "0.1.37" tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
url = { version = "2.3.1", optional = true } url = { version = "2.3.1", optional = true }

View File

@ -4,44 +4,70 @@ use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{self, PathBuf}; use std::path::{self, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime;
use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert}; use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert};
use crate::env_key;
use anyhow::Context; use anyhow::Context;
use axum::body::{self}; use axum::body;
use axum::extract::{Path, State}; use axum::extract::{Path, State};
use axum::routing::{post, put}; use axum::routing::post;
use axum::{http::StatusCode, response::IntoResponse, Router}; use axum::{http::StatusCode, response::IntoResponse, Router};
use axum_extra::routing::{ use axum_extra::routing::{
RouterExt, // for `Router::typed_*` RouterExt, // for `Router::typed_*`
TypedPath, TypedPath,
}; };
use clap::Parser; use clap::{Args, Parser};
use serde::Deserialize; use serde::Deserialize;
use ssh_key::{Certificate, PublicKey}; use ssh_key::{Certificate, PublicKey};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::ServiceBuilderExt; use tower_http::ServiceBuilderExt;
use tracing::debug; use tracing::{debug, instrument, trace};
use self::extract::CertificateBody; use self::extract::CertificateBody;
#[derive(Parser)] #[derive(Parser)]
pub struct ApiArgs { pub struct ApiArgs {
#[clap(short = 'a', long = "address")] #[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))]
address: SocketAddr, address: SocketAddr,
#[clap(short = 'c', long = "cert-store")] #[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))]
cert_dir: PathBuf, cert_dir: PathBuf,
#[clap(flatten)]
validation_args: CertificateValidationArgs,
/// CA public key /// CA public key
#[clap(long = "ca")] #[clap(long = "ca", env = env_key!("CA"))]
ca: PathBuf, ca: PathBuf,
} }
#[derive(Debug, Args, Copy, Clone)]
pub struct CertificateValidationArgs {
/// Check whether an certificate update contains an greater serial number than the existing
/// certificate
#[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))]
validate_greater_expiry: bool,
/// Check whether an certificate update contains an expiry date further in the future than the
/// existing certificate
#[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))]
validate_greater_serial: bool,
}
impl Default for CertificateValidationArgs {
fn default() -> Self {
Self {
validate_greater_expiry: true,
validate_greater_serial: false,
}
}
}
impl Default for ApiArgs { impl Default for ApiArgs {
fn default() -> Self { fn default() -> Self {
Self { Self {
address: SocketAddr::from(([127, 0, 0, 1], 3000)), address: SocketAddr::from(([127, 0, 0, 1], 3000)),
cert_dir: "certs".into(), cert_dir: "certs".into(),
ca: "certs/ca.pub".into(), ca: "certs/ca.pub".into(),
validation_args: Default::default(),
} }
} }
} }
@ -51,12 +77,14 @@ struct ApiState {
certs: Arc<Mutex<HashMap<String, Certificate>>>, certs: Arc<Mutex<HashMap<String, Certificate>>>,
cert_dir: PathBuf, cert_dir: PathBuf,
ca: PublicKey, ca: PublicKey,
validation_args: CertificateValidationArgs,
} }
impl ApiState { impl ApiState {
async fn new( async fn new(
cert_dir: impl AsRef<path::Path>, cert_dir: impl AsRef<path::Path>,
ca_file: impl AsRef<path::Path>, ca_file: impl AsRef<path::Path>,
validation_args: CertificateValidationArgs,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
let ca = read_pubkey(ca_file.as_ref()).await?; let ca = read_pubkey(ca_file.as_ref()).await?;
let certs = read_certs(&ca, cert_dir.as_ref()).await?; let certs = read_certs(&ca, cert_dir.as_ref()).await?;
@ -69,6 +97,7 @@ impl ApiState {
)), )),
cert_dir: cert_dir.as_ref().into(), cert_dir: cert_dir.as_ref().into(),
ca, ca,
validation_args,
}) })
} }
} }
@ -78,9 +107,10 @@ pub async fn run(
address, address,
cert_dir, cert_dir,
ca, ca,
validation_args,
}: ApiArgs, }: ApiArgs,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let state = ApiState::new(&cert_dir, &ca).await?; let state = ApiState::new(&cert_dir, &ca, validation_args).await?;
#[cfg(feature = "reload")] #[cfg(feature = "reload")]
{ {
@ -104,6 +134,7 @@ pub async fn run(
.typed_get(get_certs_identifier) .typed_get(get_certs_identifier)
.typed_put(put_cert_update) .typed_put(put_cert_update)
.route("/certs/:identifier", post(post_certs_identifier)) .route("/certs/:identifier", post(post_certs_identifier))
.fallback(fallback_404)
.layer(ServiceBuilder::new().map_request_body(body::boxed)) .layer(ServiceBuilder::new().map_request_body(body::boxed))
.with_state(state); .with_state(state);
@ -117,36 +148,38 @@ pub async fn run(
Ok(()) Ok(())
} }
#[derive(Debug, thiserror::Error)]
pub enum ApiError { pub enum ApiError {
Internal, #[error("internal server error")]
NotFound, Internal(#[from] anyhow::Error),
Invalid, #[error("certificate not found")]
CertificateNotFound,
#[error("invalid certificate")]
CertificateInvalid,
#[error("serial must be greater than {1}")]
LowSerial(u64, u64), LowSerial(u64, u64),
#[error("expiry date must be greater than {0:?}")]
InsufficientValidity(SystemTime),
} }
type ApiResult<T> = Result<T, ApiError>; type ApiResult<T> = Result<T, ApiError>;
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { (
Self::NotFound => (StatusCode::NOT_FOUND, "not here").into_response(), match self {
Self::LowSerial(prev, next) => ( Self::CertificateNotFound => StatusCode::NOT_FOUND,
StatusCode::BAD_REQUEST, Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST,
format!( _ => StatusCode::INTERNAL_SERVER_ERROR,
"new certificate serial must be greater than {}, got {}", },
prev, next self.to_string(),
), )
) .into_response()
.into_response(),
_ => (StatusCode::INTERNAL_SERVER_ERROR, "Oops").into_response(),
}
} }
} }
impl From<anyhow::Error> for ApiError { async fn fallback_404() -> ApiResult<()> {
fn from(_: anyhow::Error) -> Self { Err(ApiError::CertificateNotFound)
ApiError::Internal
}
} }
#[derive(TypedPath, Deserialize)] #[derive(TypedPath, Deserialize)]
@ -159,17 +192,21 @@ pub struct GetCert {
/// TODO: add option to require auth /// TODO: add option to require auth
/// return Unauthorized with an challenge /// return Unauthorized with an challenge
/// upon which the client will ssh-keysign /// upon which the client will ssh-keysign
/// the challene an issue an post request /// the challenge an issue an post request
#[instrument(skip_all, ret)]
async fn get_certs_identifier( async fn get_certs_identifier(
GetCert { identifier }: GetCert, GetCert { identifier }: GetCert,
State(ApiState { certs, .. }): State<ApiState>, State(ApiState { certs, .. }): State<ApiState>,
) -> ApiResult<String> { ) -> ApiResult<String> {
let certs = certs.lock().await; let certs = certs.lock().await;
let cert = certs.get(&identifier).ok_or(ApiError::NotFound)?; let cert = certs
.get(&identifier)
.ok_or(ApiError::CertificateNotFound)?;
Ok(cert.to_openssh().context("to openssh")?) Ok(cert.to_openssh().context("to openssh")?)
} }
/// POST with signed challenge /// POST with signed challenge
#[instrument(skip_all, ret)]
async fn post_certs_identifier( async fn post_certs_identifier(
State(ApiState { .. }): State<ApiState>, State(ApiState { .. }): State<ApiState>,
Path(_identifier): Path<String>, Path(_identifier): Path<String>,
@ -182,27 +219,40 @@ async fn post_certs_identifier(
pub struct PutCert; pub struct PutCert;
/// Upload an cert with an higher serial than the previous /// Upload an cert with an higher serial than the previous
#[instrument(skip_all, ret)]
async fn put_cert_update( async fn put_cert_update(
_: PutCert, _: PutCert,
State(ApiState { State(ApiState {
ca, ca,
cert_dir, cert_dir,
certs, certs,
validation_args:
CertificateValidationArgs {
validate_greater_expiry,
validate_greater_serial,
},
.. ..
}): State<ApiState>, }): State<ApiState>,
CertificateBody(cert): CertificateBody, CertificateBody(cert): CertificateBody,
) -> ApiResult<String> { ) -> ApiResult<String> {
cert.validate(&[ca.fingerprint(Default::default())]) cert.validate(&[ca.fingerprint(Default::default())])
.map_err(|_| ApiError::Invalid)?; .map_err(|_| ApiError::CertificateInvalid)?;
let _string_repr = cert.to_openssh(); let _string_repr = cert.to_openssh();
let prev = load_cert_by_id(&cert_dir, &ca, &cert.key_id()).await?; let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?;
let mut prev_serial = 0; let mut prev_serial = 0;
let serial = cert.serial(); let serial = cert.serial();
if let Some(prev) = prev { if let Some(prev) = prev {
let prev_exp = prev.valid_before();
let exp = cert.valid_before();
trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate");
prev_serial = prev.serial(); prev_serial = prev.serial();
if prev.serial() >= cert.serial() { if validate_greater_serial && prev.serial() >= cert.serial() {
return Err(ApiError::LowSerial(prev_serial, serial)); return Err(ApiError::LowSerial(prev_serial, serial));
} }
// check if new certificate is valid for longer than the old one
if validate_greater_expiry && prev_exp >= exp {
return Err(ApiError::InsufficientValidity(prev.valid_before_time()));
}
} }
store_cert(&cert_dir, &ca, &cert).await?; store_cert(&cert_dir, &ca, &cert).await?;
certs.lock().await.insert(cert.key_id().to_string(), cert); certs.lock().await.insert(cert.key_id().to_string(), cert);

View File

@ -1,25 +1,27 @@
use anyhow::{bail, Context}; use anyhow::{bail};
use axum_extra::routing::TypedPath; use axum_extra::routing::TypedPath;
use clap::{Args, Parser, Subcommand}; use clap::{Args, Parser, Subcommand};
use reqwest::{Client, StatusCode}; use reqwest::{Client, StatusCode};
use ssh_key::Certificate; use ssh_key::Certificate;
use std::path::PathBuf; use std::path::PathBuf;
use std::time::{Duration, SystemTime};
use tokio::fs; use tokio::fs;
use tracing::{debug, error, info, instrument}; use tracing::{debug, error, info, instrument};
use tracing::{info_span, Instrument};
use url::Url; use url::Url;
use crate::api::PutCert; use crate::api::PutCert;
use crate::certs::load_cert; use crate::certs::load_cert;
use crate::env_key;
use crate::{ use crate::{
api::GetCert, api::GetCert,
certs::{self, read_dir}, certs::{read_dir},
}; };
#[derive(Parser)] #[derive(Parser)]
pub struct ClientArgs { pub struct ClientArgs {
/// Url for the API endpoint /// Url for the API endpoint
#[clap(short = 'a', long = "api-endpoint")] #[clap(short = 'a', long = "api-endpoint", env = env_key!("API"))]
api: Url, api: Url,
} }
@ -27,8 +29,11 @@ pub struct ClientArgs {
pub struct FetchArgs { pub struct FetchArgs {
#[clap(flatten)] #[clap(flatten)]
args: ClientArgs, args: ClientArgs,
#[clap(short = 'c', long = "cert-dir", default_value = "~/.ssh")] #[clap(short = 'c', long = "cert-dir", env = env_key!("CERT_DIR") )]
cert_dir: PathBuf, cert_dir: PathBuf,
/// minimum time in days between now and expiry to consider checking
#[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))]
min_delta_days: Option<u32>,
} }
#[derive(Parser)] #[derive(Parser)]
@ -36,6 +41,7 @@ pub struct UploadArgs {
#[clap(flatten)] #[clap(flatten)]
args: ClientArgs, args: ClientArgs,
/// Certificates to be uploaded /// Certificates to be uploaded
#[clap(env = env_key!("FILES"))]
files: Vec<PathBuf>, files: Vec<PathBuf>,
} }
@ -58,6 +64,9 @@ pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> {
} }
} }
#[derive(Debug, thiserror::Error)]
enum UploadError {}
async fn upload( async fn upload(
UploadArgs { UploadArgs {
args: ClientArgs { api }, args: ClientArgs { api },
@ -102,19 +111,29 @@ async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Res
async fn fetch( async fn fetch(
FetchArgs { FetchArgs {
cert_dir, cert_dir,
min_delta_days: min_delta,
args: ClientArgs { api }, args: ClientArgs { api },
}: FetchArgs, }: FetchArgs,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let certs = read_dir(&cert_dir).await?; let certs = read_dir(&cert_dir).await?;
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let updates = certs.into_iter().map(|cert| { let threshold_exp = min_delta.and_then(|min_delta| {
let path = GetCert { SystemTime::now().checked_sub(Duration::from_secs(60 * 60 * 24 * min_delta as u64))
identifier: cert.key_id().to_string(),
};
let url = api.join(path.to_uri().path()).unwrap();
let client = client.clone();
tokio::spawn(async move { fetch_cert(client, url, cert).await })
}); });
let updates = certs
.into_iter()
.filter(|cert| {
let exp = cert.valid_before_time();
threshold_exp.as_ref().map(|th| &exp < th).unwrap_or(true)
})
.map(|cert| {
let path = GetCert {
identifier: cert.key_id().to_string(),
};
let url = api.join(path.to_uri().path()).unwrap();
let client = client.clone();
tokio::spawn(async move { fetch_cert(client, url, cert).await })
});
for cert in updates { for cert in updates {
if let Ok(Some((cert, update))) = cert.await? { if let Ok(Some((cert, update))) = cert.await? {
fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?; fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?;

View File

@ -8,6 +8,13 @@ mod certs;
#[cfg(feature = "client")] #[cfg(feature = "client")]
mod client; mod client;
#[macro_export]
macro_rules! env_key {
( $var:expr ) => {
concat!("SSH_CD_", $var)
};
}
#[derive(Parser)] #[derive(Parser)]
enum Command { enum Command {
Server(ApiArgs), Server(ApiArgs),