158 lines
4.1 KiB
Rust
158 lines
4.1 KiB
Rust
use anyhow::{bail, Context};
|
|
use axum_extra::routing::TypedPath;
|
|
use clap::{Args, Parser, Subcommand};
|
|
use reqwest::{Client, StatusCode};
|
|
use ssh_key::Certificate;
|
|
use std::path::PathBuf;
|
|
use tokio::fs;
|
|
use tracing::{debug, error, info, instrument};
|
|
use tracing::{info_span, Instrument};
|
|
use url::Url;
|
|
|
|
use crate::api::PutCert;
|
|
use crate::certs::load_cert;
|
|
use crate::{
|
|
api::GetCert,
|
|
certs::{self, read_dir},
|
|
};
|
|
|
|
#[derive(Parser)]
|
|
pub struct ClientArgs {
|
|
/// Url for the API endpoint
|
|
#[clap(short = 'a', long = "api-endpoint")]
|
|
api: Url,
|
|
}
|
|
|
|
#[derive(Parser)]
|
|
pub struct FetchArgs {
|
|
#[clap(flatten)]
|
|
args: ClientArgs,
|
|
#[clap(short = 'c', long = "cert-dir", default_value = "~/.ssh")]
|
|
cert_dir: PathBuf,
|
|
}
|
|
|
|
#[derive(Parser)]
|
|
pub struct UploadArgs {
|
|
#[clap(flatten)]
|
|
args: ClientArgs,
|
|
/// Certificates to be uploaded
|
|
files: Vec<PathBuf>,
|
|
}
|
|
|
|
#[derive(Args)]
|
|
pub struct ClientCommand {
|
|
#[clap(subcommand)]
|
|
cmd: ClientCommands,
|
|
}
|
|
|
|
#[derive(Subcommand)]
|
|
pub enum ClientCommands {
|
|
Fetch(FetchArgs),
|
|
Upload(UploadArgs),
|
|
}
|
|
|
|
pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> {
|
|
match cmd {
|
|
ClientCommands::Fetch(args) => fetch(args).await,
|
|
ClientCommands::Upload(args) => upload(args).await,
|
|
}
|
|
}
|
|
|
|
async fn upload(
|
|
UploadArgs {
|
|
args: ClientArgs { api },
|
|
files,
|
|
}: UploadArgs,
|
|
) -> anyhow::Result<()> {
|
|
let client = reqwest::Client::new();
|
|
let mut certs = Vec::new();
|
|
for path in files.into_iter() {
|
|
if let Some(cert) = load_cert(&path).await? {
|
|
certs.push(cert);
|
|
}
|
|
}
|
|
let uploads = certs.into_iter().map(|cert| {
|
|
let client = client.clone();
|
|
let path = PutCert;
|
|
let url = api.join(path.to_uri().path()).unwrap();
|
|
tokio::spawn(async move { upload_cert(client, url, cert).await })
|
|
});
|
|
for upload in uploads {
|
|
let _ = upload.await;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[instrument(skip(client, cert), ret)]
|
|
async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Result<()> {
|
|
let resp = client
|
|
.put(url.clone())
|
|
.body(cert.to_openssh()?)
|
|
.send()
|
|
.await?;
|
|
let status = resp.status();
|
|
if ![StatusCode::OK, StatusCode::CREATED].contains(&status) {
|
|
let id = cert.key_id();
|
|
error!(%id, %status, "failed to upload cert");
|
|
bail!("failed to upload {id}, error code {status}");
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn fetch(
|
|
FetchArgs {
|
|
cert_dir,
|
|
args: ClientArgs { api },
|
|
}: FetchArgs,
|
|
) -> anyhow::Result<()> {
|
|
let certs = read_dir(&cert_dir).await?;
|
|
let client = reqwest::Client::new();
|
|
let updates = certs.into_iter().map(|cert| {
|
|
let path = GetCert {
|
|
identifier: cert.key_id().to_string(),
|
|
};
|
|
let url = api.join(path.to_uri().path()).unwrap();
|
|
let client = client.clone();
|
|
tokio::spawn(async move { fetch_cert(client, url, cert).await })
|
|
});
|
|
for cert in updates {
|
|
if let Ok(Some((cert, update))) = cert.await? {
|
|
fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?;
|
|
info!(
|
|
"updated {}: {} -> {}",
|
|
cert.key_id(),
|
|
cert.serial(),
|
|
update.serial()
|
|
);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[instrument(skip(client, current))]
|
|
async fn fetch_cert(
|
|
client: Client,
|
|
url: Url,
|
|
current: Certificate,
|
|
) -> anyhow::Result<Option<(Certificate, Certificate)>> {
|
|
debug!("checking {}", current.key_id());
|
|
let resp = client.get(url.clone()).send().await?;
|
|
if resp.status() != StatusCode::OK {
|
|
return Ok(None);
|
|
}
|
|
let string_repr = resp.text().await?;
|
|
let remote_cert = Certificate::from_openssh(&string_repr)?;
|
|
if remote_cert
|
|
.validate(&[current.signature_key().fingerprint(Default::default())])
|
|
.is_err()
|
|
{
|
|
info!("invalid signature {}, skipping", &url);
|
|
return Ok(None);
|
|
}
|
|
if current.serial() >= remote_cert.serial() {
|
|
debug!("{} is not newer than local version", &url);
|
|
return Ok(None);
|
|
}
|
|
Ok(Some((current, remote_cert)))
|
|
}
|