diff --git a/Cargo.toml b/Cargo.toml index 84bfa08..25f9051 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,3 +33,6 @@ url = { version = "2.3.1", optional = true } [patch.crates-io] ssh-key = { git = "https://github.com/a-dma/SSH.git", branch = "u2f_signatures" } +[dev-dependencies] +tempfile = "3.3.0" + diff --git a/src/api.rs b/src/api.rs index 7269dec..24fb24e 100644 --- a/src/api.rs +++ b/src/api.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::path::{self, PathBuf}; use std::sync::Arc; -use std::time::SystemTime; +use std::time::{Duration, SystemTime}; use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert}; use crate::env_key; @@ -19,7 +19,8 @@ use axum_extra::routing::{ }; use clap::{Args, Parser}; use serde::Deserialize; -use ssh_key::{Certificate, PublicKey}; +use ssh_key::private::Ed25519Keypair; +use ssh_key::{certificate, Certificate, PrivateKey, PublicKey}; use tokio::sync::Mutex; use tower::ServiceBuilder; use tower_http::{trace::TraceLayer, ServiceBuilderExt}; @@ -298,3 +299,110 @@ async fn put_cert_update( certs.lock().await.insert(cert.key_id().to_string(), cert); Ok(format!("{} -> {}", prev_serial, serial)) } + +#[cfg(test)] +mod tests { + use std::env::temp_dir; + + use super::*; + + fn ca_key() -> Ed25519Keypair { + Ed25519Keypair::from_seed(&[0u8; 32]) + } + + fn ca_key2() -> Ed25519Keypair { + Ed25519Keypair::from_seed(&[10u8; 32]) + } + + fn ca_pub() -> PublicKey { + PublicKey::new(ca_key().public.into(), "TEST CA") + } + + fn user_key() -> Ed25519Keypair { + Ed25519Keypair::from_seed(&[1u8; 32]) + } + + fn user_cert(ca: Ed25519Keypair, user_key: PublicKey) -> Certificate { + let ca_private: PrivateKey = ca.into(); + let unix_time = |time: SystemTime| -> u64 { + time.duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + }; + let mut builder = certificate::Builder::new( + [0u8; 16], + user_key, + unix_time(SystemTime::now()), + unix_time(SystemTime::now() + Duration::from_secs(30)), + ); + + builder + .valid_principal("git") + .unwrap() + .key_id("test_cert") + .unwrap() + .comment("A TEST CERT") + .unwrap(); + + builder.sign(&ca_private).unwrap() + } + + fn api_state() -> ApiState { + let ca: PublicKey = ca_pub(); + ApiState { + ca, + certs: Default::default(), + cert_dir: dbg!(temp_dir()), + validation_args: Default::default(), + } + } + + #[test] + fn test_certificate() { + let valid_cert = user_cert(ca_key(), user_key().public.into()); + let ca_pub: PublicKey = ca_pub(); + assert!(valid_cert + .validate(&[ca_pub.fingerprint(Default::default())]) + .is_ok()); + } + + #[tokio::test] + async fn routes() -> anyhow::Result<()> { + let state = api_state(); + let valid_cert = user_cert(ca_key(), user_key().public.into()); + let invalid_cert = user_cert(ca_key2(), user_key().public.into()); + let res = put_cert_update( + PutCert, + State(state.clone()), + CertificateBody(valid_cert.clone()), + ) + .await; + assert!(dbg!(res).is_ok()); + assert_eq!(state.certs.lock().await.get("test_cert"), Some(&valid_cert)); + let res = put_cert_update( + PutCert, + State(state.clone()), + CertificateBody(invalid_cert.clone()), + ) + .await; + assert!(matches!(res, Err(ApiError::CertificateInvalid))); + + let cert = get_certs_identifier( + GetCert { + identifier: "test_cert".into(), + }, + State(state.clone()), + ) + .await?; + assert_eq!(cert, valid_cert.to_openssh()?); + let res = get_certs_identifier( + GetCert { + identifier: "missing_cert".into(), + }, + State(state.clone()), + ) + .await; + assert!(matches!(res, Err(ApiError::CertificateNotFound))); + Ok(()) + } +} diff --git a/src/certs.rs b/src/certs.rs index 2e0a26c..67cff7b 100644 --- a/src/certs.rs +++ b/src/certs.rs @@ -8,11 +8,19 @@ use std::{ use tokio::fs; use tracing::{instrument, trace}; +#[derive(Debug, thiserror::Error)] +pub enum CertError { + #[error("missing key id")] + NoKID, + #[error("missing ca identifier (comment)")] + NoCAComment, +} + pub async fn read_certs( ca: &PublicKey, path: impl AsRef, ) -> anyhow::Result> { - let ca_dir = path.as_ref().join(ca_dir(ca)); + let ca_dir = path.as_ref().join(ca_dir(ca)?); if !ca_dir.exists() { return Ok(Vec::new()); } @@ -60,13 +68,19 @@ pub async fn read_pubkey(path: impl AsRef) -> anyhow::Result { .with_context(|| format!("parse '{}' as public key", string_repr)) } -fn ca_dir(ca: &PublicKey) -> String { - ca.comment().to_string() +fn ca_dir(ca: &PublicKey) -> Result { + if ca.comment().is_empty() { + return Err(CertError::NoCAComment); + } + Ok(ca.comment().to_string()) } #[instrument] -fn cert_path(ca: &PublicKey, identifier: &str) -> String { - format!("{}/{}-cert.pub", ca_dir(ca), identifier) +fn cert_path(ca: &PublicKey, identifier: &str) -> Result { + if identifier.is_empty() { + return Err(CertError::NoKID); + } + Ok(format!("{}/{}-cert.pub", ca_dir(ca)?, identifier)) } #[instrument] @@ -76,11 +90,15 @@ pub async fn store_cert( cert: &Certificate, ) -> anyhow::Result { // TODO: proper store - let path = cert_dir.as_ref().join(cert_path(&ca, cert.key_id())); + let path = cert_dir.as_ref().join(cert_path(&ca, cert.key_id())?); if let Some(parent) = path.parent() { - fs::create_dir_all(parent).await?; + fs::create_dir_all(parent) + .await + .with_context(|| format!("mkdir -p {parent:?}"))?; } - fs::write(&path, cert.to_openssh().context("encode cert")?).await?; + fs::write(&path, cert.to_openssh().context("encode cert")?) + .await + .context("write cert")?; Ok(path) } @@ -89,7 +107,7 @@ pub async fn load_cert_by_id( ca: &PublicKey, identifier: &str, ) -> anyhow::Result> { - let path = cert_dir.as_ref().join(cert_path(ca, identifier)); + let path = cert_dir.as_ref().join(cert_path(ca, identifier)?); load_cert(&path).await }