refactor: split
This commit is contained in:
189
client/src/client.rs
Normal file
189
client/src/client.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use anyhow::bail;
|
||||
use axum_extra::routing::TypedPath;
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use reqwest::{Client, StatusCode};
|
||||
use ssh_key::Certificate;
|
||||
use std::io::stdin;
|
||||
use std::path::PathBuf;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::fs;
|
||||
use tracing::{debug, error, info, instrument, trace};
|
||||
|
||||
use url::Url;
|
||||
|
||||
use ssh_cert_dist_common::*;
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct ClientArgs {
|
||||
/// Url for the API endpoint
|
||||
#[clap(short = 'a', long = "api-endpoint", env = env_key!("API"))]
|
||||
api: Url,
|
||||
/// Require interaction before writing certificates
|
||||
#[clap(short = 'i', long = "interactive", env = env_key!("INTERACTIVE"))]
|
||||
interactive: bool,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct FetchArgs {
|
||||
#[clap(flatten)]
|
||||
args: ClientArgs,
|
||||
#[clap(short = 'c', long = "cert-dir", env = env_key!("CERT_DIR") )]
|
||||
cert_dir: PathBuf,
|
||||
/// minimum time in days between now and expiry to consider checking
|
||||
#[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))]
|
||||
min_delta_days: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct UploadArgs {
|
||||
#[clap(flatten)]
|
||||
args: ClientArgs,
|
||||
/// Certificates to be uploaded
|
||||
#[clap(env = env_key!("FILES"))]
|
||||
files: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct ClientCommand {
|
||||
#[clap(subcommand)]
|
||||
cmd: ClientCommands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum ClientCommands {
|
||||
Fetch(FetchArgs),
|
||||
Upload(UploadArgs),
|
||||
}
|
||||
|
||||
pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> {
|
||||
match cmd {
|
||||
ClientCommands::Fetch(args) => fetch(args).await,
|
||||
ClientCommands::Upload(args) => upload(args).await,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum UploadError {}
|
||||
|
||||
async fn upload(
|
||||
UploadArgs {
|
||||
args: ClientArgs { api, .. },
|
||||
files,
|
||||
}: UploadArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
let mut certs = Vec::new();
|
||||
for path in files.into_iter() {
|
||||
if let Some(cert) = load_cert(&path).await? {
|
||||
certs.push(cert);
|
||||
}
|
||||
}
|
||||
let uploads = certs.into_iter().map(|cert| {
|
||||
let client = client.clone();
|
||||
let path = PutCert;
|
||||
let url = api.join(path.to_uri().path()).unwrap();
|
||||
tokio::spawn(async move { upload_cert(client, url, cert).await })
|
||||
});
|
||||
for upload in uploads {
|
||||
let _ = upload.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(client, cert), ret)]
|
||||
async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Result<()> {
|
||||
let resp = client
|
||||
.put(url.clone())
|
||||
.body(cert.to_openssh()?)
|
||||
.send()
|
||||
.await?;
|
||||
let status = resp.status();
|
||||
if ![StatusCode::OK, StatusCode::CREATED].contains(&status) {
|
||||
let id = cert.key_id();
|
||||
error!(%id, %status, "failed to upload cert");
|
||||
bail!("failed to upload {id}, error code {status}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch(
|
||||
FetchArgs {
|
||||
cert_dir,
|
||||
min_delta_days: min_delta,
|
||||
args: ClientArgs { api, interactive },
|
||||
}: FetchArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let certs = read_dir(&cert_dir).await?;
|
||||
let client = reqwest::Client::new();
|
||||
let threshold_exp = min_delta.and_then(|min_delta| {
|
||||
SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64))
|
||||
});
|
||||
let updates = certs
|
||||
.into_iter()
|
||||
.filter(|cert| {
|
||||
let exp = cert.valid_before_time();
|
||||
let must_check = threshold_exp.as_ref().map(|th| &exp < th);
|
||||
trace!(?cert, ?must_check, "filter");
|
||||
must_check.unwrap_or(true)
|
||||
})
|
||||
.map(|cert| {
|
||||
let path = GetCert {
|
||||
identifier: cert.key_id().to_string(),
|
||||
};
|
||||
let url = api.join(path.to_uri().path()).unwrap();
|
||||
let client = client.clone();
|
||||
tokio::spawn(async move { fetch_cert(client, url, cert).await })
|
||||
});
|
||||
for cert in updates {
|
||||
if let Ok(Some((cert, update))) = cert.await? {
|
||||
if interactive {
|
||||
println!("certificate update: {}", cert.key_id());
|
||||
println!(
|
||||
"principals: {:?}, expiry: {}",
|
||||
update.valid_principals(),
|
||||
update.valid_before()
|
||||
);
|
||||
println!("update? : (y/n)");
|
||||
let mut yes = String::with_capacity(3);
|
||||
stdin().read_line(&mut yes)?;
|
||||
if !yes.starts_with(['y', 'Y']) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?;
|
||||
let key_id = cert.key_id();
|
||||
info!(
|
||||
%key_id,
|
||||
"updated certificate",
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(client, current))]
|
||||
async fn fetch_cert(
|
||||
client: Client,
|
||||
url: Url,
|
||||
current: Certificate,
|
||||
) -> anyhow::Result<Option<(Certificate, Certificate)>> {
|
||||
debug!("checking {}", current.key_id());
|
||||
let resp = client.get(url.clone()).send().await?;
|
||||
if resp.status() != StatusCode::OK {
|
||||
return Ok(None);
|
||||
}
|
||||
let string_repr = resp.text().await?;
|
||||
let remote_cert = Certificate::from_openssh(&string_repr)?;
|
||||
if remote_cert
|
||||
.validate(&[current.signature_key().fingerprint(Default::default())])
|
||||
.is_err()
|
||||
{
|
||||
info!("invalid signature {}, skipping", &url);
|
||||
return Ok(None);
|
||||
}
|
||||
if current.serial() >= remote_cert.serial() {
|
||||
debug!("{} is not newer than local version", &url);
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some((current, remote_cert)))
|
||||
}
|
10
client/src/main.rs
Normal file
10
client/src/main.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use clap::Parser;
|
||||
|
||||
mod client;
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
client::run(client::ClientCommand::parse()).await
|
||||
}
|
Reference in New Issue
Block a user