summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/config.rs36
-rw-r--r--src/main.rs27
2 files changed, 55 insertions, 8 deletions
diff --git a/src/config.rs b/src/config.rs
index a7abf3b..be41420 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,4 +1,7 @@
+use crate::State;
use anyhow::Context;
+use inotify::{EventMask, Inotify, WatchMask};
+use log::{error, info};
use serde::{
de::{value, Error, SeqAccess, Visitor},
Deserialize, Deserializer, Serialize,
@@ -10,6 +13,7 @@ use std::{
marker::PhantomData,
net::SocketAddr,
path::PathBuf,
+ sync::Arc,
};
#[derive(Debug, Serialize, Deserialize)]
@@ -168,3 +172,35 @@ impl Default for Limits {
}
}
}
+
+pub fn setup_file_watch(config_path: String, state: Arc<State>) {
+ std::thread::spawn(move || {
+ let mut inotify = Inotify::init().unwrap();
+ inotify
+ .watches()
+ .add(
+ ".",
+ WatchMask::MODIFY | WatchMask::CREATE | WatchMask::DELETE,
+ )
+ .unwrap();
+ let mut buffer = [0u8; 4096];
+ loop {
+ let events = inotify
+ .read_events_blocking(&mut buffer)
+ .expect("Failed to read inotify events");
+
+ for event in events {
+ if event.mask.contains(EventMask::MODIFY) {
+ info!("reloading config");
+ match Config::load(&config_path) {
+ Ok(conf) => {
+ let mut r = state.config.blocking_write();
+ *r = Arc::new(conf)
+ }
+ Err(e) => error!("config has errors: {e}"),
+ }
+ }
+ }
+ }
+ });
+}
diff --git a/src/main.rs b/src/main.rs
index 059bdf0..0fb9957 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -18,6 +18,7 @@ use crate::{
};
use anyhow::{anyhow, bail, Context, Result};
use bytes::Bytes;
+use config::setup_file_watch;
use error::ServiceError;
use futures::future::try_join_all;
use helper::TokioIo;
@@ -37,11 +38,15 @@ use std::{
fs::File, io::BufReader, net::SocketAddr, ops::ControlFlow, path::Path, process::exit,
sync::Arc,
};
-use tokio::{net::TcpListener, signal::ctrl_c, sync::Semaphore};
+use tokio::{
+ net::TcpListener,
+ signal::ctrl_c,
+ sync::{RwLock, Semaphore},
+};
use tokio_rustls::TlsAcceptor;
pub struct State {
- pub config: Config,
+ pub config: RwLock<Arc<Config>>,
pub l_incoming: Semaphore,
pub l_outgoing: Semaphore,
#[cfg(feature = "mond")]
@@ -72,9 +77,11 @@ async fn main() -> anyhow::Result<()> {
l_outgoing: Semaphore::new(config.limits.max_outgoing_connections),
#[cfg(feature = "mond")]
reporting: Reporting::new(&config),
- config,
+ config: RwLock::new(Arc::new(config)),
});
+ setup_file_watch(config_path.to_owned(), state.clone());
+
{
let state = state.clone();
tokio::spawn(async move {
@@ -99,7 +106,8 @@ async fn main() -> anyhow::Result<()> {
}
async fn serve_http(state: Arc<State>) -> Result<()> {
- let http_config = match &state.config.http {
+ let config = state.config.read().await.clone();
+ let http_config = match &config.http {
Some(n) => n,
None => return Ok(()),
};
@@ -124,7 +132,8 @@ async fn serve_http(state: Arc<State>) -> Result<()> {
}
async fn serve_https(state: Arc<State>) -> Result<()> {
- let https_config = match &state.config.https {
+ let config = state.config.read().await.clone();
+ let https_config = match &config.https {
Some(n) => n,
None => return Ok(()),
};
@@ -136,7 +145,7 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
.with_no_client_auth()
.with_single_cert(certs, key)?;
cfg.alpn_protocols = vec![
- //b"h2".to_vec(),
+ // b"h2".to_vec(),
b"http/1.1".to_vec(),
];
Arc::new(cfg)
@@ -175,7 +184,8 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r
service_fn(|req| {
let state = state.clone();
async move {
- match service(state, req, addr).await {
+ let config = state.config.read().await.clone();
+ match service(state, config, req, addr).await {
Ok(r) => Ok(r),
Err(ServiceError::Hyper(e)) => Err(e),
Err(error) => Ok({
@@ -220,6 +230,7 @@ fn load_private_key(path: &Path) -> anyhow::Result<rustls::PrivateKey> {
async fn service(
state: Arc<State>,
+ config: Arc<Config>,
req: Request<Incoming>,
addr: SocketAddr,
) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> {
@@ -234,7 +245,7 @@ async fn service(
.map(String::from)
.unwrap_or(String::from(""));
let host = remove_port(&host);
- let route = state.config.hosts.get(host).ok_or(ServiceError::NoHost)?;
+ let route = config.hosts.get(host).ok_or(ServiceError::NoHost)?;
#[cfg(feature = "mond")]
state.reporting.hosts.get(host).unwrap().requests_in.inc();