summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2025-11-14 22:43:21 +0100
committermetamuffin <metamuffin@disroot.org>2025-11-14 22:43:21 +0100
commita4d828cfa4ba9ff7ae4e21df49b2f3f9c695e4fa (patch)
tree05806e218a7ba3a53ab0a5df9b664be97e472edb /src/main.rs
parent057031d68d0ab93aa6cc668bb5e37ad575a8ff8f (diff)
downloadgnix-a4d828cfa4ba9ff7ae4e21df49b2f3f9c695e4fa.tar
gnix-a4d828cfa4ba9ff7ae4e21df49b2f3f9c695e4fa.tar.bz2
gnix-a4d828cfa4ba9ff7ae4e21df49b2f3f9c695e4fa.tar.zst
add back watch option
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs97
1 files changed, 86 insertions, 11 deletions
diff --git a/src/main.rs b/src/main.rs
index 3f437ba..b59c325 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -22,7 +22,9 @@ pub mod modules;
use crate::{
config::ConfigPackage,
- control_socket::{cs_client_reload, serve_control_socket},
+ control_socket::{
+ cs_client_reload, handle_control_socket_request, serve_control_socket, ControlSocketRequest,
+ },
generation::Generation,
};
use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit};
@@ -46,9 +48,11 @@ use hyper::{
use hyper_util::rt::{TokioExecutor, TokioIo};
use log::{debug, error, info, warn, LevelFilter};
use modules::NodeContext;
+use notify::{RecursiveMode, Watcher};
use std::{
+ future::Future,
net::{IpAddr, SocketAddr},
- path::PathBuf,
+ path::{Path, PathBuf},
process::exit,
str::FromStr,
sync::Arc,
@@ -57,8 +61,9 @@ use tokio::{
net::TcpListener,
signal::ctrl_c,
spawn,
- sync::{RwLock, Semaphore},
+ sync::{mpsc::channel, RwLock, Semaphore},
};
+use users::{get_user_by_name, switch::set_current_uid};
pub struct State {
pub crypto_key: Aes256GcmSiv,
@@ -72,7 +77,7 @@ pub struct State {
#[derive(Parser)]
struct Args {
#[arg(long)]
- setuid: Option<String>,
+ user: Option<String>,
#[arg(long)]
control_socket: Option<PathBuf>,
#[arg(short, long)]
@@ -102,17 +107,33 @@ async fn main() -> anyhow::Result<()> {
let cs_path = args
.control_socket
.ok_or(anyhow!("reload needs control socket path"))?;
- match cs_client_reload(&cs_path, config_package).await {
- Ok(()) => {
- exit(0);
- }
- Err(e) => {
- eprintln!("Error: {e}");
- exit(1);
+ if args.watch {
+ watch_config(&args.config, || async {
+ let config_package = ConfigPackage::new(&args.config)?;
+ cs_client_reload(&cs_path, config_package).await?;
+ info!("Config updated");
+ Ok(())
+ })
+ .await?;
+ exit(1);
+ } else {
+ match cs_client_reload(&cs_path, config_package).await {
+ Ok(()) => {
+ exit(0);
+ }
+ Err(e) => {
+ eprintln!("Error: {e}");
+ exit(1);
+ }
}
}
}
+ if let Some(username) = args.user {
+ let user = get_user_by_name(&username).ok_or(anyhow!("user for setuid not found"))?;
+ set_current_uid(user.uid()).context("setuid")?;
+ }
+
let generation = Generation::new(config_package)?;
let state = Arc::new(State {
crypto_key: aes_gcm_siv::Aes256GcmSiv::new(GenericArray::from_slice(
@@ -125,6 +146,27 @@ async fn main() -> anyhow::Result<()> {
generation: RwLock::new(Arc::new(generation)),
});
+ if args.watch {
+ let state = state.clone();
+ spawn(async move {
+ if let Err(e) = watch_config(&args.config, || async {
+ let config_package = ConfigPackage::new(&args.config)?;
+ handle_control_socket_request(
+ state.clone(),
+ ControlSocketRequest::Config(config_package),
+ )
+ .await?;
+ info!("Config updated");
+ Ok(())
+ })
+ .await
+ {
+ error!("{e:?}");
+ exit(1)
+ }
+ });
+ }
+
if let Some(path) = args.control_socket {
let state = state.clone();
tokio::spawn(async move {
@@ -493,3 +535,36 @@ async fn service(
Ok(resp)
}
+
+async fn watch_config<R: Future<Output = Result<()>>>(
+ path: &Path,
+ mut reload: impl FnMut() -> R,
+) -> Result<()> {
+ let (tx, mut rx) = channel(1);
+ let mut w = notify::recommended_watcher(move |r| {
+ let _ = tx.blocking_send(r);
+ })?;
+ w.watch(path.parent().unwrap(), RecursiveMode::NonRecursive)?;
+
+ let path = path.canonicalize()?;
+ while let Some(r) = rx.recv().await {
+ match r {
+ Ok(ev) => {
+ if matches!(
+ ev.kind,
+ notify::EventKind::Access(notify::event::AccessKind::Close(
+ notify::event::AccessMode::Write
+ )) | notify::EventKind::Modify(notify::event::ModifyKind::Data(_))
+ ) && ev.paths.contains(&path)
+ {
+ if let Err(e) = reload().await {
+ error!("Error: {e:#}")
+ }
+ }
+ }
+ Err(e) => error!("watch error: {e:#}"),
+ }
+ }
+
+ Ok(())
+}