Separate crates for server and client binaries #1

Merged
shimun merged 28 commits from split_components into master 2023-07-09 19:27:21 +02:00
21 changed files with 1238 additions and 869 deletions

1291
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,43 +1,8 @@
[package] [workspace]
name = "ssh-cert-dist"
version = "0.1.0"
authors = ["shimun <shimun@shimun.net>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html members = [
"common",
[features] "server",
default = [ "client", "reload", "info", "authorized" ] "client",
reload = [] ]
authorized =[ "dep:jwt-compact" ]
index = []
info = [ "axum/json", "ssh-key/serde" ]
client = [ "dep:url", "dep:reqwest" ]
[dependencies]
anyhow = "1.0.66"
async-trait = "0.1.59"
axum = { version = "0.6.1", features = ["http2"] }
axum-extra = { version = "0.4.1", features = ["typed-routing"] }
chrono = "0.4.23"
clap = { version = "4.0.29", features = ["env", "derive"] }
jwt-compact = { version = "0.6.0", features = ["serde_cbor", "std", "clock"], optional = true }
rand = "0.8.5"
reqwest = { version = "0.11.13", optional = true }
serde = { version = "1.0.148", features = ["derive"] }
ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] }
thiserror = "1.0.37"
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.3.4", features = ["map-request-body", "trace"] }
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
tracing-subscriber = "0.3.16"
url = { version = "2.3.1", optional = true }
[patch.crates-io]
ssh-key = { git = "https://github.com/a-dma/SSH.git", branch = "u2f_signatures" }
[dev-dependencies]
tempfile = "3.3.0"

32
client/Cargo.toml Normal file
View File

@ -0,0 +1,32 @@
[package]
name = "ssh-cert-dist-client"
version = "0.1.0"
authors = ["shimun <shimun@shimun.net>"]
edition = "2021"
[[bin]]
name = "sshcd"
path = "src/main.rs"
[dependencies]
anyhow = "1.0.66"
async-trait = "0.1.59"
axum-extra = { version = "0.4.1", features = ["typed-routing"] }
chrono = "0.4.23"
clap = { version = "4.0.29", features = ["env", "derive"] }
rand = "0.8.5"
reqwest = { version = "0.11.13" }
serde = { version = "1.0.148", features = ["derive"] }
ssh-key = { version = "0.6.0-pre.0", features = ["ed25519", "p256", "p384", "rsa", "serde"] }
thiserror = "1.0.37"
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
tracing-subscriber = "0.3.16"
url = { version = "2.3.1" }
ssh-cert-dist-common = { path = "../common" }
[dev-dependencies]
tempfile = "3.3.0"
[profile.relese]
opt-level = 1

View File

