Compare commits
28 Commits
53f21fa668
...
master
Author | SHA1 | Date | |
---|---|---|---|
dbca99308e | |||
e8830e812b | |||
7509d63582
|
|||
c37d40389a
|
|||
9f6a5e03c9
|
|||
675dd4faf6
|
|||
6cb7ce4a78
|
|||
b8505790f2
|
|||
d4c579c4c8
|
|||
9d405a6324
|
|||
1183ba0d73
|
|||
df85bad9a4
|
|||
cbb99138a9
|
|||
591858ef05 | |||
70de99fd25
|
|||
e7c3a9f116
|
|||
f47c57c1c0
|
|||
dffbcceeba
|
|||
2688c81aed
|
|||
e696663aec
|
|||
ba77091de7
|
|||
4ff3cbe9d9
|
|||
bccaa6935f
|
|||
c299a4e132
|
|||
50ba6c9934
|
|||
f069dae3ee
|
|||
17bb56dd5f
|
|||
e3b920fcd5
|
5
.woodpecker.yml
Normal file
5
.woodpecker.yml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pipeline:
|
||||||
|
test:
|
||||||
|
image: rust
|
||||||
|
commands:
|
||||||
|
- cargo test
|
1314
Cargo.lock
generated
1314
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ authors = ["shimun <shimun@shimun.net>"]
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "ssh-cert-dist"
|
name = "sshcd"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@@ -17,13 +17,14 @@ clap = { version = "4.0.29", features = ["env", "derive"] }
|
|||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.13" }
|
reqwest = { version = "0.11.13" }
|
||||||
serde = { version = "1.0.148", features = ["derive"] }
|
serde = { version = "1.0.148", features = ["derive"] }
|
||||||
ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] }
|
ssh-key = { version = "0.6.0-rc.2", features = ["ed25519", "p256", "p384", "rsa", "serde"] }
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
|
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
|
||||||
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
|
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
|
||||||
tracing-subscriber = "0.3.16"
|
tracing-subscriber = "0.3.16"
|
||||||
url = { version = "2.3.1" }
|
url = { version = "2.3.1" }
|
||||||
ssh-cert-dist-common = { path = "../common" }
|
ssh-cert-dist-common = { path = "../common" }
|
||||||
|
clap_complete_command = "0.5.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.3.0"
|
tempfile = "3.3.0"
|
||||||
|
@@ -1,9 +1,11 @@
|
|||||||
use anyhow::bail;
|
use anyhow::{bail, Context};
|
||||||
use axum_extra::routing::TypedPath;
|
use axum_extra::routing::TypedPath;
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{CommandFactory, Parser, Subcommand, ValueHint};
|
||||||
|
use clap_complete_command::Shell;
|
||||||
use reqwest::{Client, StatusCode};
|
use reqwest::{Client, StatusCode};
|
||||||
use ssh_key::Certificate;
|
use ssh_key::Certificate;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::process;
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
use tokio::fs;
|
use tokio::fs;
|
||||||
use tokio::io::{stdin, AsyncBufReadExt, BufReader};
|
use tokio::io::{stdin, AsyncBufReadExt, BufReader};
|
||||||
@@ -16,7 +18,7 @@ use ssh_cert_dist_common::*;
|
|||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
pub struct ClientArgs {
|
pub struct ClientArgs {
|
||||||
/// Url for the API endpoint
|
/// Url for the API endpoint
|
||||||
#[clap(short = 'a', long = "api-endpoint", env = env_key!("API"))]
|
#[clap(short = 'a', long = "api-endpoint",value_hint = ValueHint::Url, env = env_key!("API"))]
|
||||||
api: Url,
|
api: Url,
|
||||||
/// Require interaction before writing certificates
|
/// Require interaction before writing certificates
|
||||||
#[clap(short = 'i', long = "interactive", env = env_key!("INTERACTIVE"))]
|
#[clap(short = 'i', long = "interactive", env = env_key!("INTERACTIVE"))]
|
||||||
@@ -29,7 +31,7 @@ pub struct FetchArgs {
|
|||||||
args: ClientArgs,
|
args: ClientArgs,
|
||||||
#[clap(short = 'k', long = "key-update", env = env_key!("KEY_UPDATE"))]
|
#[clap(short = 'k', long = "key-update", env = env_key!("KEY_UPDATE"))]
|
||||||
prohibit_key_update: bool,
|
prohibit_key_update: bool,
|
||||||
#[clap(short = 'c', long = "cert-dir", env = env_key!("CERT_DIR"))]
|
#[clap(short = 'c', long = "cert-dir",value_hint = ValueHint::DirPath, env = env_key!("CERT_DIR"))]
|
||||||
cert_dir: PathBuf,
|
cert_dir: PathBuf,
|
||||||
/// minimum time in days between now and expiry to consider checking
|
/// minimum time in days between now and expiry to consider checking
|
||||||
#[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))]
|
#[clap(short = 'd', long = "days", default_value = "60", env = env_key!("MIN_DELTA_DAYS"))]
|
||||||
@@ -41,11 +43,25 @@ pub struct UploadArgs {
|
|||||||
#[clap(flatten)]
|
#[clap(flatten)]
|
||||||
args: ClientArgs,
|
args: ClientArgs,
|
||||||
/// Certificates to be uploaded
|
/// Certificates to be uploaded
|
||||||
#[clap(env = env_key!("FILES"))]
|
#[clap(value_hint = ValueHint::FilePath, env = env_key!("FILES"))]
|
||||||
files: Vec<PathBuf>,
|
files: Vec<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
|
pub struct RenewCommandArgs {
|
||||||
|
/// Execute the renew command
|
||||||
|
#[clap(short = 'x')]
|
||||||
|
execute: bool,
|
||||||
|
/// Path to the CA private key
|
||||||
|
#[clap(long="ca",value_hint = ValueHint::DirPath, env = env_key!("CA_KEY"))]
|
||||||
|
ca_key: Option<PathBuf>,
|
||||||
|
/// Certificates to generate commands for
|
||||||
|
#[clap(value_hint = ValueHint::FilePath,env = env_key!("FILES"))]
|
||||||
|
files: Vec<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "sshcd")]
|
||||||
pub struct ClientCommand {
|
pub struct ClientCommand {
|
||||||
#[clap(subcommand)]
|
#[clap(subcommand)]
|
||||||
cmd: ClientCommands,
|
cmd: ClientCommands,
|
||||||
@@ -55,12 +71,23 @@ pub struct ClientCommand {
|
|||||||
pub enum ClientCommands {
|
pub enum ClientCommands {
|
||||||
Fetch(FetchArgs),
|
Fetch(FetchArgs),
|
||||||
Upload(UploadArgs),
|
Upload(UploadArgs),
|
||||||
|
RenewCommand(RenewCommandArgs),
|
||||||
|
#[clap(hide = true)]
|
||||||
|
Completions {
|
||||||
|
#[arg(long = "shell", value_enum)]
|
||||||
|
shell: Shell,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> {
|
pub async fn run(ClientCommand { cmd }: ClientCommand) -> anyhow::Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
ClientCommands::Fetch(args) => fetch(args).await,
|
ClientCommands::Fetch(args) => fetch(args).await,
|
||||||
ClientCommands::Upload(args) => upload(args).await,
|
ClientCommands::Upload(args) => upload(args).await,
|
||||||
|
ClientCommands::RenewCommand(args) => renew(args).await,
|
||||||
|
ClientCommands::Completions { shell } => {
|
||||||
|
shell.generate(&mut ClientCommand::command(), &mut std::io::stdout());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,11 +143,13 @@ async fn fetch(
|
|||||||
args: ClientArgs { api, interactive },
|
args: ClientArgs { api, interactive },
|
||||||
}: FetchArgs,
|
}: FetchArgs,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let certs = read_dir(&cert_dir).await?;
|
let certs = read_certs_dir(&cert_dir).await?;
|
||||||
|
// let publics_keys = read_pubkey_dir(&cert_dir).await?;
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let threshold_exp = min_delta.and_then(|min_delta| {
|
let threshold_exp = min_delta.and_then(|min_delta| {
|
||||||
SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64))
|
SystemTime::now().checked_add(Duration::from_secs(60 * 60 * 24 * min_delta as u64))
|
||||||
});
|
});
|
||||||
|
// let standalone_certs = publics_keys.into_iter().map(|(name, key)| )
|
||||||
let updates = certs
|
let updates = certs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|cert| {
|
.filter(|cert| {
|
||||||
@@ -168,6 +197,40 @@ async fn fetch(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn renew(
|
||||||
|
RenewCommandArgs {
|
||||||
|
files,
|
||||||
|
ca_key,
|
||||||
|
execute,
|
||||||
|
}: RenewCommandArgs,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
for file in files.iter() {
|
||||||
|
let cert = load_cert(&file).await?;
|
||||||
|
if let Some(cert) = cert {
|
||||||
|
let command = renew_command(
|
||||||
|
&cert,
|
||||||
|
ca_key
|
||||||
|
.as_deref()
|
||||||
|
.map(|path| path.to_str())
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or("ca"),
|
||||||
|
file.to_str(),
|
||||||
|
);
|
||||||
|
println!("{}", command);
|
||||||
|
if execute {
|
||||||
|
process::Command::new("sh")
|
||||||
|
.arg("-c")
|
||||||
|
.arg(&command)
|
||||||
|
.spawn()
|
||||||
|
.with_context(|| format!("{command}"))?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bail!("{file:?} doesn't exist");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[instrument(skip(client, current))]
|
#[instrument(skip(client, current))]
|
||||||
async fn fetch_cert(
|
async fn fetch_cert(
|
||||||
client: Client,
|
client: Client,
|
||||||
|
@@ -11,12 +11,15 @@ anyhow = "1.0.66"
|
|||||||
async-trait = "0.1.59"
|
async-trait = "0.1.59"
|
||||||
axum = { version = "0.6.1" }
|
axum = { version = "0.6.1" }
|
||||||
axum-extra = { version = "0.4.1", features = ["typed-routing"] }
|
axum-extra = { version = "0.4.1", features = ["typed-routing"] }
|
||||||
|
chrono = "0.4.26"
|
||||||
|
hex = { version = "0.4.3", features = ["serde"] }
|
||||||
serde = { version = "1.0.148", features = ["derive"] }
|
serde = { version = "1.0.148", features = ["derive"] }
|
||||||
ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] }
|
ssh-key = { version = "0.6.0-rc.2", features = ["ed25519", "p256", "p384", "rsa"] }
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
|
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
|
||||||
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
|
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
|
||||||
tracing-subscriber = "0.3.16"
|
tracing-subscriber = "0.3.16"
|
||||||
|
shell-escape = "0.1.5"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.3.0"
|
tempfile = "3.3.0"
|
||||||
|
@@ -24,11 +24,11 @@ pub async fn read_certs(
|
|||||||
if !ca_dir.exists() {
|
if !ca_dir.exists() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
read_dir(&ca_dir).await
|
read_certs_dir(&ca_dir).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub async fn read_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<Certificate>> {
|
pub async fn read_certs_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<Certificate>> {
|
||||||
let mut dir = fs::read_dir(path.as_ref())
|
let mut dir = fs::read_dir(path.as_ref())
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("read certs dir '{:?}'", path.as_ref()))?;
|
.with_context(|| format!("read certs dir '{:?}'", path.as_ref()))?;
|
||||||
@@ -55,6 +55,26 @@ pub async fn read_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<Cert
|
|||||||
Ok(certs)
|
Ok(certs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn read_pubkey_dir(path: impl AsRef<Path> + Debug) -> anyhow::Result<Vec<PublicKey>> {
|
||||||
|
let mut dir = fs::read_dir(path.as_ref())
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("read certs dir '{:?}'", path.as_ref()))?;
|
||||||
|
let mut pubs = Vec::new();
|
||||||
|
while let Some(entry) = dir.next_entry().await? {
|
||||||
|
//TODO: investigate why path().ends_with doesn't work
|
||||||
|
let file_name = entry.file_name().into_string().unwrap();
|
||||||
|
if !file_name.ends_with(".pub") || file_name.ends_with("-cert.pub") {
|
||||||
|
trace!("skipped {:?} due to missing '.pub' extension", entry.path());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let cert = load_public_key(entry.path()).await?;
|
||||||
|
if let Some(cert) = cert {
|
||||||
|
pubs.push(cert);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(pubs)
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> {
|
fn parse_utf8(bytes: Vec<u8>) -> anyhow::Result<String> {
|
||||||
String::from_utf8(bytes).context("invalid utf-8")
|
String::from_utf8(bytes).context("invalid utf-8")
|
||||||
}
|
}
|
||||||
@@ -122,3 +142,15 @@ pub async fn load_cert(file: impl AsRef<Path> + Debug) -> anyhow::Result<Option<
|
|||||||
|| format!("parse {:?} as openssh certificate", &file),
|
|| format!("parse {:?} as openssh certificate", &file),
|
||||||
)?))
|
)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn load_public_key(file: impl AsRef<Path> + Debug) -> anyhow::Result<Option<PublicKey>> {
|
||||||
|
let contents = match fs::read(&file).await {
|
||||||
|
Ok(contents) => contents,
|
||||||
|
Err(e) if e.kind() == ErrorKind::NotFound => return Ok(None),
|
||||||
|
Err(e) => return Err(e).with_context(|| format!("read {:?}", &file)),
|
||||||
|
};
|
||||||
|
let string_repr = parse_utf8(contents)?;
|
||||||
|
Ok(Some(PublicKey::from_openssh(&string_repr).with_context(
|
||||||
|
|| format!("parse {:?} as openssh public key", &file),
|
||||||
|
)?))
|
||||||
|
}
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
mod certs;
|
mod certs;
|
||||||
|
mod renew;
|
||||||
mod routes;
|
mod routes;
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
pub use certs::*;
|
pub use certs::*;
|
||||||
|
pub use renew::*;
|
||||||
pub use routes::*;
|
pub use routes::*;
|
||||||
|
49
common/src/renew.rs
Normal file
49
common/src/renew.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use std::borrow::Cow;
|
||||||
|
use std::time::UNIX_EPOCH;
|
||||||
|
|
||||||
|
use chrono::Duration;
|
||||||
|
use shell_escape::escape;
|
||||||
|
use ssh_key::Certificate;
|
||||||
|
|
||||||
|
/// Generates an command to renew the given certs
|
||||||
|
pub fn renew_command(cert: &Certificate, ca_path: &str, file_name: Option<&str>) -> String {
|
||||||
|
let validity = cert
|
||||||
|
.valid_before_time()
|
||||||
|
.duration_since(cert.valid_after_time())
|
||||||
|
.unwrap_or(Duration::zero().to_std().unwrap());
|
||||||
|
let expiry = cert.valid_before_time().checked_add(validity).unwrap();
|
||||||
|
let expiry_date = expiry.duration_since(UNIX_EPOCH).unwrap();
|
||||||
|
let host_key = if cert.cert_type().is_host() {
|
||||||
|
" -h"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
let opts = cert
|
||||||
|
.critical_options()
|
||||||
|
.iter()
|
||||||
|
.map(|(opt, val)| {
|
||||||
|
if val.is_empty() {
|
||||||
|
opt.clone()
|
||||||
|
} else {
|
||||||
|
format!("{opt}={val}")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(|arg| format!("-O {}", escape(arg.into())))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ");
|
||||||
|
let opts = opts.trim();
|
||||||
|
let renew_command = format!(
|
||||||
|
"ssh-keygen -s {ca_path} {host_key} -I {} -n {} -z {} -V {}:{} {opts} {}",
|
||||||
|
escape(cert.key_id().into()),
|
||||||
|
escape(cert.valid_principals().join(",").into()),
|
||||||
|
cert.serial() + 1,
|
||||||
|
cert.valid_after(),
|
||||||
|
expiry_date.as_secs(),
|
||||||
|
escape(
|
||||||
|
file_name
|
||||||
|
.map(|name| name.trim_end_matches("-cert.pub")).map(Cow::Borrowed)
|
||||||
|
.unwrap_or_else(|| escape(format!("{}.pub", cert.key_id()).into()))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
renew_command
|
||||||
|
}
|
@@ -1,24 +1,37 @@
|
|||||||
use axum_extra::routing::TypedPath;
|
use axum_extra::routing::TypedPath;
|
||||||
use serde::Deserialize;
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use ssh_key::Fingerprint;
|
||||||
|
|
||||||
#[derive(TypedPath, Deserialize)]
|
#[derive(TypedPath, Deserialize)]
|
||||||
#[typed_path("/certs")]
|
#[typed_path("/certs")]
|
||||||
pub struct CertList;
|
pub struct CertList;
|
||||||
|
|
||||||
#[derive(TypedPath, Deserialize)]
|
#[derive(TypedPath, Deserialize)]
|
||||||
#[typed_path("/certs/:identifier")]
|
#[typed_path("/cert/:identifier")]
|
||||||
pub struct GetCert {
|
pub struct GetCert {
|
||||||
pub identifier: String,
|
pub identifier: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(TypedPath, Deserialize)]
|
#[derive(TypedPath, Deserialize)]
|
||||||
#[typed_path("/certs/:identifier/info")]
|
#[typed_path("/certs/:pubkey_hash")]
|
||||||
|
pub struct GetCertsPubkey {
|
||||||
|
pub pubkey_hash: Fingerprint,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||||
|
pub struct CertIds {
|
||||||
|
pub ids: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(TypedPath, Deserialize)]
|
||||||
|
#[typed_path("/cert/:identifier/info")]
|
||||||
pub struct GetCertInfo {
|
pub struct GetCertInfo {
|
||||||
pub identifier: String,
|
pub identifier: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(TypedPath, Deserialize)]
|
#[derive(TypedPath, Deserialize)]
|
||||||
#[typed_path("/certs/:identifier")]
|
#[typed_path("/cert/:identifier")]
|
||||||
pub struct PostCertInfo {
|
pub struct PostCertInfo {
|
||||||
pub identifier: String,
|
pub identifier: String,
|
||||||
}
|
}
|
||||||
|
38
flake.lock
generated
38
flake.lock
generated
@@ -7,11 +7,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1662220400,
|
"lastModified": 1698420672,
|
||||||
"narHash": "sha256-9o2OGQqu4xyLZP9K6kNe1pTHnyPz0Wr3raGYnr9AIgY=",
|
"narHash": "sha256-/TdeHMPRjjdJub7p7+w55vyABrsJlt5QkznPYy55vKA=",
|
||||||
"owner": "nmattia",
|
"owner": "nmattia",
|
||||||
"repo": "naersk",
|
"repo": "naersk",
|
||||||
"rev": "6944160c19cb591eb85bbf9b2f2768a935623ed3",
|
"rev": "aeb58d5e8faead8980a807c840232697982d47b9",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -22,11 +22,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1669411043,
|
"lastModified": 1705496572,
|
||||||
"narHash": "sha256-LfPd3+EY+jaIHTRIEOUtHXuanxm59YKgUacmSzaqMLc=",
|
"narHash": "sha256-rPIe9G5EBLXdBdn9ilGc0nq082lzQd0xGGe092R/5QE=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "5dc7114b7b256d217fe7752f1614be2514e61bb8",
|
"rev": "842d9d80cfd4560648c785f8a4e6f3b096790e19",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@@ -41,13 +41,31 @@
|
|||||||
"utils": "utils"
|
"utils": "utils"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"utils": {
|
"systems": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1667395993,
|
"lastModified": 1681028828,
|
||||||
"narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=",
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-systems",
|
||||||
|
"repo": "default",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"utils": {
|
||||||
|
"inputs": {
|
||||||
|
"systems": "systems"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1709126324,
|
||||||
|
"narHash": "sha256-q6EQdSeUZOG26WelxqkmR7kArjgWCdw5sfJVHPH/7j8=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f",
|
"rev": "d465f4819400de7c8d874d50b982301f28a84605",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
26
flake.nix
26
flake.nix
@@ -43,9 +43,11 @@
|
|||||||
# `nix run`
|
# `nix run`
|
||||||
apps."${pname}-server" = utils.lib.mkApp {
|
apps."${pname}-server" = utils.lib.mkApp {
|
||||||
drv = packages."${pname}-server";
|
drv = packages."${pname}-server";
|
||||||
|
exePath = "/bin/sshcd-server";
|
||||||
};
|
};
|
||||||
apps."${pname}-client" = utils.lib.mkApp {
|
apps."${pname}-client" = utils.lib.mkApp {
|
||||||
drv = packages."${pname}-client";
|
drv = packages."${pname}-client";
|
||||||
|
exePath = "/bin/sshcd";
|
||||||
};
|
};
|
||||||
|
|
||||||
# `nix run .#streamDockerImage | docker load`
|
# `nix run .#streamDockerImage | docker load`
|
||||||
@@ -91,7 +93,15 @@
|
|||||||
rustc --version
|
rustc --version
|
||||||
printf "\nbuild inputs: ${pkgs.lib.concatStringsSep ", " (map (bi: bi.name) (buildInputs ++ nativeBuildInputs))}"
|
printf "\nbuild inputs: ${pkgs.lib.concatStringsSep ", " (map (bi: bi.name) (buildInputs ++ nativeBuildInputs))}"
|
||||||
function server() {
|
function server() {
|
||||||
cargo watch -x "run --bin ssh-cert-dist-server --all-features -- ''${@}"
|
if [ ! -e "certs/ca.pub" ]; then
|
||||||
|
mkdir -p certs keys
|
||||||
|
ssh-keygen -t ed25519 -f certs/ca -q -N ""
|
||||||
|
ssh-keygen -t ed25519 -f keys/host -q -N ""
|
||||||
|
ssh-keygen -t ed25519 -f keys/client -q -N ""
|
||||||
|
ssh-keygen -s certs/ca -V +1000d -h -I host -n localhost,127.0.0.1 -h keys/host.pub
|
||||||
|
ssh-keygen -s certs/ca -V +1000d -I client -n "client,client@localhost" keys/client.pub -O force-command="echo Hello World"
|
||||||
|
fi
|
||||||
|
cargo watch -x "run --bin sshcd-server --all-features -- ''${@}"
|
||||||
}
|
}
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
@@ -113,17 +123,31 @@
|
|||||||
];
|
];
|
||||||
nativeBuildInputs = with prev; [
|
nativeBuildInputs = with prev; [
|
||||||
pkg-config
|
pkg-config
|
||||||
|
installShellFiles
|
||||||
];
|
];
|
||||||
|
installCompletions = cmd: ''
|
||||||
|
mkdir completions
|
||||||
|
for shell in bash zsh fish; do
|
||||||
|
$out/bin/${cmd} completions --shell $shell > completions/${cmd}.$shell
|
||||||
|
installShellCompletion --cmd ${cmd} --$shell completions/${cmd}.$shell
|
||||||
|
done
|
||||||
|
'';
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
"${pname}-server" =
|
"${pname}-server" =
|
||||||
naersk-lib.buildPackage {
|
naersk-lib.buildPackage {
|
||||||
name = "${pname}-server";
|
name = "${pname}-server";
|
||||||
inherit root buildInputs nativeBuildInputs;
|
inherit root buildInputs nativeBuildInputs;
|
||||||
|
# postInstall = ''
|
||||||
|
# ${installCompletions}
|
||||||
|
# '';
|
||||||
};
|
};
|
||||||
"${pname}-client" =
|
"${pname}-client" =
|
||||||
naersk-lib.buildPackage {
|
naersk-lib.buildPackage {
|
||||||
name = "${pname}-client";
|
name = "${pname}-client";
|
||||||
|
postInstall = ''
|
||||||
|
${installCompletions "sshcd"}
|
||||||
|
'';
|
||||||
inherit root buildInputs nativeBuildInputs;
|
inherit root buildInputs nativeBuildInputs;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@@ -11,22 +11,35 @@ in
|
|||||||
Unit.Description = "ssh-cert-dist service for ${path}";
|
Unit.Description = "ssh-cert-dist service for ${path}";
|
||||||
Service = {
|
Service = {
|
||||||
Environment = "RUST_LOG=debug";
|
Environment = "RUST_LOG=debug";
|
||||||
ExecStart = toString (pkgs.writeShellApplication {
|
ExecStart = "${pkgs.writeShellApplication {
|
||||||
name = "ssh-cert-dist-${options.name}";
|
name = "sshcd";
|
||||||
runtimeInputs = [ cfg.package ];
|
runtimeInputs = [ cfg.package ];
|
||||||
text = ''
|
text = ''
|
||||||
${optionalString options.fetch ''
|
${optionalString options.fetch ''
|
||||||
ssh-cert-dist fetch --cert-dir '${path}' --api-endpoint '${cfg.endpoint}'
|
sshcd fetch --cert-dir '${path}' --api-endpoint '${cfg.endpoint}'
|
||||||
''}
|
''}
|
||||||
${optionalString options.upload ''
|
${optionalString options.upload ''
|
||||||
ssh-cert-dist upload --api-endpoint '${cfg.endpoint}' ${path}/*
|
sshcd upload --api-endpoint '${cfg.endpoint}' ${path}/*
|
||||||
''}
|
''}
|
||||||
'';
|
'';
|
||||||
});
|
}}/bin/sshcd";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
cfg.directories);
|
cfg.directories);
|
||||||
|
config.systemd.user.timers = mkIf cfg.enable (mapAttrs'
|
||||||
|
(path: options: {
|
||||||
|
inherit (options) name; value = {
|
||||||
|
Unit.Description = "ssh-cert-dist service for ${path}";
|
||||||
|
Timer = {
|
||||||
|
OnCalendar = options.interval;
|
||||||
|
Persistent = true;
|
||||||
|
Unit = "${options.name}.service";
|
||||||
|
};
|
||||||
|
Install.WantedBy = [ "timers.target" ];
|
||||||
|
};
|
||||||
|
})
|
||||||
|
cfg.directories);
|
||||||
config.home.sessionVariables = mkIf (cfg.enable && cfg.endpoint != null) {
|
config.home.sessionVariables = mkIf (cfg.enable && cfg.endpoint != null) {
|
||||||
SSH_CD_API = cfg.endpoint;
|
SSH_CD_API = cfg.endpoint;
|
||||||
};
|
};
|
||||||
|
@@ -57,7 +57,7 @@ in
|
|||||||
chown ${cfg.user}:${cfg.group} ${cfg.dataDir}
|
chown ${cfg.user}:${cfg.group} ${cfg.dataDir}
|
||||||
''}";
|
''}";
|
||||||
User = cfg.user;
|
User = cfg.user;
|
||||||
ExecStart = "${cfg.package}/bin/ssh-cert-dist-server";
|
ExecStart = "${cfg.package}/bin/sshcd-server";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@@ -13,6 +13,11 @@
|
|||||||
type = types.bool;
|
type = types.bool;
|
||||||
default = false;
|
default = false;
|
||||||
};
|
};
|
||||||
|
interval = mkOption {
|
||||||
|
type = types.str;
|
||||||
|
default = "daily";
|
||||||
|
description = "https://www.freedesktop.org/software/systemd/man/systemd.time.html";
|
||||||
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
endpointOption = mkOption {
|
endpointOption = mkOption {
|
||||||
|
@@ -13,6 +13,9 @@ authorized =[ "dep:jwt-compact" ]
|
|||||||
index = []
|
index = []
|
||||||
info = [ "axum/json", "ssh-key/serde" ]
|
info = [ "axum/json", "ssh-key/serde" ]
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "sshcd-server"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.66"
|
anyhow = "1.0.66"
|
||||||
@@ -24,11 +27,11 @@ clap = { version = "4.0.29", features = ["env", "derive"] }
|
|||||||
jwt-compact = { version = "0.6.0", features = ["serde_cbor", "std", "clock"], optional = true }
|
jwt-compact = { version = "0.6.0", features = ["serde_cbor", "std", "clock"], optional = true }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
serde = { version = "1.0.148", features = ["derive"] }
|
serde = { version = "1.0.148", features = ["derive"] }
|
||||||
ssh-key = { version = "0.5.1", features = ["ed25519", "p256", "p384", "rsa", "signature"] }
|
ssh-key = { version = "0.6.0-rc.2", features = ["ed25519", "p256", "p384", "rsa"] }
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
|
tokio = { version = "1.22.0", features = ["io-std", "test-util", "tracing", "macros", "fs"] }
|
||||||
tower = { version = "0.4.13", features = ["util"] }
|
tower = { version = "0.4.13" }
|
||||||
tower-http = { version = "0.3.4", features = ["map-request-body", "trace"] }
|
tower-http = { version = "0.3.4", features = ["map-request-body", "trace", "util"] }
|
||||||
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
|
tracing = { version = "0.1.37", features = ["release_max_level_debug"] }
|
||||||
tracing-subscriber = "0.3.16"
|
tracing-subscriber = "0.3.16"
|
||||||
ssh-cert-dist-common = { path = "../common" }
|
ssh-cert-dist-common = { path = "../common" }
|
||||||
|
@@ -1,21 +1,26 @@
|
|||||||
mod extract;
|
mod extract;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::path::{self, PathBuf};
|
use std::path::{self, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
|
||||||
use axum::body;
|
use axum::body;
|
||||||
|
use axum::extract::rejection::QueryRejection;
|
||||||
use axum::extract::{Query, State};
|
use axum::extract::{Query, State};
|
||||||
|
|
||||||
use ssh_cert_dist_common::*;
|
use ssh_cert_dist_common::*;
|
||||||
|
|
||||||
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
use axum::{http::StatusCode, response::IntoResponse, Json, Router};
|
||||||
use axum_extra::routing::RouterExt;
|
use axum_extra::routing::RouterExt;
|
||||||
use clap::{Args, Parser};
|
use clap::{Args, Parser};
|
||||||
use jwt_compact::alg::{Hs256, Hs256Key};
|
use jwt_compact::alg::{Hs256, Hs256Key};
|
||||||
use jwt_compact::{AlgorithmExt, Token, UntrustedToken};
|
use jwt_compact::AlgorithmExt;
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use ssh_key::{Certificate, Fingerprint, PublicKey};
|
use ssh_key::{Certificate, Fingerprint, PublicKey};
|
||||||
@@ -24,9 +29,10 @@ use tower::ServiceBuilder;
|
|||||||
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
use tower_http::{trace::TraceLayer, ServiceBuilderExt};
|
||||||
use tracing::{debug, info, trace};
|
use tracing::{debug, info, trace};
|
||||||
|
|
||||||
use self::extract::{CertificateBody, SignatureBody};
|
use self::extract::{CertificateBody, JWTAuthenticated, JWTString, SignatureBody};
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
|
#[command(name = "sshcd-server")]
|
||||||
pub struct ApiArgs {
|
pub struct ApiArgs {
|
||||||
#[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))]
|
#[clap(short = 'a', long = "address", env = env_key!("SOCKET_ADDRESS"))]
|
||||||
address: SocketAddr,
|
address: SocketAddr,
|
||||||
@@ -72,7 +78,7 @@ impl Default for ApiArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct ApiState {
|
pub struct ApiState {
|
||||||
certs: Arc<Mutex<HashMap<String, Certificate>>>,
|
certs: Arc<Mutex<HashMap<String, Certificate>>>,
|
||||||
cert_dir: PathBuf,
|
cert_dir: PathBuf,
|
||||||
ca: PublicKey,
|
ca: PublicKey,
|
||||||
@@ -173,12 +179,23 @@ pub enum ApiError {
|
|||||||
AuthenticationRequired(String),
|
AuthenticationRequired(String),
|
||||||
#[error("invalid ssh signature")]
|
#[error("invalid ssh signature")]
|
||||||
InvalidSignature,
|
InvalidSignature,
|
||||||
|
#[error("malformed ssh signature: {0}")]
|
||||||
|
ParseSignature(anyhow::Error),
|
||||||
|
#[error("malformed ssh certificate: {0}")]
|
||||||
|
ParseCertificate(anyhow::Error),
|
||||||
|
#[error("{0}")]
|
||||||
|
JWTParse(#[from] jwt_compact::ParseError),
|
||||||
|
#[error("{0}")]
|
||||||
|
JWTVerify(#[from] jwt_compact::ValidationError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Query(#[from] QueryRejection),
|
||||||
}
|
}
|
||||||
|
|
||||||
type ApiResult<T> = Result<T, ApiError>;
|
type ApiResult<T> = Result<T, ApiError>;
|
||||||
|
|
||||||
impl IntoResponse for ApiError {
|
impl IntoResponse for ApiError {
|
||||||
fn into_response(self) -> axum::response::Response {
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
trace!({ error = ?self }, "returned error for request");
|
||||||
(
|
(
|
||||||
match self {
|
match self {
|
||||||
Self::CertificateNotFound => StatusCode::NOT_FOUND,
|
Self::CertificateNotFound => StatusCode::NOT_FOUND,
|
||||||
@@ -214,7 +231,7 @@ async fn list_certs(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "aud", rename = "get")]
|
#[serde(tag = "aud", rename = "get")]
|
||||||
struct AuthClaims {
|
struct AuthClaims {
|
||||||
identifier: String,
|
identifier: String,
|
||||||
@@ -291,13 +308,6 @@ struct CertInfo {
|
|||||||
|
|
||||||
impl From<&Certificate> for CertInfo {
|
impl From<&Certificate> for CertInfo {
|
||||||
fn from(cert: &Certificate) -> Self {
|
fn from(cert: &Certificate) -> Self {
|
||||||
let validity = cert.valid_after_time().duration_since(cert.valid_before_time()).unwrap();
|
|
||||||
let validity_days = validity.as_secs() / ((60*60) * 24);
|
|
||||||
let host_key = if cert.cert_type().is_host() {
|
|
||||||
" -h"
|
|
||||||
} else { "" };
|
|
||||||
let opts = cert.critical_options().iter().map(|(opt, val)| if val.is_empty() { opt.clone() } else { format!("{opt}={val}") }).map(|arg| format!("-O {arg}")).join(" ");
|
|
||||||
let renew_command = format!("ssh-keygen -s ./ca_key {host_key} -I {} -n {} -V {validity_days}d {opts}", cert.key_id(), cert.valid_principals().join(","));
|
|
||||||
CertInfo {
|
CertInfo {
|
||||||
principals: cert.valid_principals().to_vec(),
|
principals: cert.valid_principals().to_vec(),
|
||||||
ca: cert.signature_key().clone().into(),
|
ca: cert.signature_key().clone().into(),
|
||||||
@@ -306,7 +316,7 @@ impl From<&Certificate> for CertInfo {
|
|||||||
identity_hash: cert.public_key().fingerprint(ssh_key::HashAlg::Sha256),
|
identity_hash: cert.public_key().fingerprint(ssh_key::HashAlg::Sha256),
|
||||||
key_id: cert.key_id().to_string(),
|
key_id: cert.key_id().to_string(),
|
||||||
expiry: cert.valid_before_time(),
|
expiry: cert.valid_before_time(),
|
||||||
renew_command
|
renew_command: renew_command(cert, "./ca", None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -336,22 +346,25 @@ struct PostCertsQuery {
|
|||||||
challenge: String,
|
challenge: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Query<PostCertsQuery>> for JWTString {
|
||||||
|
fn from(Query(PostCertsQuery { challenge }): Query<PostCertsQuery>) -> Self {
|
||||||
|
Self::from(challenge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// POST with signed challenge
|
/// POST with signed challenge
|
||||||
async fn post_certs_identifier(
|
async fn post_certs_identifier(
|
||||||
PostCertInfo { identifier }: PostCertInfo,
|
PostCertInfo { identifier }: PostCertInfo,
|
||||||
State(ApiState { certs, jwt_key, .. }): State<ApiState>,
|
State(ApiState { certs, .. }): State<ApiState>,
|
||||||
|
JWTAuthenticated {
|
||||||
|
data: auth_claims, ..
|
||||||
|
}: JWTAuthenticated<AuthClaims, Query<PostCertsQuery>>,
|
||||||
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
Query(PostCertsQuery { challenge }): Query<PostCertsQuery>,
|
||||||
SignatureBody(sig): SignatureBody,
|
SignatureBody(sig): SignatureBody,
|
||||||
) -> ApiResult<String> {
|
) -> ApiResult<String> {
|
||||||
let certs = certs.lock().await;
|
let certs = certs.lock().await;
|
||||||
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
let cert = certs.get(&identifier).ok_or(ApiError::InvalidSignature)?;
|
||||||
let token: Token<AuthClaims> = Hs256
|
if auth_claims.identifier != identifier {
|
||||||
.validate_integrity(
|
|
||||||
&UntrustedToken::new(&challenge).context("jwt parse")?,
|
|
||||||
&jwt_key,
|
|
||||||
)
|
|
||||||
.map_err(|_| ApiError::InvalidSignature)?;
|
|
||||||
if token.claims().custom.identifier != identifier {
|
|
||||||
return Err(ApiError::InvalidSignature);
|
return Err(ApiError::InvalidSignature);
|
||||||
}
|
}
|
||||||
let pubkey: PublicKey = cert.public_key().clone().into();
|
let pubkey: PublicKey = cert.public_key().clone().into();
|
||||||
@@ -462,7 +475,8 @@ mod tests {
|
|||||||
user_key,
|
user_key,
|
||||||
unix_time(SystemTime::now()),
|
unix_time(SystemTime::now()),
|
||||||
unix_time(SystemTime::now() + validity),
|
unix_time(SystemTime::now() + validity),
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
builder
|
builder
|
||||||
.valid_principal("git")
|
.valid_principal("git")
|
||||||
@@ -509,7 +523,7 @@ mod tests {
|
|||||||
)
|
)
|
||||||
};
|
};
|
||||||
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_first)).await;
|
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_first)).await;
|
||||||
assert!(res.is_ok());
|
assert!(dbg!(res).is_ok());
|
||||||
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_newer)).await;
|
let res = put_cert_update(PutCert, State(state.clone()), CertificateBody(cert_newer)).await;
|
||||||
assert!(res.is_ok());
|
assert!(res.is_ok());
|
||||||
let res = put_cert_update(
|
let res = put_cert_update(
|
||||||
@@ -582,6 +596,9 @@ mod tests {
|
|||||||
identifier: "test_cert".into(),
|
identifier: "test_cert".into(),
|
||||||
},
|
},
|
||||||
State(state.clone()),
|
State(state.clone()),
|
||||||
|
JWTAuthenticated::new(AuthClaims {
|
||||||
|
identifier: "test_cert".into(),
|
||||||
|
}),
|
||||||
Query(PostCertsQuery { challenge }),
|
Query(PostCertsQuery { challenge }),
|
||||||
SignatureBody(sig),
|
SignatureBody(sig),
|
||||||
)
|
)
|
||||||
|
@@ -1,6 +1,16 @@
|
|||||||
use super::ApiError;
|
use std::fmt::Debug;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use super::{ApiError, ApiState};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{async_trait, body::BoxBody, extract::FromRequest, http::Request};
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
body::BoxBody,
|
||||||
|
extract::{FromRequest, FromRequestParts},
|
||||||
|
http::Request,
|
||||||
|
};
|
||||||
|
use jwt_compact::{alg::Hs256, AlgorithmExt, Token, UntrustedToken};
|
||||||
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use ssh_key::{Certificate, SshSig};
|
use ssh_key::{Certificate, SshSig};
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
|
||||||
@@ -21,7 +31,8 @@ where
|
|||||||
.context("failed to extract body")?;
|
.context("failed to extract body")?;
|
||||||
|
|
||||||
let cert = Certificate::from_openssh(&body)
|
let cert = Certificate::from_openssh(&body)
|
||||||
.with_context(|| format!("failed to parse '{}'", body))?;
|
.with_context(|| format!("failed to parse '{}'", body))
|
||||||
|
.map_err(ApiError::ParseCertificate)?;
|
||||||
trace!(%body, "extracted certificate");
|
trace!(%body, "extracted certificate");
|
||||||
Ok(Self(cert))
|
Ok(Self(cert))
|
||||||
}
|
}
|
||||||
@@ -42,8 +53,71 @@ where
|
|||||||
.await
|
.await
|
||||||
.context("failed to extract body")?;
|
.context("failed to extract body")?;
|
||||||
|
|
||||||
let sig = SshSig::from_pem(&body).with_context(|| format!("failed to parse '{}'", body))?;
|
let sig = SshSig::from_pem(&body)
|
||||||
|
.with_context(|| format!("failed to parse '{}'", body))
|
||||||
|
.map_err(ApiError::ParseSignature)?;
|
||||||
trace!(%body, "extracted signature");
|
trace!(%body, "extracted signature");
|
||||||
Ok(Self(sig))
|
Ok(Self(sig))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct JWTString(String);
|
||||||
|
|
||||||
|
impl From<String> for JWTString {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: be generic over ApiState -> AsRef<Target=Hs256>, AsRef<Target=A> where A: AlgorithmExt
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct JWTAuthenticated<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
||||||
|
> where
|
||||||
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
||||||
|
{
|
||||||
|
pub data: T,
|
||||||
|
_marker: PhantomData<Q>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
||||||
|
> JWTAuthenticated<T, Q>
|
||||||
|
where
|
||||||
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
||||||
|
{
|
||||||
|
pub fn new(data: T) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
_marker: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<
|
||||||
|
T: Serialize + DeserializeOwned + Clone + Debug,
|
||||||
|
Q: FromRequestParts<ApiState> + Debug + Into<JWTString>,
|
||||||
|
> FromRequestParts<ApiState> for JWTAuthenticated<T, Q>
|
||||||
|
where
|
||||||
|
ApiError: From<<Q as FromRequestParts<ApiState>>::Rejection>,
|
||||||
|
{
|
||||||
|
type Rejection = ApiError;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut axum::http::request::Parts,
|
||||||
|
state: &ApiState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let JWTString(token) = Q::from_request_parts(parts, state).await?.into();
|
||||||
|
let token = UntrustedToken::new(&token).map_err(ApiError::JWTParse)?;
|
||||||
|
let verified: Token<T> = Hs256
|
||||||
|
.validate_integrity(&token, &state.jwt_key)
|
||||||
|
.map_err(ApiError::JWTVerify)?;
|
||||||
|
Ok(Self {
|
||||||
|
data: verified.claims().custom.clone(),
|
||||||
|
_marker: Default::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user