more reliable events

This commit is contained in:
shimunn 2019-01-19 15:29:08 +01:00
parent 215cb7ec8f
commit 32a86b45ab
3 changed files with 113 additions and 53 deletions

View File

@ -16,6 +16,7 @@ pub(crate) fn gen_events(
prev: &HashMap<String, Peer>, prev: &HashMap<String, Peer>,
listeners: &Vec<Box<EventListener>>, listeners: &Vec<Box<EventListener>>,
timeout: time::Duration, timeout: time::Duration,
poll_interval: time::Duration,
) { ) {
let side_by_side = { let side_by_side = {
state state
@ -29,29 +30,30 @@ pub(crate) fn gen_events(
}; };
for (_id, (prev, cur)) in side_by_side { for (_id, (prev, cur)) in side_by_side {
match (prev, cur) { match (prev, cur) {
(Some(prev), Some(cur)) if prev.endpoint != cur.endpoint => {
if let (Some(prev_addr), Some(_)) = (prev.endpoint, cur.endpoint) {
listeners.roaming(&cur, prev_addr);
}
}
(Some(prev), Some(cur)) => { (Some(prev), Some(cur)) => {
let timedout = |peer: &Peer| match peer.last_handshake_rel() { let timedout = |peer: &Peer| match peer.last_handshake_rel() {
Some(shake) if shake < timeout => false, Some(shake) if shake > timeout && shake + poll_interval < timeout => true,
Some(_) => false,
_ => true, _ => true,
}; };
//if _id == "HhRgEL2xsnEIqThSTUKLGaTXusorM1MFdjSSYvzBynY=" { dbg!((cur.last_handshake_rel(),timedout(&prev) , timedout(&cur))); }
if !timedout(&prev) && timedout(&cur) { if !timedout(&prev) && timedout(&cur) {
listeners.disconnected(&cur); listeners.disconnected(&cur);
continue;
} }
if timedout(&prev) && !timedout(&cur) { if timedout(&prev) && !timedout(&cur) {
listeners.connected(&cur); listeners.connected(&cur);
} }
if prev.endpoint != cur.endpoint {
if let (Some(prev_addr), Some(_)) = (prev.endpoint, cur.endpoint) {
listeners.roaming(&cur, prev_addr);
}
}
} }
(None, Some(cur)) => (), //listeners.added(&cur), (None, Some(cur)) => listeners.added(&cur),
(Some(prev), None) => (), //listeners.removed(&prev), (Some(prev), None) => listeners.removed(&prev),
(None, Some(_cur)) => (), (None, Some(_cur)) => (),
(Some(_prev), None) => (), (Some(_prev), None) => (),
fail => { fail => {
@ -165,19 +167,38 @@ mod test {
let mut cur: HashMap<String, Peer> = HashMap::new(); let mut cur: HashMap<String, Peer> = HashMap::new();
cur.insert(peer_cur.public_key.clone(), peer_cur.clone()); cur.insert(peer_cur.public_key.clone(), peer_cur.clone());
let (listener, calls) = listeners(); let (listener, calls) = listeners();
gen_events(&cur, &prev, &listener, time::Duration::from_secs(3)); let interval = time::Duration::from_secs(3);
gen_events(
&cur,
&prev,
&listener,
time::Duration::from_secs(3),
interval,
);
assert_eq!( assert_eq!(
vec![["add", &peer_cur.public_key].join(" ")], vec![["add", &peer_cur.public_key].join(" ")],
calls.borrow().clone() calls.borrow().clone()
); );
gen_events(&cur, &cur, &listener, time::Duration::from_secs(3)); gen_events(
&cur,
&cur,
&listener,
time::Duration::from_secs(3),
interval,
);
//Shouldn't gen any new events //Shouldn't gen any new events
assert!(calls.borrow().len() == 1); assert!(calls.borrow().len() == 1);
let (listener, calls) = listeners(); let (listener, calls) = listeners();
gen_events(&prev, &cur, &listener, time::Duration::from_secs(10)); gen_events(
&prev,
&cur,
&listener,
time::Duration::from_secs(10),
interval,
);
assert_eq!( assert_eq!(
vec![["rem", &peer.public_key].join(" ")], vec![["rem", &peer.public_key].join(" ")],
calls.borrow().clone() calls.borrow().clone()
@ -193,13 +214,19 @@ mod test {
prev.insert(peer_prev.public_key.clone(), peer_prev.clone()); prev.insert(peer_prev.public_key.clone(), peer_prev.clone());
gen_events(&prev, &cur, &listener, time::Duration::from_secs(10)); gen_events(
&prev,
assert_eq!( &cur,
vec![["rom", &peer.public_key].join(" ")], &listener,
calls.borrow().clone() time::Duration::from_secs(10),
interval,
); );
assert!(calls
.borrow()
.clone()
.contains(&["rom", &peer.public_key].join(" ")));
calls.borrow_mut().clear(); calls.borrow_mut().clear();
let mut peer_prev = peer.clone(); let mut peer_prev = peer.clone();
@ -209,7 +236,13 @@ mod test {
cur.insert(peer_cur.public_key.clone(), peer_cur.clone()); cur.insert(peer_cur.public_key.clone(), peer_cur.clone());
prev.insert(peer_prev.public_key.clone(), peer_prev.clone()); prev.insert(peer_prev.public_key.clone(), peer_prev.clone());
gen_events(&cur, &prev, &listener, time::Duration::from_secs(10)); gen_events(
&cur,
&prev,
&listener,
time::Duration::from_secs(10),
interval,
);
assert_eq!( assert_eq!(
vec![["con", &peer.public_key].join(" ")], vec![["con", &peer.public_key].join(" ")],
@ -219,7 +252,13 @@ mod test {
calls.borrow_mut().clear(); calls.borrow_mut().clear();
//Other way around should be a disconnect //Other way around should be a disconnect
gen_events(&prev, &cur, &listener, time::Duration::from_secs(3)); gen_events(
&prev,
&cur,
&listener,
time::Duration::from_secs(3),
interval,
);
assert_eq!( assert_eq!(
vec![["dis", &peer.public_key].join(" ")], vec![["dis", &peer.public_key].join(" ")],

View File

@ -22,7 +22,9 @@ pub trait EventListener {
impl EventListener for Vec<Box<EventListener>> { impl EventListener for Vec<Box<EventListener>> {
fn added<'a>(&self, peer: &'a Peer) { fn added<'a>(&self, peer: &'a Peer) {
self.iter().for_each(|l| l.added(&peer)); if cfg!(feature = "addrem") || cfg!(test) {
self.iter().for_each(|l| l.added(&peer));
}
} }
fn connected<'a>(&self, peer: &'a Peer) { fn connected<'a>(&self, peer: &'a Peer) {
@ -34,7 +36,9 @@ impl EventListener for Vec<Box<EventListener>> {
} }
fn removed<'a>(&self, peer: &'a Peer) { fn removed<'a>(&self, peer: &'a Peer) {
self.iter().for_each(|l| l.removed(&peer)); if cfg!(feature = "addrem") || cfg!(test) {
self.iter().for_each(|l| l.removed(&peer));
}
} }
fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr) { fn roaming<'a>(&self, peer: &'a Peer, previous_addr: SocketAddr) {
@ -83,7 +87,11 @@ impl ScriptListener {
fn mkcmd<'a>(&self, args: Vec<&'a str>) -> Command { fn mkcmd<'a>(&self, args: Vec<&'a str>) -> Command {
let mut cmd = Command::new("/bin/sh"); let mut cmd = Command::new("/bin/sh");
cmd.arg("-c"); cmd.arg("-c");
cmd.arg(format!("{} {}",(&self.script).to_str().unwrap(), args.join(" "))); cmd.arg(format!(
"{} {}",
(&self.script).to_str().unwrap(),
args.join(" ")
));
cmd cmd
} }

View File

@ -14,9 +14,9 @@ use std::io::{BufRead, BufReader, Error, ErrorKind, Result};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::os::unix::net::UnixStream; use std::os::unix::net::UnixStream;
use std::path::PathBuf; use std::path::PathBuf;
use time;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use time;
pub type KV = (String, String); pub type KV = (String, String);
@ -56,34 +56,37 @@ impl Peer {
last_handshake: entries last_handshake: entries
.iter() .iter()
.filter_map(|(key, value)| { .filter_map(|(key, value)| {
let value = || value.parse::<u64>().unwrap(); let value = || value.parse::<u64>().unwrap();
match key.as_ref() { match key.as_ref() {
"last_handshake_time_sec" if value() != 0 => Some(Duration::new(value(), 0)), "last_handshake_time_sec" if value() != 0 => {
"last_handshake_time_nsec" if value() != 0 => Some(Duration::from_nanos(value())), Some(Duration::new(value(), 0))
_ => None }
} "last_handshake_time_nsec" if value() != 0 => {
Some(Duration::from_nanos(value()))
}
_ => None,
}
}) })
.fold(None, |acc, add| { .fold(None, |acc, add| {
if let Some(dur) = acc { if let Some(dur) = acc {
Some(dur + add) Some(dur + add)
} else { } else {
Some(add) Some(add)
} }
}), }),
persistent_keepalive: entries persistent_keepalive: entries
.iter() .iter()
.filter(|(key, _)| key == &"persistent_keepalive") .filter(|(key, _)| key == &"persistent_keepalive")
.map(|(_, value)| Duration::from_secs(value.parse::<u64>().unwrap())) .map(|(_, value)| Duration::from_secs(value.parse::<u64>().unwrap()))
.next(), .next(),
parsed: time::get_time(), parsed: time::get_time(),
}) })
} }
pub fn last_handshake_rel(&self) -> Option<Duration> { pub fn last_handshake_rel(&self) -> Option<Duration> {
let time = self.parsed; let time = self.parsed;
Some(Duration::new(time.sec as u64, time.nsec as u32) - self.last_handshake?) Some(Duration::new(time.sec as u64, time.nsec as u32) - self.last_handshake?)
} }
} }
impl State { impl State {
@ -222,13 +225,14 @@ fn main() {
.next() .next()
.map(PathBuf::from) .map(PathBuf::from)
.filter(|p| p.as_path().exists()); .filter(|p| p.as_path().exists());
let interval = args let interval = Duration::from_millis(
.next() args.next()
.map(|i| { .map(|i| {
i.parse::<u64>() i.parse::<u64>()
.expect("[interval] has to be a positive int") .expect("[interval] has to be a positive int")
}) })
.unwrap_or(1000); .unwrap_or(1000),
);
let mut listeners: Vec<Box<EventListener>> = vec![Box::new(LogListener)]; let mut listeners: Vec<Box<EventListener>> = vec![Box::new(LogListener)];
let events: PathBuf = "/etc/wireguard/events.sh".into(); let events: PathBuf = "/etc/wireguard/events.sh".into();
@ -237,9 +241,18 @@ fn main() {
listeners.push(Box::new(ScriptListener::new(events))) listeners.push(Box::new(ScriptListener::new(events)))
} }
let timeout = env::vars().collect::<HashMap<String,String>>().get("WG_EVENT_GEN_TIMEOUT").map(|timeout| Duration::from_secs(timeout.parse::<u64>().expect(&format!("Can't parse {} as timeout", timeout)))).unwrap_or(Duration::from_secs(30)); let timeout = env::vars()
.collect::<HashMap<String, String>>()
.get("WG_EVENT_GEN_TIMEOUT")
.map(|timeout| {
Duration::from_secs(
timeout
.parse::<u64>()
.expect(&format!("Can't parse {} as timeout", timeout)),
)
})
.unwrap_or(Duration::from_secs(30));
if let Some(path) = path { if let Some(path) = path {
let sock = Socket { path }; let sock = Socket { path };
let mut prev_state: Option<HashMap<String, Peer>> = None; let mut prev_state: Option<HashMap<String, Peer>> = None;
@ -252,10 +265,10 @@ fn main() {
} }
}; };
if let Some(prev_state) = prev_state { if let Some(prev_state) = prev_state {
gen::gen_events(&state, &prev_state, &listeners, timeout); gen::gen_events(&state, &prev_state, &listeners, timeout, interval);
} }
prev_state = Some(state); prev_state = Some(state);
thread::sleep(Duration::from_millis(interval)); thread::sleep(interval);
} }
} else { } else {
println!("<path> does not exist"); println!("<path> does not exist");