diff --git a/wg-event-gen/src/gen.rs b/wg-event-gen/src/gen.rs new file mode 100644 index 0000000..da43ff2 --- /dev/null +++ b/wg-event-gen/src/gen.rs @@ -0,0 +1,75 @@ +use crate::listener::*; +use crate::*; +use std::collections::{HashMap, HashSet}; +use std::env; +use std::fmt; +use std::io::prelude::*; +use std::io::{BufRead, BufReader, Error, ErrorKind, Result}; +use std::net::SocketAddr; +use std::os::unix::net::UnixStream; +use std::path::PathBuf; +use std::{thread, time}; + +pub(crate) fn gen_events( + state: &HashMap, + prev: &HashMap, + listeners: &Vec>, + timeout: time::Duration, +) { + let side_by_side = { + state + .keys() + .map(String::as_ref) + .chain(prev.keys().map(String::as_ref)) + .collect::>() + .iter() + .map(|p| (p.to_owned(), (prev.get(*p), state.get(*p)))) + .collect::, Option<&Peer>)>>() + }; + for (_id, (prev, cur)) in side_by_side { + /*if id != "HhRgEL2xsnEIqThSTUKLGaTXusorM1MFdjSSYvzBynY=" { + continue; + } + println!("{} p {} c {}", _id, prev.is_some(), cur.is_some());*/ + match (prev, cur) { + (Some(prev), Some(cur)) if prev.endpoint != cur.endpoint => { + if let (Some(prev_addr), Some(_)) = (prev.endpoint, cur.endpoint) { + listeners.roaming(&cur, prev_addr); + } + } + (Some(prev), Some(cur)) => { + //shake > timeout && prev.shake < timeout -> listeners.iter().for_each(|l| l.disconnected(&cur)); + //shake < timeout && (prev.shake is none || prev.shake > timeout) -> listeners.iter().for_each(|l| l.connected(&cur)); + if let (Some(shake), Some(pshake)) = (cur.last_handshake, prev.last_handshake) { + if shake > timeout && pshake < timeout { + listeners.disconnected(&cur); + } + continue; + } + if let Some(shake) = cur.last_handshake { + if shake > timeout + && prev + .last_handshake + .map(|shake| shake > timeout) + .unwrap_or(true) + { + listeners.connected(&cur); + } + continue; + } + } + #[cfg(addrem)] + (None, Some(cur)) => listeners.added(&cur), + #[cfg(addrem)] + (Some(prev), None) => listeners.removed(&prev), + #[cfg(not(addrem))] + (None, Some(_cur)) => (), + #[cfg(not(addrem))] + (Some(_prev), None) => (), + fail => { + println!("{:?}", fail); + unreachable!() + } + } + } +} diff --git a/wg-event-gen/src/listener.rs b/wg-event-gen/src/listener.rs new file mode 100644 index 0000000..70dda7d --- /dev/null +++ b/wg-event-gen/src/listener.rs @@ -0,0 +1,123 @@ +use crate::Peer; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::process::Command; +use std::thread; + +pub trait EventListener { + fn added<'a>(&self, peer: &'a Peer) { + self.connected(peer); + } + + fn connected<'a>(&self, peer: &'a Peer); + + fn disconnected<'a>(&self, peer: &'a Peer); + + fn removed<'a>(&self, peer: &'a Peer) { + self.disconnected(peer) + } + + fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr); +} + +impl EventListener for Vec> { + fn added<'a>(&self, peer: &'a Peer) { + self.iter().for_each(|l| l.added(&peer)); + } + + fn connected<'a>(&self, peer: &'a Peer) { + self.iter().for_each(|l| l.connected(&peer)); + } + + fn disconnected<'a>(&self, peer: &'a Peer) { + self.iter().for_each(|l| l.disconnected(&peer)); + } + + fn removed<'a>(&self, peer: &'a Peer) { + self.iter().for_each(|l| l.removed(&peer)); + } + + fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr) { + self.iter().for_each(|l| l.roaming(&peer, previous_addr)); + } +} + +pub struct LogListener; + +impl EventListener for LogListener { + fn connected<'a>(&self, peer: &'a Peer) { + println!("{} connected!", peer.public_key); + } + + fn disconnected<'a>(&self, peer: &'a Peer) { + println!("{} disconnected!", peer.public_key); + } + + fn added<'a>(&self, peer: &'a Peer) { + println!("{} added!", peer.public_key); + } + + fn removed<'a>(&self, peer: &'a Peer) { + println!("{} removed!", peer.public_key); + } + + fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr) { + println!( + "{} roamed {} -> {}!", + peer.public_key, + previous_addr, + peer.endpoint.unwrap() + ); + } +} + +pub struct ScriptListener { + pub script: PathBuf, +} + +impl ScriptListener { + pub fn new(script: PathBuf) -> ScriptListener { + ScriptListener { script } + } + + fn mkcmd<'a>(&self, args: Vec<&'a str>) -> Command { + let mut cmd = Command::new("/bin/sh"); + cmd.arg("-c"); + cmd.arg(format!("\"{}\"", args.join(" "))); + cmd + } + + fn call_sub<'a>(&self, args: Vec<&'a str>) { + let mut cmd = self.mkcmd(args); + thread::spawn(move || { + cmd.spawn().expect("Failed to call Script hooḱ!"); + }); + } +} + +impl EventListener for ScriptListener { + fn connected<'a>(&self, peer: &'a Peer) { + self.call_sub(vec!["connected", &peer.public_key]); + } + + fn disconnected<'a>(&self, peer: &'a Peer) { + self.call_sub(vec!["disconnected", &peer.public_key]); + } + + fn added<'a>(&self, peer: &'a Peer) { + self.call_sub(vec!["added", &peer.public_key]); + } + + fn removed<'a>(&self, peer: &'a Peer) { + self.call_sub(vec!["removed", &peer.public_key]); + } + + fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr) { + self.call_sub(vec![ + "roaming", + &peer.public_key, + &previous_addr.to_string(), + &peer.endpoint.unwrap().to_string(), + ]); + } +} diff --git a/wg-event-gen/src/main.rs b/wg-event-gen/src/main.rs index e09ea50..45d2d0f 100644 --- a/wg-event-gen/src/main.rs +++ b/wg-event-gen/src/main.rs @@ -1,5 +1,12 @@ +mod gen; +mod listener; + +use crate::gen::*; +use crate::listener::*; + use std::collections::HashMap; use std::env; +use std::fmt; use std::io::prelude::*; use std::io::{BufRead, BufReader, Error, ErrorKind, Result}; use std::net::SocketAddr; @@ -15,8 +22,48 @@ enum State { Peer(Vec), } +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct Peer { + public_key: String, + endpoint: Option, + last_handshake: Option, + persistent_keepalive: Option, +} + +impl Peer { + fn from_kv(entries: &Vec) -> Result { + let key = match entries + .iter() + .filter(|(key, _)| key == &"public_key") + .map(|(_, value)| value) + .next() + { + Some(key) => key, + None => return Err(Error::new(ErrorKind::Other, "Peer is missing key")), + }; + Ok(Peer { + public_key: key.to_string(), + endpoint: entries + .iter() + .filter(|(key, _)| key == &"endpoint") + .map(|(_, value)| value.parse::().unwrap()) + .next(), + last_handshake: entries + .iter() + .filter(|(key, _)| key == &"last_handshake_time_nsec") + .map(|(_, value)| time::Duration::from_millis(value.parse::().unwrap())) + .next(), + persistent_keepalive: entries + .iter() + .filter(|(key, _)| key == &"persistent_keepalive") + .map(|(_, value)| time::Duration::from_secs(value.parse::().unwrap())) + .next(), + }) + } +} + impl State { - fn kv(&self) -> &Vec { + pub fn kv(&self) -> &Vec { match self { State::Interface(kv) => kv, State::Peer(kv) => kv, @@ -69,6 +116,22 @@ impl State { } } +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for (k, v) in self.kv() { + write!(f, "({:10}= {})", k, v)?; + } + Ok(()) + } +} + +impl fmt::Display for Peer { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + // write!(f, "peer {}\nshake {} ago\naddr {}\nkeepalive {}\n", self.public_key, self.last_handshake.map(|d|d.to_string()).unwrap_or("-"), self.endpoint.map(|d|d.to_string()).unwrap_or("-"), self.persistent_keepalive.map(|d|d.to_string()).unwrap_or("-")) + } +} + struct Socket { pub path: PathBuf, } @@ -114,42 +177,17 @@ impl Socket { } Ok(ided) } -} -trait EventListener { - fn added<'a>(&self, peer: &'a State) { - self.connected(peer); - } - - fn connected<'a>(&self, peer: &'a State); - - fn disconnected<'a>(&self, peer: &'a State); - - fn removed<'a>(&self, peer: &'a State) { - self.disconnected(peer) - } - - fn roaming<'a>(&self, peer: &'a State, previous_addr: SocketAddr); -} - -struct LogListener; - -impl EventListener for LogListener { - fn connected<'a>(&self, peer: &'a State) { - println!("{} connected!", peer.id().unwrap()); - } - - fn disconnected<'a>(&self, peer: &'a State) { - println!("{} disconnected!", peer.id().unwrap()); - } - - fn roaming<'a>(&self, peer: &'a State, previous_addr: SocketAddr) { - println!( - "{} roamed {} -> {}!", - peer.id().unwrap(), - previous_addr, - peer.addr().unwrap() - ); + pub fn get_peers(&self) -> Result> { + let by_id = self.get_by_id()?; + Ok(by_id + .iter() + .filter_map(|(id, state)| { + Peer::from_kv(state.kv()) + .ok() + .map(|peer| (id.to_owned(), peer)) + }) + .collect()) } } @@ -167,15 +205,15 @@ fn main() { .expect("[interval] has to be a positive int") }) .unwrap_or(1000); - let listeners = vec![LogListener]; + let listeners: Vec> = vec![Box::new(LogListener)]; - let timeout: u64 = 3 * 1000; + let timeout = time::Duration::from_secs(3); if let Some(path) = path { let sock = Socket { path }; - let mut prev_state: Option> = None; + let mut prev_state: Option> = None; loop { - let state = match sock.get_by_id() { + let state = match sock.get_peers() { Ok(state) => state, Err(err) => { eprintln!("Failed to read from socket: {}", err); @@ -183,28 +221,8 @@ fn main() { } }; if let Some(prev_state) = prev_state { - for (peer, state) in state.iter() { - if let Some(p_state) = prev_state.get(peer) { - if let (Some(addr), Some(p_addr)) = (state.addr(), p_state.addr()) { - if addr != p_addr { - listeners.iter().for_each(|l| l.roaming(state, p_addr)); - } - } - } else { - listeners.iter().for_each(|l| l.connected(state)); - } - if let Some(shake) = state.last_handshake() { - if (shake / 1000) > timeout && shake / 1000 < timeout + interval { - listeners.iter().for_each(|l| l.disconnected(state)); - } - } - } - prev_state - .iter() - .filter(|(k, _)| !state.contains_key(k.clone())) - .for_each(|(_, state)| listeners.iter().for_each(|l| l.disconnected(state))); + gen::gen_events(&state, &prev_state, &listeners, timeout); } - state.keys().for_each(|k| print!("{}, ", k)); println!(""); prev_state = Some(state); thread::sleep(time::Duration::from_millis(interval)); }