From c29a99e6c440b132da446ca1a9982381b11e1913 Mon Sep 17 00:00:00 2001 From: shimun Date: Wed, 30 Nov 2022 23:30:03 +0100 Subject: [PATCH] added: paralell fetch --- src/certs.rs | 12 ++++-- src/client.rs | 106 +++++++++++++++++++++++++++++++++----------------- src/main.rs | 1 - 3 files changed, 79 insertions(+), 40 deletions(-) diff --git a/src/certs.rs b/src/certs.rs index c20d8d9..9084963 100644 --- a/src/certs.rs +++ b/src/certs.rs @@ -1,8 +1,8 @@ use anyhow::Context; use ssh_key::{Certificate, PublicKey}; -use std::path::{Path, PathBuf}; +use std::{path::{Path, PathBuf}, fmt::Debug}; use tokio::fs; -use tracing::trace; +use tracing::{trace, instrument}; pub async fn read_certs( ca: &PublicKey, @@ -11,7 +11,8 @@ pub async fn read_certs( read_dir(path.as_ref().join(ca_dir(ca))).await } -pub async fn read_dir(path: impl AsRef) -> anyhow::Result> { +#[instrument] +pub async fn read_dir(path: impl AsRef + Debug) -> anyhow::Result> { let mut dir = fs::read_dir(path.as_ref()) .await .context("read certs dir")?; @@ -58,13 +59,15 @@ fn ca_dir(ca: &PublicKey) -> String { ca.comment().to_string() } +#[instrument] fn cert_path(ca: &PublicKey, identifier: &str) -> String { let _ca_fingerprint = ca.fingerprint(Default::default()); format!("{}/{}-cert.pub", ca_dir(ca), identifier) } +#[instrument] pub async fn store_cert( - cert_dir: impl AsRef, + cert_dir: impl AsRef + Debug, ca: &PublicKey, cert: &Certificate, ) -> anyhow::Result { @@ -84,6 +87,7 @@ pub async fn load_cert( ) -> anyhow::Result> { let path = cert_dir.as_ref().join(cert_path(ca, identifier)); if !path.exists() { + trace!("no certificate at {:?}", path); return Ok(None); } let contents = fs::read(&path) diff --git a/src/client.rs b/src/client.rs index 0a40489..aa7bc48 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,17 @@ use axum_extra::routing::TypedPath; use clap::{Args, Parser, Subcommand}; -use reqwest::StatusCode; +use reqwest::{Client, StatusCode}; use ssh_key::Certificate; use std::path::PathBuf; use tokio::fs; -use tracing::{debug, info}; +use tracing::{debug, info, instrument}; +use tracing::{info_span, Instrument}; use url::Url; -use crate::{api::GetCert, certs::read_dir}; +use crate::{ + api::GetCert, + certs::{self, read_dir}, +}; #[derive(Parser)] pub struct ClientArgs { @@ -22,9 +26,14 @@ pub struct FetchArgs { args: ClientArgs, #[clap(short = 'c', long = "cert-dir", default_value = "~/.ssh")] cert_dir: PathBuf, - /// CA public key - #[clap(long = "ca")] - ca: PathBuf, +} + +#[derive(Parser)] +pub struct UploadArgs { + #[clap(flatten)] + args: ClientArgs, + /// Certificates to be uploaded + files: Vec, } #[derive(Args)] @@ -36,55 +45,82 @@ pub struct ClientCommand { #[derive(Subcommand)] pub enum ClientCommands { Fetch(FetchArgs), - Upload, + Upload(UploadArgs), } pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> { match cmd { ClientCommands::Fetch(args) => fetch(args).await, - ClientCommands::Upload => unimplemented!(), + ClientCommands::Upload(args) => upload(args).await, } } +async fn upload( + UploadArgs { + args: ClientArgs { api }, + files, + }: UploadArgs, +) -> anyhow::Result<()> { + Ok(()) +} + +async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Result<()> { + Ok(()) +} + async fn fetch( FetchArgs { cert_dir, - ca: _, args: ClientArgs { api }, }: FetchArgs, ) -> anyhow::Result<()> { let certs = read_dir(&cert_dir).await?; let client = reqwest::Client::new(); - for cert in certs { + let updates = certs.into_iter().map(|cert| { let path = GetCert { identifier: cert.key_id().to_string(), }; - debug!("checking {}", cert.key_id()); - let url = api.join(path.to_uri().path())?; - let resp = client.get(url.clone()).send().await?; - if resp.status() != StatusCode::OK { - continue; + let url = api.join(path.to_uri().path()).unwrap(); + let client = client.clone(); + tokio::spawn(async move { fetch_cert(client, url, cert).await }) + }); + for cert in updates { + if let Ok(Some((cert, update))) = cert.await? { + fs::write(cert_dir.join(cert.key_id()), update.to_openssh()?).await?; + info!( + "updated {}: {} -> {}", + cert.key_id(), + cert.serial(), + update.serial() + ); } - let string_repr = resp.text().await?; - let remote_cert = Certificate::from_openssh(&string_repr)?; - if remote_cert - .validate(&[cert.signature_key().fingerprint(Default::default())]) - .is_err() - { - info!("invalid signature {}, skipping", &url); - continue; - } - if cert.serial() >= remote_cert.serial() { - debug!("{} is not newer than local version", &url); - continue; - } - fs::write(cert_dir.join(cert.key_id()), remote_cert.to_openssh()?).await?; - info!( - "updated {}: {} -> {}", - cert.key_id(), - cert.serial(), - remote_cert.serial() - ); } Ok(()) } + +#[instrument] +async fn fetch_cert( + client: Client, + url: Url, + current: Certificate, +) -> anyhow::Result> { + debug!("checking {}", current.key_id()); + let resp = client.get(url.clone()).send().await?; + if resp.status() != StatusCode::OK { + return Ok(None); + } + let string_repr = resp.text().await?; + let remote_cert = Certificate::from_openssh(&string_repr)?; + if remote_cert + .validate(&[current.signature_key().fingerprint(Default::default())]) + .is_err() + { + info!("invalid signature {}, skipping", &url); + return Ok(None); + } + if current.serial() >= remote_cert.serial() { + debug!("{} is not newer than local version", &url); + return Ok(None); + } + Ok(Some((current, remote_cert))) +} diff --git a/src/main.rs b/src/main.rs index 89b6f3f..b576320 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ use api::ApiArgs; use clap::Parser; #[cfg(feature = "client")] use client::ClientCommand; -use tracing_subscriber; mod api; mod certs;