@ -1,20 +1,17 @@
use anyhow::bail; use anyhow::bail;
use axum_extra::routing::TypedPath; use axum_extra::routing::TypedPath;
use clap::{Args, Parser, Subcommand}; use clap::{Parser, Subcommand};
use reqwest::{Client, StatusCode}; use reqwest::{Client, StatusCode};
use ssh_key::Certificate; use ssh_key::Certificate;
use std::io::{stdin, stdout};
use std::path::PathBuf; use std::path::PathBuf;
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use tokio::fs; use tokio::fs;
use tokio::io::{stdin, AsyncBufReadExt, BufReader};
use tracing::{debug, error, info, instrument, trace}; use tracing::{debug, error, info, instrument, trace};
use url::Url; use url::Url;
use crate::api::PutCert; use ssh_cert_dist_common::*;
use crate::certs::load_cert;
use crate::env_key;
use crate::{api::GetCert, certs::read_dir};
#[derive(Parser)] #[derive(Parser)]
pub struct ClientArgs { pub struct ClientArgs {
@ -30,7 +27,9 @@ pub struct ClientArgs {
pub struct FetchArgs { pub struct FetchArgs {
#[clap(flatten)] #[clap(flatten)]
args: ClientArgs, args: ClientArgs,
#[clap(short = 'c', long = "cert-dir", env = env_key!("CERT_DIR") )] #[clap(short = 'k', long = "key-update", env = env_key!("KEY_UPDATE"))]
prohibit_key_update: bool,
#[clap(short = 'c', long = "cert-dir", env = env_key!("CERT_DIR"))]
cert_dir: PathBuf, cert_dir: PathBuf,
/// minimum time in days between now and expiry to consider checking /// minimum time in days between now and expiry to consider checking
#[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))] #[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))]
@ -46,7 +45,7 @@ pub struct UploadArgs {
files: Vec<PathBuf>, files: Vec<PathBuf>,
} }
#[derive(Args)] #[derive(Parser)]
pub struct ClientCommand { pub struct ClientCommand {
#[clap(subcommand)] #[clap(subcommand)]
cmd: ClientCommands, cmd: ClientCommands,
@ -112,15 +111,18 @@ async fn upload_cert(client: Client, url: Url, cert: Certificate) -> anyhow::Res
async fn fetch( async fn fetch(
FetchArgs { FetchArgs {
cert_dir, cert_dir,
prohibit_key_update,
min_delta_days: min_delta, min_delta_days: min_delta,
args: ClientArgs { api, interactive }, args: ClientArgs { api, interactive },
}: FetchArgs, }: FetchArgs,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let certs = read_dir(&cert_dir).await?; let certs = read_certs_dir(&cert_dir).await?;
// let publics_keys = read_pubkey_dir(&cert_dir).await?;
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let threshold_exp = min_delta.and_then(|min_delta| { let threshold_exp = min_delta.and_then(|min_delta| {
SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64)) SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64))
}); });
// let standalone_certs = publics_keys.into_iter().map(|(name, key)| )
let updates = certs let updates = certs
.into_iter() .into_iter()
.filter(|cert| { .filter(|cert| {
@ -137,8 +139,13 @@ async fn fetch(
let client = client.clone(); let client = client.clone();
tokio::spawn(async move { fetch_cert(client, url, cert).await }) tokio::spawn(async move { fetch_cert(client, url, cert).await })
}); });
let mut stdin = BufReader::new(stdin()).lines();
for cert in updates { for cert in updates {
if let Ok(Some((cert, update))) = cert.await? { if let Ok(Some((cert, update))) = cert.await? {
if prohibit_key_update && cert.public_key() != update.public_key() {
debug!(?update, "skipping cert due to key change");
continue;
}
if interactive { if interactive {
println!("certificate update: {}", cert.key_id()); println!("certificate update: {}", cert.key_id());
println!( println!(
@ -147,9 +154,8 @@ async fn fetch(
update.valid_before() update.valid_before()
); );
println!("update? : (y/n)"); println!("update? : (y/n)");
let mut yes = String::with_capacity(3); let yes = stdin.next_line().await?;
stdin().read_line(&mut yes)?; if !matches!(yes, Some(line) if line.starts_with(['y', 'Y'])) {
if !yes.starts_with(['y', 'Y']) {
break; break;
} }
} }

10
client/src/main.rs Normal file
View File

@ -0,0 +1,10 @@
use clap::Parser;
mod client;
#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
client::run(client::ClientCommand::parse()).await
}

24
common/Cargo.toml Normal file
View File

@ -0,0 +1,24 @@
[package]
name = "ssh-cert-dist-common"
version = "0.1.0"
authors = ["shimun <shimun@shimun.net>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.66"
async-trait = "0.1.59"
axum = { version = "0.6.1" }
axum-extra = { version = "0.4.1", features = ["typed-routing"] }
hex = { version = "0.4.3", features = ["serde"] }
serde = { version = "1.0.148", features = ["derive"] }
ssh-key = { version = "0.6.0-pre.0", features = ["ed25519", "p256", "p384", "rsa"] }
thiserror = "1.0.37"
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
tracing-subscriber = "0.3.16"
[dev-dependencies]
tempfile = "3.3.0"

View File

@ -24,11 +24,11 @@ pub async fn read_certs(
if !ca_dir.exists() { if !ca_dir.exists() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
read_dir(&ca_dir).await read_certs_dir(&ca_dir).await
} }
#[instrument] #[instrument]
pub async fn read_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<Certificate>> { pub async fn read_certs_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<Certificate>> {
let mut dir = fs::read_dir(path.as_ref()) let mut dir = fs::read_dir(path.as_ref())
.await .await
.with_context(|| format!("read certs dir '{:?}'", path.as_ref()))?; .with_context(|| format!("read certs dir '{:?}'", path.as_ref()))?;
@ -55,6 +55,26 @@ pub async fn read_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<Cert
Ok(certs) Ok(certs)
} }
pub async fn read_pubkey_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<PublicKey>> {
let mut dir = fs::read_dir(path.as_ref())
.await
.with_context(|| format!("read certs dir '{:?}'", path.as_ref()))?;
let mut pubs = Vec::new();
while let Some(entry) = dir.next_entry().await? {
//TODO: investigate why path().ends_with doesn't work
let file_name = entry.file_name().into_string().unwrap();
if !file_name.ends_with(".pub") || file_name.ends_with("-cert.pub") {
trace!("skipped {:?} due to missing '.pub' extension", entry.path());
continue;
}
let cert = load_public_key(entry.path()).await?;
if let Some(cert) = cert {
pubs.push(cert);
}
}
Ok(pubs)
}
fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> { fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> {
String::from_utf8(bytes).context("invalid utf-8") String::from_utf8(bytes).context("invalid utf-8")
} }
@ -122,3 +142,15 @@ pub async fn load_cert(file: impl AsRef<Path> + Debug) -> anyhow::Result<Option<
|| format!("parse {:?} as openssh certificate", &file), || format!("parse {:?} as openssh certificate", &file),
)?)) )?))
} }
pub async fn load_public_key(file: impl AsRef<Path> + Debug) -> anyhow::Result<Option<PublicKey>> {
let contents = match fs::read(&file).await {
Ok(contents) => contents,
Err(e) if e.kind() == ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(e).with_context(|| format!("read {:?}", &file)),
};
let string_repr = parse_utf8(contents)?;
Ok(Some(PublicKey::from_openssh(&string_repr).with_context(
|| format!("parse {:?} as openssh public key", &file),
)?))
}

6
common/src/lib.rs Normal file
View File

@ -0,0 +1,6 @@
mod certs;
mod routes;
mod util;
pub use certs::*;
pub use routes::*;

41
common/src/routes.rs Normal file
View File

@ -0,0 +1,41 @@
use axum_extra::routing::TypedPath;
use serde::{Deserialize, Serialize};
use ssh_key::Fingerprint;
#[derive(TypedPath, Deserialize)]
#[typed_path("/certs")]
pub struct CertList;
#[derive(TypedPath, Deserialize)]
#[typed_path("/cert/:identifier")]
pub struct GetCert {
pub identifier: String,
}
#[derive(TypedPath, Deserialize)]
#[typed_path("/certs/:pubkey_hash")]
pub struct GetCertsPubkey {
pub pubkey_hash: Fingerprint,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct CertIds {
pub ids: Vec<String>,
}
#[derive(TypedPath, Deserialize)]
#[typed_path("/cert/:identifier/info")]
pub struct GetCertInfo {
pub identifier: String,
}
#[derive(TypedPath, Deserialize)]
#[typed_path("/cert/:identifier")]
pub struct PostCertInfo {
pub identifier: String,
}
#[derive(TypedPath)]
#[typed_path("/cert")]
pub struct PutCert;

6
common/src/util.rs Normal file
View File

@ -0,0 +1,6 @@
#[macro_export]
macro_rules! env_key {
( $var:expr ) => {
concat!("SSH_CD_", $var)
};
}

38
flake.lock generated
View File

@ -7,11 +7,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1662220400, "lastModified": 1688534083,
"narHash": "sha256-9o2OGQqu4xyLZP9K6kNe1pTHnyPz0Wr3raGYnr9AIgY=", "narHash": "sha256-/bI5vsioXscQTsx+Hk9X5HfweeNZz/6kVKsbdqfwW7g=",
"owner": "nmattia", "owner": "nmattia",
"repo": "naersk", "repo": "naersk",
"rev": "6944160c19cb591eb85bbf9b2f2768a935623ed3", "rev": "abca1fb7a6cfdd355231fc220c3d0302dbb4369a",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -22,11 +22,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1669411043, "lastModified": 1688679045,
"narHash": "sha256-LfPd3+EY+jaIHTRIEOUtHXuanxm59YKgUacmSzaqMLc=", "narHash": "sha256-t3xGEfYIwhaLTPU8FLtN/pLPytNeDwbLI6a7XFFBlGo=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5dc7114b7b256d217fe7752f1614be2514e61bb8", "rev": "3c7487575d9445185249a159046cc02ff364bff8",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -41,13 +41,31 @@
"utils": "utils" "utils": "utils"
} }
}, },
"utils": { "systems": {
"locked": { "locked": {
"lastModified": 1667395993, "lastModified": 1681028828,
"narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1687709756,
"narHash": "sha256-Y5wKlQSkgEK2weWdOu4J3riRd+kV/VCgHsqLNTTWQ/0=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", "rev": "dbabf0ca0c0c4bce6ea5eaf65af5cb694d2082c7",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -13,7 +13,7 @@
outputs = inputs @ { self, nixpkgs, utils, naersk, ... }: outputs = inputs @ { self, nixpkgs, utils, naersk, ... }:
let let
root = inputs.source or self; root = inputs.source or self;
pname = (builtins.fromTOML (builtins.readFile (root + "/Cargo.toml"))).package.name; pname = "ssh-cert-dist";
# toolchains: stable, beta, default(nightly) # toolchains: stable, beta, default(nightly)
toolchain = pkgs: toolchain = pkgs:
if inputs ? fenix then inputs.fenix.packages."${pkgs.system}".complete.toolchain if inputs ? fenix then inputs.fenix.packages."${pkgs.system}".complete.toolchain
@ -24,15 +24,30 @@
in in
rec { rec {
# `nix build` # `nix build`
packages.${pname} = (self.overlay pkgs pkgs).${pname}; packages."${pname}-server" = (self.overlay pkgs pkgs)."${pname}-server";
packages."${pname}-client" = (self.overlay pkgs pkgs)."${pname}-client";
packages."${pname}-client-snap" = pkgs.snapTools.makeSnap {
meta = {
name = pname;
architectures = [ "amd64" ];
confinement = "strict";
apps.hello.command = apps."${pname}-client".program;
};
};
packages.dockerImage = pkgs.runCommandLocal "docker-${pname}.tar.gz" { } "${apps.streamDockerImage.program} | gzip --fast > $out"; packages.dockerImage = pkgs.runCommandLocal "docker-${pname}.tar.gz" { } "${apps.streamDockerImage.program} | gzip --fast > $out";
packages.default = packages.${pname}; packages.default = packages."${pname}-client";
# `nix run` # `nix run`
apps.${pname} = utils.lib.mkApp { apps."${pname}-server" = utils.lib.mkApp {
drv = packages.${pname}; drv = packages."${pname}-server";
exePath = "/bin/sshcd-server";
};
apps."${pname}-client" = utils.lib.mkApp {
drv = packages."${pname}-client";
exePath = "/bin/sshcd";
}; };
# `nix run .#streamDockerImage | docker load` # `nix run .#streamDockerImage | docker load`
@ -41,12 +56,12 @@
name = pname; name = pname;
tag = self.shortRev or "latest"; tag = self.shortRev or "latest";
config = { config = {
Entrypoint = apps.default.program; Entrypoint = apps."${pname}-server".program;
}; };
}; };
exePath = ""; exePath = "";
}; };
apps.default = apps.${pname}; apps.default = apps."${pname}-client";
# `nix flake check` # `nix flake check`
checks = { checks = {
@ -78,7 +93,15 @@
rustc --version rustc --version
printf "\nbuild inputs: ${pkgs.lib.concatStringsSep ", " (map (bi: bi.name) (buildInputs ++ nativeBuildInputs))}" printf "\nbuild inputs: ${pkgs.lib.concatStringsSep ", " (map (bi: bi.name) (buildInputs ++ nativeBuildInputs))}"
function server() { function server() {
cargo watch -x "run --all-features -- server ''${@}" if [ ! -e "certs/ca.pub" ]; then
mkdir -p certs keys
ssh-keygen -t ed25519 -f certs/ca -q -N ""
ssh-keygen -t ed25519 -f keys/host -q -N ""
ssh-keygen -t ed25519 -f keys/client -q -N ""
ssh-keygen -s certs/ca -V +1000d -h -I host -n localhost,127.0.0.1 -h keys/host.pub
ssh-keygen -s certs/ca -V +1000d -I client -n "client,client@localhost" keys/client.pub -O force-command="echo Hello World"
fi
cargo watch -x "run --bin sshcd-server --all-features -- ''${@}"
} }
''; '';
}; };
@ -103,9 +126,15 @@
]; ];
in in
{ {
"${pname}" = "${pname}-server" =
naersk-lib.buildPackage { naersk-lib.buildPackage {
inherit pname root buildInputs nativeBuildInputs; name = "${pname}-server";
inherit root buildInputs nativeBuildInputs;
};
"${pname}-client" =
naersk-lib.buildPackage {
name = "${pname}-client";
inherit root buildInputs nativeBuildInputs;
}; };
}; };

