shimun 9f6a5e03c9
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
feat(complete): add command name
2023-07-12 22:52:59 +02:00

260 lines
7.9 KiB
Rust

use anyhow::{bail, Context};
use axum_extra::routing::TypedPath;
use clap::{CommandFactory, Parser, Subcommand, ValueHint};
use clap_complete_command::Shell;
use reqwest::{Client, StatusCode};
use ssh_key::Certificate;
use std::path::PathBuf;
use std::process;
use std::time::{Duration, SystemTime};
use tokio::fs;
use tokio::io::{stdin, AsyncBufReadExt, BufReader};
use tracing::{debug, error, info, instrument, trace};
use url::Url;
use ssh_cert_dist_common::*;
#[derive(Parser)]
pub struct ClientArgs {
/// Url for the API endpoint
#[clap(short = 'a', long = "api-endpoint",value_hint = ValueHint::Url, env = env_key!("API"))]
api: Url,
/// Require interaction before writing certificates
#[clap(short = 'i', long = "interactive", env = env_key!("INTERACTIVE"))]
interactive: bool,
}
#[derive(Parser)]
pub struct FetchArgs {
#[clap(flatten)]
args: ClientArgs,
#[clap(short = 'k', long = "key-update", env = env_key!("KEY_UPDATE"))]
prohibit_key_update: bool,
#[clap(short = 'c', long = "cert-dir",value_hint = ValueHint::DirPath, env = env_key!("CERT_DIR"))]
cert_dir: PathBuf,
/// minimum time in days between now and expiry to consider checking
#[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))]
min_delta_days: Option<u32>,
}
#[derive(Parser)]
pub struct UploadArgs {
#[clap(flatten)]
args: ClientArgs,
/// Certificates to be uploaded
#[clap(value_hint = ValueHint::FilePath, env = env_key!("FILES"))]
files: Vec<PathBuf>,
}
#[derive(Parser)]
pub struct RenewCommandArgs {
/// Execute the renew command
#[clap(short = 'x')]
execute: bool,
/// Path to the CA private key
#[clap(long="ca",value_hint = ValueHint::DirPath, env = env_key!("CA_KEY"))]
ca_key: Option<PathBuf>,
/// Certificates to generate commands for
#[clap(value_hint = ValueHint::FilePath,env = env_key!("FILES"))]
files: Vec<PathBuf>,
}
#[derive(Parser)]
#[command(name = "sshcd")]
pub struct ClientCommand {
#[clap(subcommand)]
cmd: ClientCommands,
}
#[derive(Subcommand)]
pub enum ClientCommands {
Fetch(FetchArgs),
Upload(UploadArgs),
RenewCommand(RenewCommandArgs),
#[clap(hide = true)]
Completions {
#[arg(long = "shell", value_enum)]
shell: Shell,
}
}
pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> {
match cmd {
ClientCommands::Fetch(args) => fetch(args).await,
ClientCommands::Upload(args) => upload(args).await,
ClientCommands::RenewCommand(args) => renew(args).await,
ClientCommands::Completions { shell } => {
shell.generate(&mut ClientCommand::command(), &mut std::io::stdout());
Ok(())
}
}
}
#[derive(Debug, thiserror::Error)]
enum UploadError {}
async fn upload(
UploadArgs {
args: ClientArgs { api, .. },
files,
}: UploadArgs,
) -> anyhow::Result<()> {
let client = reqwest::Client::new();
let mut certs = Vec::new();
for path in files.into_iter() {
if let Some(cert) = load_cert(&path).await? {
certs.push(cert);
}
}
let uploads = certs.into_iter().map(|cert| {
let client = client.clone();
let path = PutCert;
let url = api.join(path.to_uri().path()).unwrap();
tokio::spawn(async move { upload_cert(client, url, cert).await })
});
for upload in uploads {
let _ = upload.await;
}
Ok(())
}
#[instrument(skip(client, cert), ret)]
async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Result<()> {
let resp = client
.put(url.clone())
.body(cert.to_openssh()?)
.send()
.await?;
let status = resp.status();
if ![StatusCode::OK, StatusCode::CREATED].contains(&status) {
let id = cert.key_id();
error!(%id, %status, "failed to upload cert");
bail!("failed to upload {id}, error code {status}");
}
Ok(())
}
async fn fetch(
FetchArgs {
cert_dir,
prohibit_key_update,
min_delta_days: min_delta,
args: ClientArgs { api, interactive },
}: FetchArgs,
) -> anyhow::Result<()> {
let certs = read_certs_dir(&cert_dir).await?;
// let publics_keys = read_pubkey_dir(&cert_dir).await?;
let client = reqwest::Client::new();
let threshold_exp = min_delta.and_then(|min_delta| {
SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64))
});
// let standalone_certs = publics_keys.into_iter().map(|(name, key)| )
let updates = certs
.into_iter()
.filter(|cert| {
let exp = cert.valid_before_time();
let must_check = threshold_exp.as_ref().map(|th| &exp < th);
trace!(?cert, ?must_check, "filter");
must_check.unwrap_or(true)
})
.map(|cert| {
let path = GetCert {
identifier: cert.key_id().to_string(),
};
let url = api.join(path.to_uri().path()).unwrap();
let client = client.clone();
tokio::spawn(async move { fetch_cert(client, url, cert).await })
});
let mut stdin = BufReader::new(stdin()).lines();
for cert in updates {
if let Ok(Some((cert, update))) = cert.await? {
if prohibit_key_update && cert.public_key() != update.public_key() {
debug!(?update, "skipping cert due to key change");
continue;
}
if interactive {
println!("certificate update: {}", cert.key_id());
println!(
"principals: {:?}, expiry: {}",
update.valid_principals(),
update.valid_before()
);
println!("update? : (y/n)");
let yes = stdin.next_line().await?;
if !matches!(yes, Some(line) if line.starts_with(['y', 'Y'])) {
break;
}
}
fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?;
let key_id = cert.key_id();
info!(
%key_id,
"updated certificate",
);
}
}
Ok(())
}
async fn renew(
RenewCommandArgs {
files,
ca_key,
execute,
}: RenewCommandArgs,
) -> anyhow::Result<()> {
for file in files.iter() {
let cert = load_cert(&file).await?;
if let Some(cert) = cert {
let command = renew_command(
&cert,
ca_key
.as_deref()
.map(|path| path.to_str())
.flatten()
.unwrap_or("ca"),
file.to_str(),
);
println!("{}", command);
if execute {
process::Command::new("sh")
.arg("-c")
.arg(&command)
.spawn()
.with_context(|| format!("{command}"))?;
}
} else {
bail!("{file:?} doesn't exist");
}
}
Ok(())
}
#[instrument(skip(client, current))]
async fn fetch_cert(
client: Client,
url: Url,
current: Certificate,
) -> anyhow::Result<Option<(Certificate, Certificate)>> {
debug!("checking {}", current.key_id());
let resp = client.get(url.clone()).send().await?;
if resp.status() != StatusCode::OK {
return Ok(None);
}
let string_repr = resp.text().await?;
let remote_cert = Certificate::from_openssh(&string_repr)?;
if remote_cert
.validate(&[current.signature_key().fingerprint(Default::default())])
.is_err()
{
info!("invalid signature {}, skipping", &url);
return Ok(None);
}
if current.serial() >= remote_cert.serial() {
debug!("{} is not newer than local version", &url);
return Ok(None);
}
Ok(Some((current, remote_cert)))
}