use core::net::SocketAddr; use dbus::{channel::MatchingReceiver, message::MatchRule}; use dbus_crossroads::{Context, Crossroads, MethodErr}; use defguard_wireguard_rs::{ host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WGApi, WireguardInterfaceApi, }; use log::{debug, error, info, warn}; use serde::{Deserialize, Serialize}; use std::{ collections::{BTreeSet, HashMap}, fs::File, io::{ErrorKind, Read, Write}, marker::PhantomData, net::ToSocketAddrs, ops::DerefMut, str::FromStr, sync::Arc, time::SystemTime, }; use thiserror::Error; use tokio::{ net::TcpListener, runtime::Builder, signal::unix::{signal, SignalKind}, sync::{broadcast, RwLock}, task, }; use xdg::BaseDirectories; use crate::{daemon_dbus::*, daemon_network::*}; #[derive(Debug, Error)] pub enum DaemonError { #[error("{0}")] Io(#[from] std::io::Error), #[error("{0}")] XdgBase(#[from] xdg::BaseDirectoriesError), // TODO hier wärs nett zu unterscheiden was decoded wurde #[error("{0}")] Decoding(#[from] serde_json::Error), #[error("{0}")] IpMaskParse(#[from] defguard_wireguard_rs::net::IpAddrParseError), #[error("{0}")] WgInterfaceError(#[from] defguard_wireguard_rs::error::WireguardInterfaceError), #[error("{0}")] DbusError(#[from] dbus::Error), } #[derive(Serialize, Deserialize, Clone)] pub enum Endpoint { Ip(SocketAddr), Domain(String, u16), } // subset of defguard_wireguard_rs::host::Peer, with hostname added #[derive(Serialize, Deserialize)] pub struct PeerConfig { pub psk: Option, pub ips: Vec<(IpAddrMask, Option)>, // if false: the hostnames are kept around for sharing, but we personally do not use them pub use_hostnames: bool, pub endpoint: Option, pub last_changed: SystemTime, pub known_to: Vec, pub mäsch_endpoint: SocketAddr, } fn default_wg_port() -> u16 { 51820 } #[derive(Serialize, Deserialize)] pub struct Network { pub privkey: String, // this really should be a different type, but this is what defguard takes... pub address: String, #[serde(default = "default_wg_port")] pub listen_port: u16, pub peers: HashMap, pub mäsch_port: u16, } #[derive(Serialize, Deserialize, Default)] pub struct Config { pub networks: HashMap, } // TODO das überschreibt änderungen an /etc/hosts während der runtime :( pub struct State { pub conf: Config, pub nw_handles: HashMap)>, pub hostfile: Option<(String, BTreeSet)>, } impl Drop for State { fn drop(&mut self) { for (wg_api, _) in self.nw_handles.values() { let _ = wg_api.remove_interface(); } } } pub fn daemon() -> Result<(), DaemonError> { let config_path = BaseDirectories::with_prefix("mäsch")?.place_state_file("daemon.json")?; let config: Config = match File::open(config_path) { Ok(f) => serde_json::from_reader(f)?, Err(e) => match e.kind() { ErrorKind::NotFound => Config::default(), _ => Err(e)?, }, }; info!("read config"); let hostfile = match File::open("/etc/hosts") { Ok(mut f) => { let mut r = String::new(); f.read_to_string(&mut r)?; let seen_hostnames: BTreeSet = r .lines() .map(|l| { l.split_whitespace() .take_while(|dom| dom.chars().next().unwrap() != '#') .skip(1) }) .flatten() .map(|dom| dom.to_owned()) .collect(); Some((r, seen_hostnames)) } Err(e) => { warn!("failed to read /etc/hosts: {e}"); None } }; let state = Arc::new(RwLock::new(State { conf: config, nw_handles: HashMap::new(), hostfile, })); let rt = Builder::new_current_thread().enable_all().build()?; rt.block_on(run_networks(state))?; Ok(()) } async fn run_networks(state: Arc>) -> Result<(), DaemonError> { let mut state_rw_guard = state.write().await; let state_rw = state_rw_guard.deref_mut(); // load existing configurations for (name, nw) in &state_rw.conf.networks { let wg_api = add_network( &mut state_rw.hostfile, name.clone(), nw.privkey.clone(), nw.address.clone(), nw.listen_port, &nw.peers, ) .await?; let addr = IpAddrMask::from_str(&nw.address)?.ip; let h = task::spawn(print_error(run_network( state.clone(), TcpListener::bind((addr, nw.mäsch_port)).await?, name.clone(), ))); state_rw.nw_handles.insert(name.clone(), (wg_api, h)); debug!("loaded configuration for {0}", name); } info!("loaded all existing configurations"); drop(state_rw_guard); // set up dbus interface let mut cr = Crossroads::new(); let state_ref = state.clone(); let if_token = cr.register("de.a.maesch", move |b| { b.signal::<(String, String), _>("Proposal", ("network", "peer_data")); b.method_with_cr_async( "AddNetwork", ("name", "key", "ip", "listen_port", "maesch_port"), ("success",), move |ctx, _, args: (String, String, String, u16, u16)| { debug!("Received AddNetwork"); handle_add_network(ctx, state_ref.clone(), args) }, ); }); cr.insert("/de/a/maesch", &[if_token], ()); // drive dbus interface let (res, c) = dbus_tokio::connection::new_system_sync()?; cr.set_async_support(Some(( c.clone(), Box::new(|x| { tokio::spawn(x); }), ))); let _ = tokio::spawn(print_error(async { res.await; Result::::Err("lost connection to dbus!") })); let receive_token = c.start_receive( MatchRule::new_method_call(), Box::new(move |msg, conn| { cr.handle_message(msg, conn).unwrap(); true }), ); c.request_name("de.a.maesch", false, true, false).await?; // wait for SIGTERM/SIGINT let mut sigterm_fut = signal(SignalKind::terminate())?; let mut sigint_fut = signal(SignalKind::interrupt())?; let mut sighup_fut = signal(SignalKind::hangup())?; tokio::select! { _ = sigterm_fut.recv() => info!("Received SIGTERM"), _ = sigint_fut.recv() => info!("Received SIGINT"), _ = sighup_fut.recv() => info!("Received SIGHUP"), }; // clean exit c.stop_receive(receive_token); let mut state_rw_guard = state.write().await; for (_, (wg_api, h)) in state_rw_guard.nw_handles.drain() { let _ = wg_api.remove_interface(); h.abort(); // could also join the handles... don't think that would do too much, though } Ok(()) } pub async fn print_error>>( f: F, ) -> () { match f.await { Err(e) => error!("oh no: {e}"), _ => (), }; }