diff --git a/wg-event-gen/src/gen.rs b/wg-event-gen/src/gen.rs index 218aff4..7bab254 100644 --- a/wg-event-gen/src/gen.rs +++ b/wg-event-gen/src/gen.rs @@ -8,6 +8,7 @@ use std::io::{BufRead, BufReader, Error, ErrorKind, Result}; use std::net::SocketAddr; use std::os::unix::net::UnixStream; use std::path::PathBuf; +use std::rc::Rc; use std::{thread, time}; pub(crate) fn gen_events( @@ -29,8 +30,8 @@ pub(crate) fn gen_events( for (_id, (prev, cur)) in side_by_side { /*if id != "HhRgEL2xsnEIqThSTUKLGaTXusorM1MFdjSSYvzBynY=" { continue; - } - println!("{} p {} c {}", _id, prev.is_some(), cur.is_some());*/ + }*/ + println!("{} p {} c {}", _id, prev.is_some(), cur.is_some()); match (prev, cur) { (Some(prev), Some(cur)) if prev.endpoint != cur.endpoint => { if let (Some(prev_addr), Some(_)) = (prev.endpoint, cur.endpoint) { @@ -47,24 +48,19 @@ pub(crate) fn gen_events( continue; } if let Some(shake) = cur.last_handshake { - if shake > timeout - && prev - .last_handshake - .map(|shake| shake > timeout) - .unwrap_or(true) + if shake > timeout && prev + .last_handshake + .map(|shake| shake > timeout) + .unwrap_or(true) { listeners.connected(&cur); } continue; } } - #[cfg(addrem)] (None, Some(cur)) => listeners.added(&cur), - #[cfg(addrem)] (Some(prev), None) => listeners.removed(&prev), - #[cfg(not(addrem))] (None, Some(_cur)) => (), - #[cfg(not(addrem))] (Some(_prev), None) => (), fail => { println!("{:?}", fail); @@ -74,7 +70,6 @@ pub(crate) fn gen_events( } } - #[cfg(test)] mod test { use super::*; @@ -92,14 +87,14 @@ mod test { use std::{thread, time}; struct TestListener { - calls: RefCell>, + calls: Rc>>, } impl TestListener { fn new() -> TestListener { - Self::from(RefCell::new(vec![])) + Self::from(Rc::new(RefCell::new(vec![]))) } - fn from(calls: RefCell>) -> TestListener { + fn from(calls: Rc>>) -> TestListener { TestListener { calls: calls } } } @@ -135,42 +130,97 @@ mod test { } } - fn listeners() -> (Vec>, RefCell>) { - let calls: RefCell> = RefCell::new(vec![]); - ( - vec![ - Box::new(TestListener::from(calls.clone())), - Box::new(LogListener), - ], - calls.clone(), - ) + fn listeners() -> (Vec>, Rc>>) { + let l = TestListener::new(); + let calls = l.calls.clone(); + (vec![Box::new(l)], calls) + } + + #[test] + fn test_setup() { + let (listeners, calls) = listeners(); + let peer = peer(); + listeners.connected(&peer); + assert_eq!( + vec![["con", &peer.public_key].join(" ")], + calls.borrow().clone() + ); } fn b2h(b: &str) -> String { hex::encode(base64::decode(b).unwrap()) } - #[test] - fn connected() { + fn peer() -> Peer { let bkey = "HhRgEL2xsnEIqThSTUKLGaTXusorM1MFdjSSYvzBynY="; let key = b2h(bkey); - let prev: HashMap = HashMap::new(); + Peer::from_kv(&vec![ + ("public_key".to_string(), key.clone()), + ( + "last_handshake_time_nsec".to_string(), + (1000 * 1000 * 1).to_string(), + ), + ("endpoint".to_string(), "1.1.1.1:22222".to_string()), + ]).unwrap() + } + + #[test] + fn connected() { + let peer = peer(); + let mut peer_cur = peer.clone(); + let mut prev: HashMap = HashMap::new(); let mut cur: HashMap = HashMap::new(); - cur.insert( - key.clone(), - Peer::from_kv(&vec![ - ("public_key".to_string(), key.clone()), - ( - "last_handshake_time_nsec".to_string(), - (1000 * 1000 * 1).to_string(), - ), - ("endpoint".to_string(), "1.1.1.1:22222".to_string()), - ]) - .unwrap(), + cur.insert(peer_cur.public_key.clone(), peer_cur.clone()); + let (listener, calls) = listeners(); + gen_events(&cur, &prev, &listener, time::Duration::from_secs(3)); + assert_eq!( + vec![["add", &peer_cur.public_key].join(" ")], + calls.borrow().clone() + ); + + gen_events(&cur, &cur, &listener, time::Duration::from_secs(3)); + + //Shouldn't gen any new events + assert!(calls.borrow().len() == 1); + + let (listener, calls) = listeners(); + gen_events(&prev, &cur, &listener, time::Duration::from_secs(3)); + assert_eq!( + vec![["rem", &peer.public_key].join(" ")], + calls.borrow().clone() + ); + + calls.borrow_mut().clear(); + + let mut peer_prev = peer.clone(); + + peer_prev.endpoint = Some("2.2.2.2:33333".parse::().unwrap()); + + peer_prev.last_handshake = Some(time::Duration::from_secs(1000)); + + prev.insert(peer_prev.public_key.clone(), peer_prev.clone()); + + gen_events(&prev, &cur, &listener, time::Duration::from_secs(3)); + + assert_eq!( + vec![["rom", &peer.public_key].join(" ")], + calls.borrow().clone() + ); + + calls.borrow_mut().clear(); + + let mut peer_prev = peer.clone(); + + peer_cur.last_handshake = Some(time::Duration::from_secs(1)); + + cur.insert(peer_cur.public_key.clone(), peer_cur); + + gen_events(&cur, &prev, &listener, time::Duration::from_secs(3)); + + assert_eq!( + vec![["rom", &peer.public_key].join(" ")], + calls.borrow().clone() ); - let (listeners, calls) = listeners(); - gen_events(&cur, &prev, &listeners, time::Duration::from_secs(3)); - assert_eq!(vec![["con", bkey].join(" ")], calls.borrow().clone()); } } diff --git a/wg-event-gen/src/main.rs b/wg-event-gen/src/main.rs index 2083709..3f78ebe 100644 --- a/wg-event-gen/src/main.rs +++ b/wg-event-gen/src/main.rs @@ -24,7 +24,7 @@ enum State { Peer(Vec), } -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct Peer { public_key: String, endpoint: Option, @@ -188,8 +188,7 @@ impl Socket { Peer::from_kv(state.kv()) .ok() .map(|peer| (id.to_owned(), peer)) - }) - .collect()) + }).collect()) } } @@ -205,8 +204,7 @@ fn main() { .map(|i| { i.parse::() .expect("[interval] has to be a positive int") - }) - .unwrap_or(1000); + }).unwrap_or(1000); let mut listeners: Vec> = vec![Box::new(LogListener)]; let events: PathBuf = "/etc/wireguard/events.sh".into();