248 lines
6.2 KiB
Rust
248 lines
6.2 KiB
Rust
use base64::{decode, encode};
|
|
use hex;
|
|
use std::error::Error;
|
|
use std::fmt;
|
|
use std::hash::{Hash, Hasher};
|
|
use std::io;
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::time::Instant;
|
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
|
|
const KEY_SIZE: usize = 48; //TODO: use VEC instead of array
|
|
|
|
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
|
pub enum ECCKey {
|
|
PublicKey(Vec<u8>),
|
|
PrivateKey(Vec<u8>),
|
|
}
|
|
|
|
impl fmt::Display for ECCKey {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(f, "{}", self.as_base64().unwrap())
|
|
}
|
|
}
|
|
|
|
pub trait HexBackend {
|
|
fn from_bytes(bytes: Vec<u8>) -> Self;
|
|
fn bytes(&self) -> &Vec<u8>;
|
|
fn from_hex<I: AsRef<str>>(key: I) -> io::Result<Self>
|
|
where
|
|
Self: Sized,
|
|
{
|
|
Ok(Self::from_bytes(hex::decode(key.as_ref()).map_err(
|
|
|_| io::Error::new(io::ErrorKind::InvalidData, "Failed to decode hexstring"),
|
|
)?))
|
|
}
|
|
fn as_hex(&self) -> io::Result<String> {
|
|
Ok(hex::encode(&self.bytes()))
|
|
}
|
|
}
|
|
|
|
impl<T: Base64Backed> HexBackend for T {
|
|
fn from_bytes(bytes: Vec<u8>) -> Self {
|
|
<Self as Base64Backed>::from_bytes(bytes)
|
|
}
|
|
fn bytes(&self) -> &Vec<u8> {
|
|
<Self as Base64Backed>::bytes(self)
|
|
}
|
|
}
|
|
|
|
pub trait Base64Backed {
|
|
fn from_bytes(bytes: Vec<u8>) -> Self;
|
|
fn bytes(&self) -> &Vec<u8>;
|
|
fn from_base64<I: AsRef<str>>(key: I) -> io::Result<Self>
|
|
where
|
|
Self: Sized,
|
|
{
|
|
let key = match decode(key.as_ref()) {
|
|
Ok(key) => key,
|
|
_ => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidData,
|
|
"Failed to decode base64",
|
|
));
|
|
}
|
|
}; /*.map_err(|err| {
|
|
|
|
})?;*/
|
|
if key.len() != KEY_SIZE {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::Other,
|
|
format!(
|
|
"Mismatched key size. Expected: {}, Got {}",
|
|
KEY_SIZE,
|
|
key.len()
|
|
),
|
|
));
|
|
}
|
|
Ok(Self::from_bytes(key))
|
|
}
|
|
|
|
fn as_base64(&self) -> io::Result<String> {
|
|
Ok(encode(self.bytes()))
|
|
}
|
|
}
|
|
|
|
impl Base64Backed for ECCKey {
|
|
fn bytes(&self) -> &Vec<u8> {
|
|
match self {
|
|
ECCKey::PublicKey(bytes) => &bytes,
|
|
ECCKey::PrivateKey(bytes) => &bytes,
|
|
}
|
|
}
|
|
|
|
fn from_bytes(bytes: Vec<u8>) -> ECCKey {
|
|
ECCKey::PublicKey(bytes)
|
|
}
|
|
}
|
|
|
|
impl ECCKey {
|
|
pub fn public_key(&self) -> Option<ECCKey> {
|
|
//TODO: Determine whether Self is a private key and only the return public part
|
|
Some(self.clone())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
|
pub struct SharedKey(Vec<u8>);
|
|
|
|
impl fmt::Display for SharedKey {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(f, "{}", self.as_base64().unwrap())
|
|
}
|
|
}
|
|
|
|
impl Base64Backed for SharedKey {
|
|
fn bytes(&self) -> &Vec<u8> {
|
|
&self.0
|
|
}
|
|
fn from_bytes(bytes: Vec<u8>) -> SharedKey {
|
|
SharedKey(bytes)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Builder, PartialEq, Eq, Clone)]
|
|
pub struct Interface {
|
|
pub key: ECCKey,
|
|
pub port: usize,
|
|
pub fwmark: Option<String>,
|
|
}
|
|
|
|
impl Hash for Interface {
|
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
|
self.key.public_key().hash(state);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Builder, PartialEq, Eq, Clone)]
|
|
pub struct Peer {
|
|
pub key: ECCKey,
|
|
#[builder(default = "None")]
|
|
pub shared_key: Option<SharedKey>,
|
|
#[builder(default = "None")]
|
|
pub endpoint: Option<SocketAddr>,
|
|
#[builder(default = "Vec::new()")]
|
|
pub allowed_ips: Vec<(IpAddr, u8)>,
|
|
#[builder(default = "None")]
|
|
pub last_handshake: Option<SystemTime>,
|
|
#[builder(default = "None")]
|
|
pub persistent_keepalive: Option<Duration>,
|
|
#[builder(default = "(0u64,0u64)")]
|
|
pub traffic: (u64, u64),
|
|
#[builder(default = "Instant::now()")]
|
|
pub parsed: Instant,
|
|
}
|
|
|
|
impl Hash for Peer {
|
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
|
self.key.hash(state);
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Peer {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
fn dis_opt<'a, T: fmt::Display + 'a>(opt: &Option<T>) -> String {
|
|
opt.as_ref()
|
|
.map(|s| s.to_string())
|
|
.unwrap_or(" ".to_string())
|
|
}
|
|
write!(
|
|
f,
|
|
"peer {} {}{}{}",
|
|
self.key,
|
|
dis_opt(&self.shared_key),
|
|
dis_opt(&self.endpoint),
|
|
self.allowed_ips
|
|
.iter()
|
|
.map(|(ip, sub)| format!(" {}/{}", ip, sub))
|
|
.collect::<Vec<_>>()
|
|
.join(",")
|
|
)
|
|
}
|
|
}
|
|
|
|
impl PeerBuilder {
|
|
fn validate(&self) -> Result<(), String> {
|
|
if let Some(ref key) = self.key {
|
|
Ok(())
|
|
} else {
|
|
Err("No key supplied".into())
|
|
}
|
|
}
|
|
|
|
pub fn is_whole(&self) -> bool {
|
|
self.validate().is_ok()
|
|
}
|
|
|
|
pub fn has_key(&self) -> bool {
|
|
self.key.is_some()
|
|
}
|
|
|
|
pub fn add_allowed_ip(&mut self, ip: (IpAddr, u8)) {
|
|
if let Some(ref mut ips) = &mut self.allowed_ips {
|
|
ips.push(ip);
|
|
} else {
|
|
self.allowed_ips = Some(vec![ip]);
|
|
}
|
|
}
|
|
|
|
pub fn add_last_handshake(&mut self, d: Duration) {
|
|
if !self.last_handshake.is_some() {
|
|
self.last_handshake = Some(Some(UNIX_EPOCH + d));
|
|
} else {
|
|
self.last_handshake = self
|
|
.last_handshake
|
|
.map(|shake| shake.map(|shake| shake + d));
|
|
}
|
|
}
|
|
|
|
pub fn add_traffic(&mut self, txrx: (u64, u64)) {
|
|
if let Some(ref mut traffic) = &mut self.traffic {
|
|
traffic.0 += txrx.0;
|
|
traffic.1 += txrx.1;
|
|
} else {
|
|
self.traffic = Some(txrx);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait WireguardController {
|
|
fn peers<'a>(&'a mut self) -> io::Result<Box<Iterator<Item = io::Result<Peer>> + 'a>>;
|
|
|
|
fn interface(&mut self) -> io::Result<Interface>;
|
|
|
|
fn update_peer(&mut self, peer: &Peer) -> io::Result<()>;
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
|
|
use super::*;
|
|
#[test]
|
|
fn key_encoding() {
|
|
let key_encoded = "08df3bebd54217eb769d607f8673e1c3c53bb55d6ac689348a9227c8c4dd8857";
|
|
let key = ECCKey::from_hex(key_encoded).unwrap();
|
|
assert_eq!(&key.as_hex().unwrap(), key_encoded);
|
|
}
|
|
}
|