use anyhow::{bail, Context}; use axum_extra::routing::TypedPath; use clap::{CommandFactory, Parser, Subcommand, ValueHint}; use clap_complete_command::Shell; use reqwest::{Client, StatusCode}; use ssh_key::Certificate; use std::path::PathBuf; use std::process; use std::time::{Duration, SystemTime}; use tokio::fs; use tokio::io::{stdin, AsyncBufReadExt, BufReader}; use tracing::{debug, error, info, instrument, trace}; use url::Url; use ssh_cert_dist_common::*; #[derive(Parser)] pub struct ClientArgs { /// Url for the API endpoint #[clap(short = 'a', long = "api-endpoint",value_hint = ValueHint::Url, env = env_key!("API"))] api: Url, /// Require interaction before writing certificates #[clap(short = 'i', long = "interactive", env = env_key!("INTERACTIVE"))] interactive: bool, } #[derive(Parser)] pub struct FetchArgs { #[clap(flatten)] args: ClientArgs, #[clap(short = 'k', long = "key-update", env = env_key!("KEY_UPDATE"))] prohibit_key_update: bool, #[clap(short = 'c', long = "cert-dir",value_hint = ValueHint::DirPath, env = env_key!("CERT_DIR"))] cert_dir: PathBuf, /// minimum time in days between now and expiry to consider checking #[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))] min_delta_days: Option, } #[derive(Parser)] pub struct UploadArgs { #[clap(flatten)] args: ClientArgs, /// Certificates to be uploaded #[clap(value_hint = ValueHint::FilePath, env = env_key!("FILES"))] files: Vec, } #[derive(Parser)] pub struct RenewCommandArgs { /// Execute the renew command #[clap(short = 'x')] execute: bool, /// Path to the CA private key #[clap(long="ca",value_hint = ValueHint::DirPath, env = env_key!("CA_KEY"))] ca_key: Option, /// Certificates to generate commands for #[clap(value_hint = ValueHint::FilePath,env = env_key!("FILES"))] files: Vec, } #[derive(Parser)] pub struct ClientCommand { #[clap(subcommand)] cmd: ClientCommands, } #[derive(Subcommand)] pub enum ClientCommands { Fetch(FetchArgs), Upload(UploadArgs), RenewCommand(RenewCommandArgs), #[clap(hide = true)] Completions { #[arg(long = "shell", value_enum)] shell: Shell, } } pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> { match cmd { ClientCommands::Fetch(args) => fetch(args).await, ClientCommands::Upload(args) => upload(args).await, ClientCommands::RenewCommand(args) => renew(args).await, ClientCommands::Completions { shell } => { shell.generate(&mut ClientCommand::command(), &mut std::io::stdout()); Ok(()) } } } #[derive(Debug, thiserror::Error)] enum UploadError {} async fn upload( UploadArgs { args: ClientArgs { api, .. }, files, }: UploadArgs, ) -> anyhow::Result<()> { let client = reqwest::Client::new(); let mut certs = Vec::new(); for path in files.into_iter() { if let Some(cert) = load_cert(&path).await? { certs.push(cert); } } let uploads = certs.into_iter().map(|cert| { let client = client.clone(); let path = PutCert; let url = api.join(path.to_uri().path()).unwrap(); tokio::spawn(async move { upload_cert(client, url, cert).await }) }); for upload in uploads { let _ = upload.await; } Ok(()) } #[instrument(skip(client, cert), ret)] async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Result<()> { let resp = client .put(url.clone()) .body(cert.to_openssh()?) .send() .await?; let status = resp.status(); if ![StatusCode::OK, StatusCode::CREATED].contains(&status) { let id = cert.key_id(); error!(%id, %status, "failed to upload cert"); bail!("failed to upload {id}, error code {status}"); } Ok(()) } async fn fetch( FetchArgs { cert_dir, prohibit_key_update, min_delta_days: min_delta, args: ClientArgs { api, interactive }, }: FetchArgs, ) -> anyhow::Result<()> { let certs = read_certs_dir(&cert_dir).await?; // let publics_keys = read_pubkey_dir(&cert_dir).await?; let client = reqwest::Client::new(); let threshold_exp = min_delta.and_then(|min_delta| { SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64)) }); // let standalone_certs = publics_keys.into_iter().map(|(name, key)| ) let updates = certs .into_iter() .filter(|cert| { let exp = cert.valid_before_time(); let must_check = threshold_exp.as_ref().map(|th| &exp < th); trace!(?cert, ?must_check, "filter"); must_check.unwrap_or(true) }) .map(|cert| { let path = GetCert { identifier: cert.key_id().to_string(), }; let url = api.join(path.to_uri().path()).unwrap(); let client = client.clone(); tokio::spawn(async move { fetch_cert(client, url, cert).await }) }); let mut stdin = BufReader::new(stdin()).lines(); for cert in updates { if let Ok(Some((cert, update))) = cert.await? { if prohibit_key_update && cert.public_key() != update.public_key() { debug!(?update, "skipping cert due to key change"); continue; } if interactive { println!("certificate update: {}", cert.key_id()); println!( "principals: {:?}, expiry: {}", update.valid_principals(), update.valid_before() ); println!("update? : (y/n)"); let yes = stdin.next_line().await?; if !matches!(yes, Some(line) if line.starts_with(['y', 'Y'])) { break; } } fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?; let key_id = cert.key_id(); info!( %key_id, "updated certificate", ); } } Ok(()) } async fn renew( RenewCommandArgs { files, ca_key, execute, }: RenewCommandArgs, ) -> anyhow::Result<()> { for file in files.iter() { let cert = load_cert(&file).await?; if let Some(cert) = cert { let command = renew_command( &cert, ca_key .as_deref() .map(|path| path.to_str()) .flatten() .unwrap_or("ca"), file.to_str(), ); println!("{}", command); if execute { process::Command::new("sh") .arg("-c") .arg(&command) .spawn() .with_context(|| format!("{command}"))?; } } else { bail!("{file:?} doesn't exist"); } } Ok(()) } #[instrument(skip(client, current))] async fn fetch_cert( client: Client, url: Url, current: Certificate, ) -> anyhow::Result> { debug!("checking {}", current.key_id()); let resp = client.get(url.clone()).send().await?; if resp.status() != StatusCode::OK { return Ok(None); } let string_repr = resp.text().await?; let remote_cert = Certificate::from_openssh(&string_repr)?; if remote_cert .validate(&[current.signature_key().fingerprint(Default::default())]) .is_err() { info!("invalid signature {}, skipping", &url); return Ok(None); } if current.serial() >= remote_cert.serial() { debug!("{} is not newer than local version", &url); return Ok(None); } Ok(Some((current, remote_cert))) }