mod gen; mod listener; use crate::gen::*; use crate::listener::*; use base64; use hex; use std::collections::HashMap; use std::env; use std::fmt; use std::io::prelude::*; use std::io::{BufRead, BufReader, Error, ErrorKind, Result}; use std::net::{IpAddr, SocketAddr}; use std::os::unix::net::UnixStream; use std::path::PathBuf; use std::thread; use std::time::Duration; use time; pub type KV = (String, String); #[derive(Debug, PartialEq, Eq, Hash)] enum State { Interface(Vec), Peer(Vec), } #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct Peer { public_key: String, endpoint: Option, allowed_ips: Vec<(IpAddr, u8)>, last_handshake: Option, persistent_keepalive: Option, traffic: (u64, u64), parsed: time::Timespec, } impl Peer { fn from_kv(entries: &Vec) -> Result { let key = match entries .iter() .filter(|(key, _)| key == &"public_key") .map(|(_, value)| value) .next() { Some(key) => key, None => return Err(Error::new(ErrorKind::Other, "Peer is missing key")), }; Ok(Peer { public_key: base64::encode(&hex::decode(key).unwrap()), endpoint: entries .iter() .filter(|(key, _)| key == &"endpoint") .map(|(_, value)| value.parse::().unwrap()) .next(), allowed_ips: entries .iter() .filter(|(key, _)| key == &"allowed_ip") .map(|(_, value)| { let mut parts = value.split("/").into_iter(); match ( parts.next().and_then(|addr| addr.parse::().ok()), parts.next().and_then(|mask| mask.parse::().ok()), ) { (Some(addr), Some(mask)) => Some((addr, mask)), (Some(addr), None) if addr.is_ipv6() => Some((addr, 128)), (Some(addr), None) => Some((addr, 32)), _ => None, } }) .filter_map(|net| net) .collect::>(), last_handshake: entries .iter() .filter_map(|(key, value)| { let value = || value.parse::().unwrap(); match key.as_ref() { "last_handshake_time_sec" if value() != 0 => { Some(Duration::new(value(), 0)) } "last_handshake_time_nsec" if value() != 0 => { Some(Duration::from_nanos(value())) } _ => None, } }) .fold(None, |acc, add| { if let Some(dur) = acc { Some(dur + add) } else { Some(add) } }), persistent_keepalive: entries .iter() .filter(|(key, _)| key == &"persistent_keepalive") .map(|(_, value)| Duration::from_secs(value.parse::().unwrap())) .next(), traffic: (0, 0), parsed: time::get_time(), }) } pub fn last_handshake_rel(&self) -> Option { let time = self.parsed; Some(Duration::new(time.sec as u64, time.nsec as u32) - self.last_handshake?) } } impl State { pub fn kv(&self) -> &Vec { match self { State::Interface(kv) => kv, State::Peer(kv) => kv, } } fn kv_mut(&mut self) -> &mut Vec { match self { State::Interface(kv) => kv, State::Peer(kv) => kv, } } pub fn id<'a>(&'a self) -> Option { self.kv() .iter() .filter(|(key, _)| key == &"private_key" || key == &"public_key") .map(|(_, value)| base64::encode(&hex::decode(&value).unwrap())) .next() } pub fn addr(&self) -> Option { self.kv() .iter() .filter(|(key, _)| key == &"endpoint") .map(|(_, value)| value.parse::().unwrap()) .next() } pub fn last_handshake(&self) -> Option { self.kv() .iter() .filter(|(key, _)| key == &"last_handshake_time_nsec") .map(|(_, value)| value.parse::().unwrap()) .next() } pub fn push(&mut self, key: String, value: String) { self.kv_mut().push((key, value)); } pub fn delta(&self, other: Self) -> Vec { let kv = self.kv(); other .kv() .iter() .filter(|pair| !kv.contains(pair)) .map(|p| p.clone()) .collect::>() } } impl fmt::Display for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (k, v) in self.kv() { write!(f, "({:10}= {})", k, v)?; } Ok(()) } } impl fmt::Display for Peer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self) // write!(f, "peer {}\nshake {} ago\naddr {}\nkeepalive {}\n", self.public_key, self.last_handshake.map(|d|d.to_string()).unwrap_or("-"), self.endpoint.map(|d|d.to_string()).unwrap_or("-"), self.persistent_keepalive.map(|d|d.to_string()).unwrap_or("-")) } } struct Socket { pub path: PathBuf, } impl Socket { pub fn get(&self) -> Result> { let mut stream = UnixStream::connect(&self.path)?; stream.write_all(b"get=1\n")?; let mut state: Vec = vec![]; let mut cur = State::Interface(Vec::with_capacity(0)); for line in BufReader::new(stream).lines() { let line = line?; let mut iter = line.chars(); let key = iter.by_ref().take_while(|c| c != &'=').collect::(); let value = iter.collect::(); match key.as_ref() { "errno" if value != "0" => Err(Error::new( ErrorKind::Other, format!("Socket said error: {}", value), ))?, "public_key" | "private_key" => { state.push(cur); cur = if key == "private_key" { State::Interface(Vec::with_capacity(3)) } else { State::Peer(Vec::with_capacity(5)) }; cur.push(key, value); } _ => cur.push(key, value), } } Ok(state) } pub fn get_by_id(&self) -> Result> { let state = self.get()?; let mut ided = HashMap::new(); for s in state { if let Some(id) = s.id() { ided.insert(id.clone(), s); } } Ok(ided) } pub fn get_peers(&self) -> Result> { let by_id = self.get_by_id()?; Ok(by_id .iter() .filter_map(|(id, state)| { Peer::from_kv(state.kv()) .ok() .map(|peer| (id.to_owned(), peer)) }) .collect()) } } fn main() { let mut args = env::args(); args.next(); //Ignore program name let path = args .next() .map(PathBuf::from) .filter(|p| p.as_path().exists()); let interval = Duration::from_millis( args.next() .map(|i| { i.parse::() .expect("[interval] has to be a positive int") }) .unwrap_or(1000), ); let mut listeners: Vec> = vec![Box::new(LogListener)]; let events: PathBuf = "/etc/wireguard/events.sh".into(); if events.exists() { listeners.push(Box::new(ScriptListener::new(events))) } let timeout = env::vars() .collect::>() .get("WG_EVENT_GEN_TIMEOUT") .map(|timeout| { Duration::from_secs( timeout .parse::() .expect(&format!("Can't parse {} as timeout", timeout)), ) }) .unwrap_or(Duration::from_secs(30)); if let Some(path) = path { let sock = Socket { path }; let mut prev_state: Option> = None; loop { let state = match sock.get_peers() { Ok(state) => state, Err(err) => { eprintln!("Failed to read from socket: {}", err); continue; } }; if let Some(prev_state) = prev_state { gen::gen_events(&state, &prev_state, &listeners, timeout, interval); } prev_state = Some(state); thread::sleep(interval); } } else { println!(" does not exist"); } }