use base64::{decode, encode}; use hex; use std::error::Error; use std::fmt; use std::hash::{Hash, Hasher}; use std::io; use std::net::{IpAddr, SocketAddr}; use std::time::Instant; use std::time::{Duration, SystemTime, UNIX_EPOCH}; const KEY_SIZE: usize = 48; //TODO: use VEC instead of array #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum ECCKey { PublicKey(Vec), PrivateKey(Vec), } impl fmt::Display for ECCKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.as_base64().unwrap()) } } pub trait HexBackend { fn from_bytes(bytes: Vec) -> Self; fn bytes(&self) -> &Vec; fn from_hex>(key: I) -> io::Result where Self: Sized, { Ok(Self::from_bytes(hex::decode(key.as_ref()).map_err( |_| io::Error::new(io::ErrorKind::InvalidData, "Failed to decode hexstring"), )?)) } fn as_hex(&self) -> io::Result { Ok(hex::encode(&self.bytes())) } } impl HexBackend for T { fn from_bytes(bytes: Vec) -> Self { ::from_bytes(bytes) } fn bytes(&self) -> &Vec { ::bytes(self) } } pub trait Base64Backed { fn from_bytes(bytes: Vec) -> Self; fn bytes(&self) -> &Vec; fn from_base64>(key: I) -> io::Result where Self: Sized, { let key = match decode(key.as_ref()) { Ok(key) => key, _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, "Failed to decode base64", )); } }; /*.map_err(|err| { })?;*/ if key.len() != KEY_SIZE { return Err(io::Error::new( io::ErrorKind::Other, format!( "Mismatched key size. Expected: {}, Got {}", KEY_SIZE, key.len() ), )); } Ok(Self::from_bytes(key)) } fn as_base64(&self) -> io::Result { Ok(encode(self.bytes())) } } impl Base64Backed for ECCKey { fn bytes(&self) -> &Vec { match self { ECCKey::PublicKey(bytes) => &bytes, ECCKey::PrivateKey(bytes) => &bytes, } } fn from_bytes(bytes: Vec) -> ECCKey { ECCKey::PublicKey(bytes) } } impl ECCKey { pub fn public_key(&self) -> Option { //TODO: Determine whether Self is a private key and only the return public part Some(self.clone()) } } #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct SharedKey(Vec); impl fmt::Display for SharedKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.as_base64().unwrap()) } } impl Base64Backed for SharedKey { fn bytes(&self) -> &Vec { &self.0 } fn from_bytes(bytes: Vec) -> SharedKey { SharedKey(bytes) } } #[derive(Debug, Builder, PartialEq, Eq, Clone)] pub struct Interface { pub key: ECCKey, pub port: usize, pub fwmark: Option, } impl Hash for Interface { fn hash(&self, state: &mut H) { self.key.public_key().hash(state); } } #[derive(Debug, Builder, PartialEq, Eq, Clone)] pub struct Peer { pub key: ECCKey, #[builder(default = "None")] pub shared_key: Option, #[builder(default = "None")] pub endpoint: Option, #[builder(default = "Vec::new()")] pub allowed_ips: Vec<(IpAddr, u8)>, #[builder(default = "None")] pub last_handshake: Option, #[builder(default = "None")] pub persistent_keepalive: Option, #[builder(default = "(0u64,0u64)")] pub traffic: (u64, u64), #[builder(default = "Instant::now()")] pub parsed: Instant, } impl Hash for Peer { fn hash(&self, state: &mut H) { self.key.hash(state); } } impl fmt::Display for Peer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn dis_opt<'a, T: fmt::Display + 'a>(opt: &Option) -> String { opt.as_ref() .map(|s| s.to_string()) .unwrap_or(" ".to_string()) } write!( f, "peer {} {}{}{}", self.key, dis_opt(&self.shared_key), dis_opt(&self.endpoint), self.allowed_ips .iter() .map(|(ip, sub)| format!(" {}/{}", ip, sub)) .collect::>() .join(",") ) } } impl PeerBuilder { fn validate(&self) -> Result<(), String> { if let Some(ref key) = self.key { Ok(()) } else { Err("No key supplied".into()) } } pub fn is_whole(&self) -> bool { self.validate().is_ok() } pub fn has_key(&self) -> bool { self.key.is_some() } pub fn add_allowed_ip(&mut self, ip: (IpAddr, u8)) { if let Some(ref mut ips) = &mut self.allowed_ips { ips.push(ip); } else { self.allowed_ips = Some(vec![ip]); } } pub fn add_last_handshake(&mut self, d: Duration) { if !self.last_handshake.is_some() { self.last_handshake = Some(Some(UNIX_EPOCH + d)); } else { self.last_handshake = self .last_handshake .map(|shake| shake.map(|shake| shake + d)); } } pub fn add_traffic(&mut self, txrx: (u64, u64)) { if let Some(ref mut traffic) = &mut self.traffic { traffic.0 += txrx.0; traffic.1 += txrx.1; } else { self.traffic = Some(txrx); } } } pub trait WireguardController { fn peers<'a>(&'a mut self) -> io::Result> + 'a>>; fn interface(&mut self) -> io::Result; fn update_peer(&mut self, peer: &Peer) -> io::Result<()>; } #[cfg(test)] mod test { use super::*; #[test] fn key_encoding() { let key_encoded = "08df3bebd54217eb769d607f8673e1c3c53bb55d6ac689348a9227c8c4dd8857"; let key = ECCKey::from_hex(key_encoded).unwrap(); assert_eq!(&key.as_hex().unwrap(), key_encoded); } }