View File

@ -13,15 +13,14 @@ in
Environment = "RUST_LOG=debug"; Environment = "RUST_LOG=debug";
ExecStart = toString (pkgs.writeShellApplication { ExecStart = toString (pkgs.writeShellApplication {
name = "ssh-cert-dist-${options.name}"; name = "ssh-cert-dist-${options.name}";
runtimeInputs = [ pkgs.ssh-cert-dist ]; runtimeInputs = [ cfg.package ];
text = '' text = ''
${optionalString options.fetch '' ${optionalString options.fetch ''
ssh-cert-dist client fetch --cert-dir '${path}' --api-endpoint '${cfg.endpoint}' sshcd fetch --cert-dir '${path}' --api-endpoint '${cfg.endpoint}'
''} ''}
${optionalString options.upload '' ${optionalString options.upload ''
ssh-cert-dist client upload --api-endpoint '${cfg.endpoint}' ${path}/* sshcd upload --api-endpoint '${cfg.endpoint}' ${path}/*
''} ''}
''; '';
}); });
}; };

View File

@ -14,7 +14,7 @@ in
}; };
package = mkOption { package = mkOption {
type = types.package; type = types.package;
default = pkgs.ssh-cert-dist; default = pkgs.ssh-cert-dist-server;
}; };
ca = mkOption { ca = mkOption {
type = types.path; type = types.path;
@ -57,7 +57,7 @@ in
chown ${cfg.user}:${cfg.group} ${cfg.dataDir} chown ${cfg.user}:${cfg.group} ${cfg.dataDir}
''}"; ''}";
User = cfg.user; User = cfg.user;
ExecStart = "${cfg.package}/bin/ssh-cert-dist server"; ExecStart = "${cfg.package}/bin/sshcd-server";
}; };
}; };
}; };

View File

@ -22,7 +22,7 @@
}; };
packageOption = mkOption { packageOption = mkOption {
type = types.package; type = types.package;
default = pkgs.ssh-cert-dist; default = pkgs.ssh-cert-dist-client;
}; };
in in

44
server/Cargo.toml Normal file
View File

@ -0,0 +1,44 @@
[package]
name = "ssh-cert-dist-server"
version = "0.1.0"
authors = ["shimun <shimun@shimun.net>"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = [ "reload", "info", "authorized" ]
reload = []
authorized =[ "dep:jwt-compact" ]
index = []
info = [ "axum/json", "ssh-key/serde" ]
[[bin]]
name = "sshcd-server"
path = "src/main.rs"
[dependencies]
anyhow = "1.0.66"
async-trait = "0.1.59"
axum = { version = "0.6.1", features = ["http2"] }
axum-extra = { version = "0.4.1", features = ["typed-routing"] }
chrono = "0.4.23"
clap = { version = "4.0.29", features = ["env", "derive"] }
jwt-compact = { version = "0.6.0", features = ["serde_cbor", "std", "clock"], optional = true }
rand = "0.8.5"
serde = { version = "1.0.148", features = ["derive"] }
ssh-key = { version = "0.6.0-pre.0", features = ["ed25519", "p256", "p384", "rsa"] }
thiserror = "1.0.37"
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
tower = { version = "0.4.13" }
tower-http = { version = "0.3.4", features = ["map-request-body", "trace", "util"] }
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
tracing-subscriber = "0.3.16"
ssh-cert-dist-common = { path = "../common" }
shell-escape = "0.1.5"
[dev-dependencies]
tempfile = "3.3.0"
[profile.release]
opt-level = 1

View File

@ -1,35 +1,37 @@
mod extract; mod extract;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{self, PathBuf}; use std::path::{self, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, SystemTime}; use std::time::{SystemTime, UNIX_EPOCH};
use crate::certs::{load_cert_by_id, read_certs, read_pubkey, store_cert};
use crate::env_key;
use anyhow::Context; use anyhow::Context;
use axum::body; use axum::body;
use axum::extract::rejection::QueryRejection;
use axum::extract::{Query, State}; use axum::extract::{Query, State};
use chrono::Duration;
use shell_escape::escape;
use ssh_cert_dist_common::*;
use axum::{http::StatusCode, response::IntoResponse, Json, Router}; use axum::{http::StatusCode, response::IntoResponse, Json, Router};
use axum_extra::routing::{ use axum_extra::routing::RouterExt;
RouterExt, // for `Router::typed_*`
TypedPath,
};
use clap::{Args, Parser}; use clap::{Args, Parser};
use jwt_compact::alg::{Hs256, Hs256Key}; use jwt_compact::alg::{Hs256, Hs256Key};
use jwt_compact::{AlgorithmExt, Token, UntrustedToken}; use jwt_compact::{AlgorithmExt};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ssh_key::private::Ed25519Keypair; use ssh_key::{Certificate, Fingerprint, PublicKey};
use ssh_key::{certificate, Certificate, PrivateKey, PublicKey};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::{trace::TraceLayer, ServiceBuilderExt}; use tower_http::{trace::TraceLayer, ServiceBuilderExt};
use tracing::{debug, info, trace}; use tracing::{debug, info, trace};
use self::extract::{CertificateBody, SignatureBody}; use self::extract::{CertificateBody, SignatureBody, JWTAuthenticated, JWTString};
#[derive(Parser)] #[derive(Parser)]
pub struct ApiArgs { pub struct ApiArgs {
@ -77,7 +79,7 @@ impl Default for ApiArgs {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ApiState { pub struct ApiState {
certs: Arc<Mutex<HashMap<String, Certificate>>>, certs: Arc<Mutex<HashMap<String, Certificate>>>,
cert_dir: PathBuf, cert_dir: PathBuf,
ca: PublicKey, ca: PublicKey,
@ -140,6 +142,7 @@ pub async fn run(
let app = Router::new() let app = Router::new()
.typed_get(get_certs_identifier) .typed_get(get_certs_identifier)
.typed_get(get_certs_pubkey)
.typed_put(put_cert_update) .typed_put(put_cert_update)
.typed_get(get_cert_info) .typed_get(get_cert_info)
.typed_post(post_certs_identifier); .typed_post(post_certs_identifier);
@ -177,12 +180,23 @@ pub enum ApiError {
AuthenticationRequired(String), AuthenticationRequired(String),
#[error("invalid ssh signature")] #[error("invalid ssh signature")]
InvalidSignature, InvalidSignature,
#[error("malformed ssh signature: {0}")]
ParseSignature(anyhow::Error),
#[error("malformed ssh certificate: {0}")]
ParseCertificate(anyhow::Error),
#[error("{0}")]
JWTParse(#[from] jwt_compact::ParseError),
#[error("{0}")]
JWTVerify(#[from] jwt_compact::ValidationError),
#[error("{0}")]
Query(#[from] QueryRejection),
} }
type ApiResult<T> = Result<T, ApiError>; type ApiResult<T> = Result<T, ApiError>;
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
trace!({ error = ?self }, "returned error for request");
( (
match self { match self {
Self::CertificateNotFound => StatusCode::NOT_FOUND, Self::CertificateNotFound => StatusCode::NOT_FOUND,
@ -202,10 +216,7 @@ async fn fallback_404() -> ApiResult<()> {
Err(ApiError::CertificateNotFound) Err(ApiError::CertificateNotFound)
} }
#[derive(TypedPath, Deserialize)] #[cfg(feature = "index")]
#[typed_path("/certs")]
pub struct CertList;
async fn list_certs( async fn list_certs(
_: CertList, _: CertList,
State(ApiState { certs, .. }): State<ApiState>, State(ApiState { certs, .. }): State<ApiState>,
@ -221,16 +232,26 @@ async fn list_certs(
)) ))
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "aud", rename = "get")] #[serde(tag = "aud", rename = "get")]
struct AuthClaims { struct AuthClaims {
identifier: String, identifier: String,
} }
#[derive(TypedPath, Deserialize)] async fn request_client_auth(enabled: bool, identifier: &str, jwt_key: &Hs256Key) -> ApiResult<()> {
#[typed_path("/certs/:identifier")] use jwt_compact::{Claims, Header, TimeOptions};
pub struct GetCert { if enabled {
pub identifier: String, let claims = Claims::new(AuthClaims {
identifier: identifier.into(),
})
.set_duration(&TimeOptions::default(), chrono::Duration::seconds(120));
let challenge = Hs256
.compact_token(Header::default(), &claims, &jwt_key)
.context("jwt sign")?;
return Err(ApiError::AuthenticationRequired(challenge));
} else {
Ok(())
}
} }
/// Retrieve an certificate for identifier /// Retrieve an certificate for identifier
@ -247,16 +268,7 @@ async fn get_certs_identifier(
.. ..
}): State<ApiState>, }): State<ApiState>,
) -> ApiResult<String> { ) -> ApiResult<String> {
use jwt_compact::{AlgorithmExt, Claims, Header, TimeOptions}; request_client_auth(client_auth, &identifier, &jwt_key).await?;
if client_auth {
let claims = Claims::new(AuthClaims { identifier })
.set_duration(&TimeOptions::default(), chrono::Duration::seconds(120));
let challenge = Hs256
.compact_token(Header::default(), &claims, &jwt_key)
.context("jwt sign")?;
return Err(ApiError::AuthenticationRequired(challenge));
}
let certs = certs.lock().await; let certs = certs.lock().await;
let cert = certs let cert = certs
.get(&identifier) .get(&identifier)
@ -264,10 +276,22 @@ async fn get_certs_identifier(
Ok(cert.to_openssh().context("to openssh")?) Ok(cert.to_openssh().context("to openssh")?)
} }
#[derive(TypedPath, Deserialize)] async fn get_certs_pubkey(
#[typed_path("/certs/:identifier/info")] GetCertsPubkey { pubkey_hash }: GetCertsPubkey,
pub struct GetCertInfo { State(ApiState {
pub identifier: String, certs,
jwt_key: _,
client_auth: _,
..
}): State<ApiState>,
) -> ApiResult<Json<CertIds>> {
let certs = certs.lock().await;
let ids = certs
.values()
.filter(|cert| &cert.public_key().fingerprint(pubkey_hash.algorithm()) == &pubkey_hash)
.map(|cert| cert.key_id().to_string())
.collect::<Vec<_>>();
Ok(Json(CertIds { ids }))
} }
#[cfg(feature = "info")] #[cfg(feature = "info")]
@ -275,19 +299,59 @@ pub struct GetCertInfo {
struct CertInfo { struct CertInfo {
principals: Vec<String>, principals: Vec<String>,
ca: PublicKey, ca: PublicKey,
ca_hash: Fingerprint,
identity: PublicKey, identity: PublicKey,
identity_hash: Fingerprint,
key_id: String, key_id: String,
expiry: SystemTime, expiry: SystemTime,
renew_command: String,
} }
impl From<&Certificate> for CertInfo { impl From<&Certificate> for CertInfo {
fn from(cert: &Certificate) -> Self { fn from(cert: &Certificate) -> Self {
let validity = cert
.valid_before_time()
.duration_since(cert.valid_after_time())
.unwrap_or(Duration::zero().to_std().unwrap());
let expiry = cert.valid_before_time().checked_add(validity).unwrap();
let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap();
let host_key = if cert.cert_type().is_host() {
" -h"
} else {
""
};
let opts = cert
.critical_options()
.iter()
.map(|(opt, val)| {
if val.is_empty() {
opt.clone()
} else {
format!("{opt}={val}")
}
})
.map(|arg| format!("-O {}", escape(arg.into())))
.collect::<Vec<_>>()
.join(" ");
let opts = opts.trim();
let renew_command = format!(
"ssh-keygen -s ./ca_key {host_key} -I {} -n {} -z {} -V {:#x}:{:#x} {opts} {}.pub",
escape(cert.key_id().into()),
escape(cert.valid_principals().join(",").into()),
cert.serial() + 1,
cert.valid_after(),
expiry_date.as_secs(),
escape(cert.key_id().into())
);
CertInfo { CertInfo {
principals: cert.valid_principals().to_vec(), principals: cert.valid_principals().to_vec(),
ca: cert.signature_key().clone().into(), ca: cert.signature_key().clone().into(),
ca_hash: cert.signature_key().fingerprint(ssh_key::HashAlg::Sha256),
identity: cert.public_key().clone().into(), identity: cert.public_key().clone().into(),
identity_hash: cert.public_key().fingerprint(ssh_key::HashAlg::Sha256),
key_id: cert.key_id().to_string(), key_id: cert.key_id().to_string(),
expiry: cert.valid_before_time(), expiry: cert.valid_before_time(),
renew_command,
} }
} }
} }
@ -312,33 +376,31 @@ async fn get_cert_info(
unimplemented!() unimplemented!()
} }
#[derive(TypedPath, Deserialize)]
#[typed_path("/certs/:identifier")]
pub struct PostCertInfo {
pub identifier: String,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct PostCertsQuery { struct PostCertsQuery {
challenge: String, challenge: String,
} }
impl From<Query<PostCertsQuery>> for JWTString {
fn from(Query(PostCertsQuery { challenge }): Query<PostCertsQuery>) -> Self {
Self::from(challenge)
}
}
/// POST with signed challenge /// POST with signed challenge
async fn post_certs_identifier( async fn post_certs_identifier(
PostCertInfo { identifier }: PostCertInfo, PostCertInfo { identifier }: PostCertInfo,
State(ApiState { certs, jwt_key, .. }): State<ApiState>, State(ApiState { certs, .. }): State<ApiState>,
JWTAuthenticated {
data: auth_claims, ..
}: JWTAuthenticated<AuthClaims, Query<PostCertsQuery>>,
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>, Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
SignatureBody(sig): SignatureBody, SignatureBody(sig): SignatureBody,
) -> ApiResult<String> { ) -> ApiResult<String> {
let certs = certs.lock().await; let certs = certs.lock().await;
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?; let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
let token: Token<AuthClaims> = Hs256 if auth_claims.identifier != identifier {
.validate_integrity(
&UntrustedToken::new(&challenge).context("jwt parse")?,
&jwt_key,
)
.map_err(|_| ApiError::InvalidSignature)?;
if token.claims().custom.identifier != identifier {
return Err(ApiError::InvalidSignature); return Err(ApiError::InvalidSignature);
} }
let pubkey: PublicKey = cert.public_key().clone().into(); let pubkey: PublicKey = cert.public_key().clone().into();
@ -353,10 +415,6 @@ async fn post_certs_identifier(
Ok(cert.to_openssh().context("to openssh")?) Ok(cert.to_openssh().context("to openssh")?)
} }
#[derive(TypedPath)]
#[typed_path("/cert")]
pub struct PutCert;
/// Upload an cert with an higher serial than the previous /// Upload an cert with an higher serial than the previous
async fn put_cert_update( async fn put_cert_update(
_: PutCert, _: PutCert,
@ -405,12 +463,14 @@ async fn put_cert_update(
let identity = cert.key_id(); let identity = cert.key_id();
info!(%identity, ?principals, "updating certificate"); info!(%identity, ?principals, "updating certificate");
certs.lock().await.insert(cert.key_id().to_string(), cert); certs.lock().await.insert(cert.key_id().to_string(), cert);
Ok(format!("{} -> {}", prev_serial, serial)) Ok(format!("{prev_serial} -> {serial}"))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use ssh_key::{certificate, private::Ed25519Keypair, PrivateKey};
use std::env::temp_dir; use std::env::temp_dir;
use std::time::Duration;
use super::*; use super::*;
@ -423,14 +483,23 @@ mod tests {
} }
fn ca_pub() -> PublicKey { fn ca_pub() -> PublicKey {
PublicKey::new(ca_key().public.into(), "TEST CA") PublicKey::new(
ca_key().public.into(),
format!(
"TEST CA {}",
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
),
)
} }
fn user_key() -> Ed25519Keypair { fn user_key() -> Ed25519Keypair {
Ed25519Keypair::from_seed(&[1u8; 32]) Ed25519Keypair::from_seed(&[1u8; 32])
} }
fn user_cert(ca: Ed25519Keypair, user_key: PublicKey) -> Certificate { fn user_cert(ca: Ed25519Keypair, user_key: PublicKey, validity: Duration) -> Certificate {
let ca_private: PrivateKey = ca.into(); let ca_private: PrivateKey = ca.into();
let unix_time = |time: SystemTime| -> u64 { let unix_time = |time: SystemTime| -> u64 {
time.duration_since(SystemTime::UNIX_EPOCH) time.duration_since(SystemTime::UNIX_EPOCH)
@ -441,7 +510,7 @@ mod tests {
[0u8; 16], [0u8; 16],
user_key, user_key,
unix_time(SystemTime::now()), unix_time(SystemTime::now()),
unix_time(SystemTime::now() + Duration::from_secs(30)), unix_time(SystemTime::now() + validity),
); );
builder builder
@ -449,7 +518,7 @@ mod tests {
.unwrap() .unwrap()
.key_id("test_cert") .key_id("test_cert")
.unwrap() .unwrap()
.comment("A TEST CERT") .comment(&format!("A TEST CERT, VALID FOR {}s", validity.as_secs()))
.unwrap(); .unwrap();
builder.sign(&ca_private).unwrap() builder.sign(&ca_private).unwrap()
@ -463,24 +532,49 @@ mod tests {
cert_dir: dbg!(temp_dir()), cert_dir: dbg!(temp_dir()),
validation_args: Default::default(), validation_args: Default::default(),
client_auth: false, client_auth: false,
jwt_key: Hs256Key::new(&[0u8; 16]), jwt_key: Hs256Key::new([0u8; 16]),
} }
} }
#[test] #[test]
fn test_certificate() { fn test_certificate() {
let valid_cert = user_cert(ca_key(), user_key().public.into()); let valid_cert = user_cert(ca_key(), user_key().public.into(), Duration::from_secs(30));
let ca_pub: PublicKey = ca_pub(); let ca_pub: PublicKey = ca_pub();
assert!(valid_cert assert!(valid_cert
.validate(&[ca_pub.fingerprint(Default::default())]) .validate(&[ca_pub.fingerprint(Default::default())])
.is_ok()); .is_ok());
} }
#[tokio::test]
async fn update_cert() {
let state = api_state();
let ca = ca_key();
let user: PublicKey = user_key().public.into();
let (cert_first, cert_newer, cert_outdated) = {
(
user_cert(ca.clone(), user.clone(), Duration::from_secs(300)),
user_cert(ca.clone(), user.clone(), Duration::from_secs(600)),
user_cert(ca.clone(), user.clone(), Duration::from_secs(30)),
)
};
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_first)).await;
assert!(dbg!(res).is_ok());
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_newer)).await;
assert!(res.is_ok());
let res = put_cert_update(
PutCert,
State(state.clone()),
CertificateBody(cert_outdated),
)
.await;
assert!(res.is_err());
}
#[tokio::test] #[tokio::test]
async fn routes() -> anyhow::Result<()> { async fn routes() -> anyhow::Result<()> {
let state = api_state(); let state = api_state();
let valid_cert = user_cert(ca_key(), user_key().public.into()); let valid_cert = user_cert(ca_key(), user_key().public.into(), Duration::from_secs(30));
let invalid_cert = user_cert(ca_key2(), user_key().public.into()); let invalid_cert = user_cert(ca_key2(), user_key().public.into(), Duration::from_secs(30));
let res = put_cert_update( let res = put_cert_update(
PutCert, PutCert,
State(state.clone()), State(state.clone()),
@ -537,6 +631,9 @@ mod tests {
identifier: "test_cert".into(), identifier: "test_cert".into(),
}, },
State(state.clone()), State(state.clone()),
JWTAuthenticated::new(AuthClaims {
identifier: "test_cert".into(),
}),
Query(PostCertsQuery { challenge }), Query(PostCertsQuery { challenge }),
SignatureBody(sig), SignatureBody(sig),
) )

123
server/src/api/extract.rs Normal file
View File

@ -0,0 +1,123 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use super::{ApiError, ApiState};
use anyhow::Context;
use axum::{
async_trait,
body::BoxBody,
extract::{FromRequest, FromRequestParts},
http::Request,
};
use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken};
use serde::{de::DeserializeOwned, Serialize};
use ssh_key::{Certificate, SshSig};
use tracing::trace;
#[derive(Debug, Clone)]
pub struct CertificateBody(pub Certificate);
// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body
#[async_trait]
impl<S> FromRequest<S, BoxBody> for CertificateBody
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = String::from_request(req, state)
.await
.context("failed to extract body")?;
let cert = Certificate::from_openssh(&body)
.with_context(|| format!("failed to parse '{}'", body))
.map_err(ApiError::ParseCertificate)?;
trace!(%body, "extracted certificate");
Ok(Self(cert))
}
}
#[derive(Debug, Clone)]
pub struct SignatureBody(pub SshSig);
#[async_trait]
impl<S> FromRequest<S, BoxBody> for SignatureBody
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = String::from_request(req, state)
.await
.context("failed to extract body")?;
let sig = SshSig::from_pem(&body)
.with_context(|| format!("failed to parse '{}'", body))
.map_err(ApiError::ParseSignature)?;
trace!(%body, "extracted signature");
Ok(Self(sig))
}
}
pub struct JWTString(String);
impl From<String> for JWTString {
fn from(s: String) -> Self {
Self(s)
}
}
// TODO: be generic over ApiState -> AsRef<Target=Hs256>, AsRef<Target=A> where A: AlgorithmExt
#[derive(Debug)]
pub struct JWTAuthenticated<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
pub data: T,
_marker: PhantomData<Q>,
}
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> JWTAuthenticated<T, Q>
where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
pub fn new(data: T) -> Self {
Self {
data,
_marker: Default::default(),
}
}
}
#[async_trait]
impl<
T: Serialize + DeserializeOwned + Clone + Debug,
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
> FromRequestParts<ApiState> for JWTAuthenticated<T, Q>
where
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
{
type Rejection = ApiError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &ApiState,
) -> Result<Self, Self::Rejection> {
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?;
let verified: Token<T> = Hs256
.validate_integrity(&token, &state.jwt_key)
.map_err(ApiError::JWTVerify)?;
Ok(Self {
data: verified.claims().custom.clone(),
_marker: Default::default(),
})
}
}

10
server/src/main.rs Normal file
View File

@ -0,0 +1,10 @@
use clap::Parser;
mod api;
#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
api::run(api::ApiArgs::parse()).await
}

View File

@ -1,51 +0,0 @@
use anyhow::Context;
use axum::{
async_trait, body::BoxBody, extract::FromRequest, http::Request, response::IntoResponse,
};
use ssh_key::{Certificate, SshSig};
use tracing::trace;
use super::ApiError;
#[derive(Debug, Clone)]
pub struct CertificateBody(pub Certificate);
// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body
#[async_trait]
impl<S> FromRequest<S, BoxBody> for CertificateBody
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = String::from_request(req, state)
.await
.context("failed to extract body")?;
let cert = Certificate::from_openssh(&body)
.with_context(|| format!("failed to parse '{}'", body))?;
trace!(%body, "extracted certificate");
Ok(Self(cert))
}
}
#[derive(Debug, Clone)]
pub struct SignatureBody(pub SshSig);
#[async_trait]
impl<S> FromRequest<S, BoxBody> for SignatureBody
where
S: Send + Sync,
{
type Rejection = ApiError;
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = String::from_request(req, state)
.await
.context("failed to extract body")?;
let sig = SshSig::from_pem(&body).with_context(|| format!("failed to parse '{}'", body))?;
trace!(%body, "extracted signature");
Ok(Self(sig))
}
}

View File

@ -1,35 +0,0 @@
use api::ApiArgs;
use clap::Parser;
#[cfg(feature = "client")]
use client::ClientCommand;
mod api;
mod certs;
#[cfg(feature = "client")]
mod client;
#[macro_export]
macro_rules! env_key {
( $var:expr ) => {
concat!("SSH_CD_", $var)
};
}
#[derive(Parser)]
enum Command {
Server(ApiArgs),
#[cfg(feature = "client")]
Client(ClientCommand),
}
#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
match Command::parse() {
Command::Server(args) => api::run(args).await?,
#[cfg(feature = "client")]
Command::Client(args) => client::run(args).await?,
}
Ok(())
}