2023-03-10 10:38:31 +01:00

596 lines
18 KiB
Rust

mod extract;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::path::{self, PathBuf};
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::Context;
use axum::body;
use axum::extract::{Query, State};
use chrono::Duration;
use ssh_cert_dist_common::*;
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
use axum_extra::routing::RouterExt;
use clap::{Args, Parser};
use jwt_compact::alg::{Hs256, Hs256Key};
use jwt_compact::{AlgorithmExt, Token, UntrustedToken};
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use ssh_key::{Certificate, Fingerprint, PublicKey};
use tokio::sync::Mutex;
use tower::ServiceBuilder;
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
use tracing::{debug, info, trace};
use self::extract::{CertificateBody, SignatureBody};
#[derive(Parser)]
pub struct ApiArgs {
#[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))]
address: SocketAddr,
#[clap(short = 'c', long = "cert-store", env = env_key!("CERT_DIR"))]
cert_dir: PathBuf,
#[clap(flatten)]
validation_args: CertificateValidationArgs,
/// CA public key
#[clap(long = "ca", env = env_key!("CA"))]
ca: PathBuf,
}
#[derive(Debug, Args, Copy, Clone)]
pub struct CertificateValidationArgs {
/// Check whether an certificate update contains an greater serial number than the existing
/// certificate
#[clap(short = 'e', long = "validate-expiry", env = env_key!("VALIDATE_EXPIRY"))]
validate_greater_expiry: bool,
/// Check whether an certificate update contains an expiry date further in the future than the
/// existing certificate
#[clap(short = 's', long = "validate-serial", env = env_key!("VALIDATE_SERIAL"))]
validate_greater_serial: bool,
}
impl Default for CertificateValidationArgs {
fn default() -> Self {
Self {
validate_greater_expiry: true,
validate_greater_serial: false,
}
}
}
impl Default for ApiArgs {
fn default() -> Self {
Self {
address: SocketAddr::from(([127, 0, 0, 1], 3000)),
cert_dir: "certs".into(),
ca: "certs/ca.pub".into(),
validation_args: Default::default(),
}
}
}
#[derive(Debug, Clone)]
struct ApiState {
certs: Arc<Mutex<HashMap<String, Certificate>>>,
cert_dir: PathBuf,
ca: PublicKey,
client_auth: bool,
validation_args: CertificateValidationArgs,
jwt_key: Hs256Key,
}
impl ApiState {
async fn new(
cert_dir: impl AsRef<path::Path>,
ca_file: impl AsRef<path::Path>,
validation_args: CertificateValidationArgs,
) -> anyhow::Result<Self> {
let ca = read_pubkey(ca_file.as_ref()).await?;
let certs = read_certs(&ca, cert_dir.as_ref()).await?;
Ok(Self {
certs: Arc::new(Mutex::new(
certs
.into_iter()
.map(|cert| (cert.key_id().to_string(), cert))
.collect(),
)),
cert_dir: cert_dir.as_ref().into(),
ca,
client_auth: false,
validation_args,
jwt_key: Hs256Key::new(thread_rng().gen::<[u8; 16]>()),
})
}
}
pub async fn run(
ApiArgs {
address,
cert_dir,
ca,
validation_args,
}: ApiArgs,
) -> anyhow::Result<()> {
let state = ApiState::new(&cert_dir, &ca, validation_args).await?;
#[cfg(feature = "reload")]
{
let state = state.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
if let Ok(certs) = read_certs(&state.ca, &state.cert_dir).await {
*state.certs.lock().await = certs
.into_iter()
.map(|cert| (cert.key_id().to_string(), cert))
.collect();
trace!("reloaded certs");
}
}
});
}
let app = Router::new()
.typed_get(get_certs_identifier)
.typed_get(get_certs_pubkey)
.typed_put(put_cert_update)
.typed_get(get_cert_info)
.typed_post(post_certs_identifier);
#[cfg(feature = "index")]
let app = app.typed_get(list_certs);
let app = app
.fallback(fallback_404)
.layer(ServiceBuilder::new().map_request_body(body::boxed))
.layer(TraceLayer::new_for_http())
.with_state(state);
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
debug!("listening on {}", address);
axum::Server::bind(&address)
.serve(app.into_make_service())
.await
.unwrap();
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum ApiError {
#[error("internal server error")]
Internal(#[from] anyhow::Error),
#[error("certificate not found")]
CertificateNotFound,
#[error("invalid certificate")]
CertificateInvalid,
#[error("serial must be greater than {1}")]
LowSerial(u64, u64),
#[error("expiry date must be greater than {0:?}")]
InsufficientValidity(SystemTime),
#[error("authentication required")]
AuthenticationRequired(String),
#[error("invalid ssh signature")]
InvalidSignature,
}
type ApiResult<T> = Result<T, ApiError>;
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
(
match self {
Self::CertificateNotFound => StatusCode::NOT_FOUND,
Self::LowSerial(_, _) | Self::InsufficientValidity(_) => StatusCode::BAD_REQUEST,
Self::AuthenticationRequired(challenge) => {
return (StatusCode::UNAUTHORIZED, challenge).into_response()
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
},
self.to_string(),
)
.into_response()
}
}
async fn fallback_404() -> ApiResult<()> {
Err(ApiError::CertificateNotFound)
}
#[cfg(feature = "index")]
async fn list_certs(
_: CertList,
State(ApiState { certs, .. }): State<ApiState>,
) -> ApiResult<Json<Vec<String>>> {
Ok(Json(
certs
.lock()
.await
.values()
.into_iter()
.map(|cert| cert.key_id().to_string())
.collect(),
))
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "aud", rename = "get")]
struct AuthClaims {
identifier: String,
}
async fn request_client_auth(enabled: bool, identifier: &str, jwt_key: &Hs256Key) -> ApiResult<()> {
use jwt_compact::{Claims, Header, TimeOptions};
if enabled {
let claims = Claims::new(AuthClaims {
identifier: identifier.into(),
})
.set_duration(&TimeOptions::default(), chrono::Duration::seconds(120));
let challenge = Hs256
.compact_token(Header::default(), &claims, &jwt_key)
.context("jwt sign")?;
return Err(ApiError::AuthenticationRequired(challenge));
} else {
Ok(())
}
}
/// Retrieve an certificate for identifier
/// TODO: add option to require auth
/// return Unauthorized with an challenge
/// upon which the client will ssh-keysign
/// the challenge an issue an post request
async fn get_certs_identifier(
GetCert { identifier }: GetCert,
State(ApiState {
certs,
jwt_key,
client_auth,
..
}): State<ApiState>,
) -> ApiResult<String> {
request_client_auth(client_auth, &identifier, &jwt_key).await?;
let certs = certs.lock().await;
let cert = certs
.get(&identifier)
.ok_or(ApiError::CertificateNotFound)?;
Ok(cert.to_openssh().context("to openssh")?)
}
async fn get_certs_pubkey(
GetCertsPubkey { pubkey_hash }: GetCertsPubkey,
State(ApiState {
certs,
jwt_key: _,
client_auth: _,
..
}): State<ApiState>,
) -> ApiResult<Json<CertIds>> {
let certs = certs.lock().await;
let ids = certs
.values()
.filter(|cert| &cert.public_key().fingerprint(pubkey_hash.algorithm()) == &pubkey_hash)
.map(|cert| cert.key_id().to_string())
.collect::<Vec<_>>();
Ok(Json(CertIds { ids }))
}
#[cfg(feature = "info")]
#[derive(Debug, Serialize)]
struct CertInfo {
principals: Vec<String>,
ca: PublicKey,
ca_hash: Fingerprint,
identity: PublicKey,
identity_hash: Fingerprint,
key_id: String,
expiry: SystemTime,
renew_command: String,
}
impl From<&Certificate> for CertInfo {
fn from(cert: &Certificate) -> Self {
let validity = cert.valid_before_time().duration_since(cert.valid_after_time()).unwrap_or(Duration::zero().to_std().unwrap());
let validity_days = validity.as_secs() / ((60*60) * 24);
let host_key = if cert.cert_type().is_host() {
" -h"
} else { "" };
let opts = cert.critical_options().iter().map(|(opt, val)| if val.is_empty() { opt.clone() } else { format!("{opt}={val}") }).map(|arg| format!("-O {arg}")).collect::<Vec<_>>().join(" ");
let renew_command = format!("ssh-keygen -s ./ca_key {host_key} -I {} -n {} -V {validity_days}d {opts}", cert.key_id(), cert.valid_principals().join(","));
CertInfo {
principals: cert.valid_principals().to_vec(),
ca: cert.signature_key().clone().into(),
ca_hash: cert.signature_key().fingerprint(ssh_key::HashAlg::Sha256),
identity: cert.public_key().clone().into(),
identity_hash: cert.public_key().fingerprint(ssh_key::HashAlg::Sha256),
key_id: cert.key_id().to_string(),
expiry: cert.valid_before_time(),
renew_command
}
}
}
#[cfg(feature = "info")]
async fn get_cert_info(
GetCertInfo { identifier }: GetCertInfo,
State(ApiState { certs, .. }): State<ApiState>,
) -> ApiResult<Json<CertInfo>> {
let certs = certs.lock().await;
let cert = certs
.get(&identifier)
.ok_or(ApiError::CertificateNotFound)?;
Ok(Json(cert.into()))
}
#[cfg(not(feature = "info"))]
async fn get_cert_info(
GetCertInfo { identifier: _ }: GetCertInfo,
State(ApiState { certs: _, .. }): State<ApiState>,
) -> ApiResult<()> {
unimplemented!()
}
#[derive(Debug, Deserialize)]
struct PostCertsQuery {
challenge: String,
}
/// POST with signed challenge
async fn post_certs_identifier(
PostCertInfo { identifier }: PostCertInfo,
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
SignatureBody(sig): SignatureBody,
) -> ApiResult<String> {
let certs = certs.lock().await;
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
let token: Token<AuthClaims> = Hs256
.validate_integrity(
&UntrustedToken::new(&challenge).context("jwt parse")?,
&jwt_key,
)
.map_err(|_| ApiError::InvalidSignature)?;
if token.claims().custom.identifier != identifier {
return Err(ApiError::InvalidSignature);
}
let pubkey: PublicKey = cert.public_key().clone().into();
let verification = tokio::task::spawn_blocking(move || {
pubkey
.verify(&identifier, challenge.as_bytes(), &sig)
.map_err(|_| ApiError::InvalidSignature)
})
.await
.context("tokio blocking")?;
verification?;
Ok(cert.to_openssh().context("to openssh")?)
}
/// Upload an cert with an higher serial than the previous
async fn put_cert_update(
_: PutCert,
State(ApiState {
ca,
cert_dir,
certs,
validation_args:
CertificateValidationArgs {
validate_greater_expiry,
validate_greater_serial,
},
..
}): State<ApiState>,
CertificateBody(cert): CertificateBody,
) -> ApiResult<String> {
let cert = {
let ca = ca.clone();
tokio::task::spawn_blocking(move || -> ApiResult<Certificate> {
let cert = cert;
cert.validate(&[ca.fingerprint(Default::default())])
.map_err(|_| ApiError::CertificateInvalid)?;
Ok(cert)
})
.await
.context("signature verification")??
};
let prev = load_cert_by_id(&cert_dir, &ca, cert.key_id()).await?;
let mut prev_serial = 0;
let serial = cert.serial();
if let Some(prev) = prev {
let prev_exp = prev.valid_before();
let exp = cert.valid_before();
trace!(%prev_serial, %serial, ?prev_exp, ?exp, "comparing to previous certificate");
prev_serial = prev.serial();
if validate_greater_serial && prev.serial() >= cert.serial() {
return Err(ApiError::LowSerial(prev_serial, serial));
}
// check if new certificate is valid for longer than the old one
if validate_greater_expiry && prev_exp >= exp {
return Err(ApiError::InsufficientValidity(prev.valid_before_time()));
}
}
store_cert(&cert_dir, &ca, &cert).await?;
let principals = cert.valid_principals();
let identity = cert.key_id();
info!(%identity, ?principals, "updating certificate");
certs.lock().await.insert(cert.key_id().to_string(), cert);
Ok(format!("{prev_serial} -> {serial}"))
}
#[cfg(test)]
mod tests {
use ssh_key::{certificate, private::Ed25519Keypair, PrivateKey};
use std::env::temp_dir;
use std::time::Duration;
use super::*;
fn ca_key() -> Ed25519Keypair {
Ed25519Keypair::from_seed(&[0u8; 32])
}
fn ca_key2() -> Ed25519Keypair {
Ed25519Keypair::from_seed(&[10u8; 32])
}
fn ca_pub() -> PublicKey {
PublicKey::new(
ca_key().public.into(),
format!(
"TEST CA {}",
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
),
)
}
fn user_key() -> Ed25519Keypair {
Ed25519Keypair::from_seed(&[1u8; 32])
}
fn user_cert(ca: Ed25519Keypair, user_key: PublicKey, validity: Duration) -> Certificate {
let ca_private: PrivateKey = ca.into();
let unix_time = |time: SystemTime| -> u64 {
time.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
};
let mut builder = certificate::Builder::new(
[0u8; 16],
user_key,
unix_time(SystemTime::now()),
unix_time(SystemTime::now() + validity),
);
builder
.valid_principal("git")
.unwrap()
.key_id("test_cert")
.unwrap()
.comment(&format!("A TEST CERT, VALID FOR {}s", validity.as_secs()))
.unwrap();
builder.sign(&ca_private).unwrap()
}
fn api_state() -> ApiState {
let ca: PublicKey = ca_pub();
ApiState {
ca,
certs: Default::default(),
cert_dir: dbg!(temp_dir()),
validation_args: Default::default(),
client_auth: false,
jwt_key: Hs256Key::new([0u8; 16]),
}
}
#[test]
fn test_certificate() {
let valid_cert = user_cert(ca_key(), user_key().public.into(), Duration::from_secs(30));
let ca_pub: PublicKey = ca_pub();
assert!(valid_cert
.validate(&[ca_pub.fingerprint(Default::default())])
.is_ok());
}
#[tokio::test]
async fn update_cert() {
let state = api_state();
let ca = ca_key();
let user: PublicKey = user_key().public.into();
let (cert_first, cert_newer, cert_outdated) = {
(
user_cert(ca.clone(), user.clone(), Duration::from_secs(300)),
user_cert(ca.clone(), user.clone(), Duration::from_secs(600)),
user_cert(ca.clone(), user.clone(), Duration::from_secs(30)),
)
};
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_first)).await;
assert!(res.is_ok());
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_newer)).await;
assert!(res.is_ok());
let res = put_cert_update(
PutCert,
State(state.clone()),
CertificateBody(cert_outdated),
)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn routes() -> anyhow::Result<()> {
let state = api_state();
let valid_cert = user_cert(ca_key(), user_key().public.into(), Duration::from_secs(30));
let invalid_cert = user_cert(ca_key2(), user_key().public.into(), Duration::from_secs(30));
let res = put_cert_update(
PutCert,
State(state.clone()),
CertificateBody(valid_cert.clone()),
)
.await;
assert!(dbg!(res).is_ok());
assert_eq!(state.certs.lock().await.get("test_cert"), Some(&valid_cert));
let res = put_cert_update(
PutCert,
State(state.clone()),
CertificateBody(invalid_cert.clone()),
)
.await;
assert!(matches!(res, Err(ApiError::CertificateInvalid)));
let cert = get_certs_identifier(
GetCert {
identifier: "test_cert".into(),
},
State(state.clone()),
)
.await?;
assert_eq!(cert, valid_cert.to_openssh()?);
let res = get_certs_identifier(
GetCert {
identifier: "missing_cert".into(),
},
State(state.clone()),
)
.await;
assert!(matches!(res, Err(ApiError::CertificateNotFound)));
let state = ApiState {
client_auth: true,
..state
};
let res = get_certs_identifier(
GetCert {
identifier: "test_cert".into(),
},
State(state.clone()),
)
.await;
assert!(matches!(res, Err(ApiError::AuthenticationRequired(_))));
if let Err(ApiError::AuthenticationRequired(challenge)) = res {
let signing_key: PrivateKey = user_key().into();
let sig = signing_key.sign("test_cert", Default::default(), challenge.as_bytes())?;
let cert = post_certs_identifier(
PostCertInfo {
identifier: "test_cert".into(),
},
State(state.clone()),
Query(PostCertsQuery { challenge }),
SignatureBody(sig),
)
.await?;
assert_eq!(cert, valid_cert.to_openssh()?);
}
Ok(())
}
}