diff --git a/wg-event-gen/src/gen.rs b/wg-event-gen/src/gen.rs index 99eceb2..a34821f 100644 --- a/wg-event-gen/src/gen.rs +++ b/wg-event-gen/src/gen.rs @@ -16,6 +16,7 @@ pub(crate) fn gen_events( prev: &HashMap, listeners: &Vec>, timeout: time::Duration, + poll_interval: time::Duration, ) { let side_by_side = { state @@ -29,29 +30,30 @@ pub(crate) fn gen_events( }; for (_id, (prev, cur)) in side_by_side { match (prev, cur) { - (Some(prev), Some(cur)) if prev.endpoint != cur.endpoint => { - if let (Some(prev_addr), Some(_)) = (prev.endpoint, cur.endpoint) { - listeners.roaming(&cur, prev_addr); - } - } (Some(prev), Some(cur)) => { let timedout = |peer: &Peer| match peer.last_handshake_rel() { - Some(shake) if shake < timeout => false, + Some(shake) if shake > timeout && shake + poll_interval < timeout => true, + Some(_) => false, _ => true, }; - - //if _id == "HhRgEL2xsnEIqThSTUKLGaTXusorM1MFdjSSYvzBynY=" { dbg!((cur.last_handshake_rel(),timedout(&prev) , timedout(&cur))); } - + if !timedout(&prev) && timedout(&cur) { listeners.disconnected(&cur); + continue; } if timedout(&prev) && !timedout(&cur) { listeners.connected(&cur); } + + if prev.endpoint != cur.endpoint { + if let (Some(prev_addr), Some(_)) = (prev.endpoint, cur.endpoint) { + listeners.roaming(&cur, prev_addr); + } + } } - (None, Some(cur)) => (), //listeners.added(&cur), - (Some(prev), None) => (), //listeners.removed(&prev), + (None, Some(cur)) => listeners.added(&cur), + (Some(prev), None) => listeners.removed(&prev), (None, Some(_cur)) => (), (Some(_prev), None) => (), fail => { @@ -165,19 +167,38 @@ mod test { let mut cur: HashMap = HashMap::new(); cur.insert(peer_cur.public_key.clone(), peer_cur.clone()); let (listener, calls) = listeners(); - gen_events(&cur, &prev, &listener, time::Duration::from_secs(3)); + let interval = time::Duration::from_secs(3); + gen_events( + &cur, + &prev, + &listener, + time::Duration::from_secs(3), + interval, + ); assert_eq!( vec![["add", &peer_cur.public_key].join(" ")], calls.borrow().clone() ); - gen_events(&cur, &cur, &listener, time::Duration::from_secs(3)); + gen_events( + &cur, + &cur, + &listener, + time::Duration::from_secs(3), + interval, + ); //Shouldn't gen any new events assert!(calls.borrow().len() == 1); let (listener, calls) = listeners(); - gen_events(&prev, &cur, &listener, time::Duration::from_secs(10)); + gen_events( + &prev, + &cur, + &listener, + time::Duration::from_secs(10), + interval, + ); assert_eq!( vec![["rem", &peer.public_key].join(" ")], calls.borrow().clone() @@ -193,13 +214,19 @@ mod test { prev.insert(peer_prev.public_key.clone(), peer_prev.clone()); - gen_events(&prev, &cur, &listener, time::Duration::from_secs(10)); - - assert_eq!( - vec![["rom", &peer.public_key].join(" ")], - calls.borrow().clone() + gen_events( + &prev, + &cur, + &listener, + time::Duration::from_secs(10), + interval, ); + assert!(calls + .borrow() + .clone() + .contains(&["rom", &peer.public_key].join(" "))); + calls.borrow_mut().clear(); let mut peer_prev = peer.clone(); @@ -209,7 +236,13 @@ mod test { cur.insert(peer_cur.public_key.clone(), peer_cur.clone()); prev.insert(peer_prev.public_key.clone(), peer_prev.clone()); - gen_events(&cur, &prev, &listener, time::Duration::from_secs(10)); + gen_events( + &cur, + &prev, + &listener, + time::Duration::from_secs(10), + interval, + ); assert_eq!( vec![["con", &peer.public_key].join(" ")], @@ -219,7 +252,13 @@ mod test { calls.borrow_mut().clear(); //Other way around should be a disconnect - gen_events(&prev, &cur, &listener, time::Duration::from_secs(3)); + gen_events( + &prev, + &cur, + &listener, + time::Duration::from_secs(3), + interval, + ); assert_eq!( vec![["dis", &peer.public_key].join(" ")], diff --git a/wg-event-gen/src/listener.rs b/wg-event-gen/src/listener.rs index ecfbb1f..8a24217 100644 --- a/wg-event-gen/src/listener.rs +++ b/wg-event-gen/src/listener.rs @@ -22,7 +22,9 @@ pub trait EventListener { impl EventListener for Vec> { fn added<'a>(&self, peer: &'a Peer) { - self.iter().for_each(|l| l.added(&peer)); + if cfg!(feature = "addrem") || cfg!(test) { + self.iter().for_each(|l| l.added(&peer)); + } } fn connected<'a>(&self, peer: &'a Peer) { @@ -34,7 +36,9 @@ impl EventListener for Vec> { } fn removed<'a>(&self, peer: &'a Peer) { - self.iter().for_each(|l| l.removed(&peer)); + if cfg!(feature = "addrem") || cfg!(test) { + self.iter().for_each(|l| l.removed(&peer)); + } } fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr) { @@ -83,7 +87,11 @@ impl ScriptListener { fn mkcmd<'a>(&self, args: Vec<&'a str>) -> Command { let mut cmd = Command::new("/bin/sh"); cmd.arg("-c"); - cmd.arg(format!("{} {}",(&self.script).to_str().unwrap(), args.join(" "))); + cmd.arg(format!( + "{} {}", + (&self.script).to_str().unwrap(), + args.join(" ") + )); cmd } diff --git a/wg-event-gen/src/main.rs b/wg-event-gen/src/main.rs index 2add0c5..bf15e92 100644 --- a/wg-event-gen/src/main.rs +++ b/wg-event-gen/src/main.rs @@ -14,9 +14,9 @@ use std::io::{BufRead, BufReader, Error, ErrorKind, Result}; use std::net::SocketAddr; use std::os::unix::net::UnixStream; use std::path::PathBuf; -use time; use std::thread; use std::time::Duration; +use time; pub type KV = (String, String); @@ -56,34 +56,37 @@ impl Peer { last_handshake: entries .iter() .filter_map(|(key, value)| { - let value = || value.parse::().unwrap(); - match key.as_ref() { - "last_handshake_time_sec" if value() != 0 => Some(Duration::new(value(), 0)), - "last_handshake_time_nsec" if value() != 0 => Some(Duration::from_nanos(value())), - _ => None - } + let value = || value.parse::().unwrap(); + match key.as_ref() { + "last_handshake_time_sec" if value() != 0 => { + Some(Duration::new(value(), 0)) + } + "last_handshake_time_nsec" if value() != 0 => { + Some(Duration::from_nanos(value())) + } + _ => None, + } }) .fold(None, |acc, add| { - if let Some(dur) = acc { - Some(dur + add) - } else { - Some(add) - } + if let Some(dur) = acc { + Some(dur + add) + } else { + Some(add) + } }), persistent_keepalive: entries .iter() .filter(|(key, _)| key == &"persistent_keepalive") .map(|(_, value)| Duration::from_secs(value.parse::().unwrap())) .next(), - parsed: time::get_time(), + parsed: time::get_time(), }) } - + pub fn last_handshake_rel(&self) -> Option { - let time = self.parsed; - Some(Duration::new(time.sec as u64, time.nsec as u32) - self.last_handshake?) + let time = self.parsed; + Some(Duration::new(time.sec as u64, time.nsec as u32) - self.last_handshake?) } - } impl State { @@ -222,13 +225,14 @@ fn main() { .next() .map(PathBuf::from) .filter(|p| p.as_path().exists()); - let interval = args - .next() - .map(|i| { - i.parse::() - .expect("[interval] has to be a positive int") - }) - .unwrap_or(1000); + let interval = Duration::from_millis( + args.next() + .map(|i| { + i.parse::() + .expect("[interval] has to be a positive int") + }) + .unwrap_or(1000), + ); let mut listeners: Vec> = vec![Box::new(LogListener)]; let events: PathBuf = "/etc/wireguard/events.sh".into(); @@ -237,9 +241,18 @@ fn main() { listeners.push(Box::new(ScriptListener::new(events))) } - let timeout = env::vars().collect::>().get("WG_EVENT_GEN_TIMEOUT").map(|timeout| Duration::from_secs(timeout.parse::().expect(&format!("Can't parse {} as timeout", timeout)))).unwrap_or(Duration::from_secs(30)); - - + let timeout = env::vars() + .collect::>() + .get("WG_EVENT_GEN_TIMEOUT") + .map(|timeout| { + Duration::from_secs( + timeout + .parse::() + .expect(&format!("Can't parse {} as timeout", timeout)), + ) + }) + .unwrap_or(Duration::from_secs(30)); + if let Some(path) = path { let sock = Socket { path }; let mut prev_state: Option> = None; @@ -252,10 +265,10 @@ fn main() { } }; if let Some(prev_state) = prev_state { - gen::gen_events(&state, &prev_state, &listeners, timeout); + gen::gen_events(&state, &prev_state, &listeners, timeout, interval); } prev_state = Some(state); - thread::sleep(Duration::from_millis(interval)); + thread::sleep(interval); } } else { println!(" does not exist");