summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2025-11-15 19:31:01 +0100
committermetamuffin <metamuffin@disroot.org>2025-11-15 19:31:01 +0100
commit6524e4b26d869c9d48a4abedfb4efcc089b9588d (patch)
treeb9bd389dfe22d6e07845c8125a1936f8401c303d /src
parentcfe57a5e8f0cb9d76526776d3d274a638b138ea3 (diff)
downloadgnix-6524e4b26d869c9d48a4abedfb4efcc089b9588d.tar
gnix-6524e4b26d869c9d48a4abedfb4efcc089b9588d.tar.bz2
gnix-6524e4b26d869c9d48a4abedfb4efcc089b9588d.tar.zst
Establish listeners before switching users
Diffstat (limited to 'src')
-rw-r--r--src/control_socket.rs7
-rw-r--r--src/main.rs322
2 files changed, 157 insertions, 172 deletions
diff --git a/src/control_socket.rs b/src/control_socket.rs
index d0fa3b9..da057ec 100644
--- a/src/control_socket.rs
+++ b/src/control_socket.rs
@@ -29,7 +29,7 @@ pub enum ControlSocketResponse {
Error(String),
}
-pub async fn serve_control_socket(state: Arc<State>, path: &Path) -> Result<()> {
+pub async fn create_unix_listener(path: &Path) -> Result<UnixListener> {
if path.exists() {
let meta = path.metadata().context("metadata of old socket")?;
// TODO proper check
@@ -39,9 +39,10 @@ pub async fn serve_control_socket(state: Arc<State>, path: &Path) -> Result<()>
.context("removing old control socket")?;
}
}
+ Ok(UnixListener::bind(path).context("creating control socket")?)
+}
- let listener = UnixListener::bind(path).context("creating control socket")?;
-
+pub async fn serve_control_socket(state: Arc<State>, listener: UnixListener) -> Result<()> {
loop {
let Ok((conn, _addr)) = listener.accept().await else {
continue;
diff --git a/src/main.rs b/src/main.rs
index 9a623d0..c37059d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -21,9 +21,10 @@ pub mod h3_support;
pub mod modules;
use crate::{
- config::ConfigPackage,
+ config::{Config, ConfigPackage},
control_socket::{
- cs_client_reload, handle_control_socket_request, serve_control_socket, ControlSocketRequest,
+ create_unix_listener, cs_client_reload, handle_control_socket_request,
+ serve_control_socket, ControlSocketRequest,
},
generation::Generation,
};
@@ -32,7 +33,6 @@ use anyhow::{anyhow, Context, Result};
use bytes::Bytes;
use clap::Parser;
use error::ServiceError;
-use futures::future::try_join_all;
use h3::server::RequestStream;
use h3_quinn::SendStream;
use h3_support::H3RequestBody;
@@ -52,13 +52,13 @@ use notify::{RecursiveMode, Watcher};
use std::{
future::Future,
net::{IpAddr, SocketAddr},
- path::{Path, PathBuf},
+ path::PathBuf,
process::exit,
str::FromStr,
sync::Arc,
};
use tokio::{
- net::TcpListener,
+ net::{TcpListener, UdpSocket},
signal::ctrl_c,
spawn,
sync::{mpsc::channel, RwLock, Semaphore},
@@ -105,6 +105,7 @@ async fn main() -> anyhow::Result<()> {
.install_default()
.unwrap();
+ info!("Reading configuration and certificates...");
let config_package = ConfigPackage::new(&args.config)?;
if args.reload {
@@ -112,7 +113,7 @@ async fn main() -> anyhow::Result<()> {
.control_socket
.ok_or(anyhow!("reload needs control socket path"))?;
if args.watch {
- watch_config(&args.config, || async {
+ watch_config(args.config.clone(), || async {
let config_package = ConfigPackage::new(&args.config)?;
cs_client_reload(&cs_path, config_package).await?;
info!("Config updated");
@@ -133,11 +134,61 @@ async fn main() -> anyhow::Result<()> {
}
}
+ let mut http_listeners = Vec::new();
+ let mut https_listeners = Vec::new();
+ let mut https_listeners_h3 = Vec::new();
+ let mut control_socket_listeners = Vec::new();
+
+ let init_config: Config =
+ serde_yml::from_str(&config_package.config).context("parsing config YAML")?;
+ if let Some(path) = args.control_socket {
+ control_socket_listeners.push(
+ create_unix_listener(&path)
+ .await
+ .context(anyhow!("creating Unix listener for {path:?}"))?,
+ );
+ info!("Control Socket listener bound to {path:?}/unix");
+ }
+
+ if let Some(http_config) = &init_config.http {
+ for addr in &http_config.bind {
+ http_listeners.push(
+ TcpListener::bind(addr)
+ .await
+ .context(anyhow!("creating TCP listener for {addr}"))?,
+ );
+ info!("HTTP listener bound to {addr}/tcp");
+ }
+ }
+ if let Some(https_config) = &init_config.https {
+ for addr in &https_config.bind {
+ if !https_config.disable_h1 || !https_config.disable_h2 {
+ https_listeners.push(
+ TcpListener::bind(addr)
+ .await
+ .context(anyhow!("creating TCP listener for {addr}"))?,
+ );
+ info!("HTTPS (h1+h2) listener bound to {addr}/tcp");
+ }
+ if !https_config.disable_h3 {
+ https_listeners_h3.push(
+ UdpSocket::bind(addr)
+ .await
+ .context(anyhow!("creating UDP listener for {addr}"))?,
+ );
+ info!("HTTPS (h3) listener bound to {addr}/udp");
+ }
+ }
+ }
+ drop(init_config);
+
if let Some(username) = args.user {
+ info!("Switching user to {username:?}...");
let user = get_user_by_name(&username).ok_or(anyhow!("user for setuid not found"))?;
set_current_uid(user.uid()).context("setuid")?;
}
+ info!("Initializing request handlers...");
let generation = Generation::new(config_package)?;
let state = Arc::new(State {
crypto_key: aes_gcm_siv::Aes256GcmSiv::new(GenericArray::from_slice(
@@ -152,147 +203,115 @@ async fn main() -> anyhow::Result<()> {
if args.watch {
let state = state.clone();
- spawn(async move {
- if let Err(e) = watch_config(&args.config, || async {
- let config_package = ConfigPackage::new(&args.config)?;
- handle_control_socket_request(
- state.clone(),
- ControlSocketRequest::Config(config_package),
- )
- .await?;
- info!("Config updated");
- Ok(())
- })
- .await
- {
- error!("Config Watch: {e:#}");
- exit(1)
- }
- });
+ let config = args.config.clone();
+ spawn(exit_on_error(
+ "Config Watch",
+ watch_config(args.config.clone(), move || {
+ let config = config.clone();
+ let state = state.clone();
+ async move {
+ let config_package = ConfigPackage::new(&config)?;
+ handle_control_socket_request(
+ state.clone(),
+ ControlSocketRequest::Config(config_package),
+ )
+ .await?;
+ info!("Config updated");
+ Ok(())
+ }
+ }),
+ ));
}
- if let Some(path) = args.control_socket {
- let state = state.clone();
- tokio::spawn(async move {
- if let Err(e) = serve_control_socket(state, &path).await {
- error!("Control Socket: {e:#}");
- exit(1)
- }
- });
+ for li in control_socket_listeners {
+ spawn(exit_on_error(
+ "Control Socket",
+ serve_control_socket(state.clone(), li),
+ ));
}
- {
- let state = state.clone();
- tokio::spawn(async move {
- if let Err(e) = serve_http(state).await {
- error!("HTTP: {e:#}");
- exit(1)
- }
- });
+ for li in http_listeners {
+ spawn(exit_on_error("HTTP", serve_http(state.clone(), li)));
}
- {
- let state = state.clone();
- tokio::spawn(async move {
- if let Err(e) = serve_https(state).await {
- error!("HTTPS (h1+h2): {e:#}");
- exit(1)
- }
- });
+ for li in https_listeners {
+ spawn(exit_on_error(
+ "HTTPS (h1+h2)",
+ serve_https(state.clone(), li),
+ ));
}
- {
- let state = state.clone();
- tokio::spawn(async move {
- if let Err(e) = serve_h3(state).await {
- error!("HTTPS (h3): {e:#}");
- exit(1)
- }
- });
+ for li in https_listeners_h3 {
+ spawn(exit_on_error(
+ "HTTPS (h3)",
+ serve_https_h3(state.clone(), li),
+ ));
}
+ info!("Ready to accept connections");
ctrl_c().await.unwrap();
- Ok(())
+ exit(0)
}
-async fn serve_http(state: Arc<State>) -> Result<()> {
- let Some(bind_addrs) = state
- .generation
- .read()
- .await
- .config
- .http
- .as_ref()
- .map(|h| h.bind.clone())
- else {
- return Ok(()); // http disabled
- };
+async fn exit_on_error(name: &str, fut: impl Future<Output = Result<()>>) {
+ if let Err(e) = fut.await {
+ error!("{name}: {e:#}");
+ exit(1)
+ }
+}
- let listen_futs: Result<Vec<()>> = try_join_all(bind_addrs.iter().map(|e| async {
- let listener = TcpListener::bind(*e).await?;
- let listen_addr = listener.local_addr()?;
- info!(
- "HTTP listener bound to {}/tcp",
- listener.local_addr().unwrap()
- );
- loop {
- let (stream, addr) = listener.accept().await.context("accepting connection")?;
+async fn serve_http(state: Arc<State>, listener: TcpListener) -> Result<()> {
+ let listen_addr = listener.local_addr()?;
+ loop {
+ let (stream, addr) = listener.accept().await.context("accepting connection")?;
+ debug!("connection from {addr}");
+ let stream = TokioIo::new(stream);
+ let state = state.clone();
+ spawn(async move { serve_stream(state, stream, addr, false, listen_addr).await });
+ }
+}
+
+async fn serve_https(state: Arc<State>, listener: TcpListener) -> Result<()> {
+ let listen_addr = listener.local_addr()?;
+ loop {
+ let (stream, addr) = listener.accept().await.context("accepting connection")?;
+ let generation = state.generation.read().await.clone();
+ let Some(tls_acceptor) = generation.tls_acceptor.clone() else {
+ warn!("HTTPS (h1+h2) listener is missing TLS configuration");
+ break;
+ };
+ let state = state.clone();
+ spawn(async move {
debug!("connection from {addr}");
- let stream = TokioIo::new(stream);
- let state = state.clone();
- tokio::spawn(
- async move { serve_stream(state, stream, addr, false, listen_addr).await },
- );
- }
- }))
- .await;
- listen_futs?;
+ match tls_acceptor.accept(stream).await {
+ Ok(stream) => {
+ serve_stream(state, TokioIo::new(stream), addr, true, listen_addr).await
+ }
+ Err(e) => warn!("error accepting tls: {e}"),
+ };
+ });
+ }
+ info!(
+ "HTTPS (h1+h2) listener for {} shutting down",
+ listener.local_addr().unwrap()
+ );
Ok(())
}
-async fn serve_https(state: Arc<State>) -> Result<()> {
- let Some(bind_addrs) = state
- .generation
- .read()
- .await
- .config
- .https
- .as_ref()
- .map(|h| h.bind.clone())
- else {
- return Ok(()); // https disabled
+async fn serve_https_h3(state: Arc<State>, listener: UdpSocket) -> Result<()> {
+ let Some(quic_config) = state.generation.read().await.quic_config.clone() else {
+ warn!("HTTPS (h3) listener is missing TLS configuration");
+ return Ok(());
};
-
- let listen_futs: Result<Vec<()>> = try_join_all(bind_addrs.iter().map(|e| async {
- let listener = TcpListener::bind(*e).await?;
- let listen_addr = listener.local_addr()?;
- info!(
- "HTTPS (h1+h2) listener bound to {}/tcp",
- listener.local_addr().unwrap()
- );
- loop {
- let (stream, addr) = listener.accept().await.context("accepting connection")?;
- let generation = state.generation.read().await.clone();
- let Some(tls_acceptor) = generation.tls_acceptor.clone() else {
- warn!("HTTPS (h1+h2) listener is missing TLS configuration");
- break;
- };
- let state = state.clone();
- tokio::task::spawn(async move {
- debug!("connection from {addr}");
- match tls_acceptor.accept(stream).await {
- Ok(stream) => {
- serve_stream(state, TokioIo::new(stream), addr, true, listen_addr).await
- }
- Err(e) => warn!("error accepting tls: {e}"),
- };
- });
- }
- info!(
- "HTTPS (h1+h2) listener for {} shutting down",
- listener.local_addr().unwrap()
- );
- Ok(())
- }))
- .await;
- listen_futs?;
+ let listen_addr = listener.local_addr()?;
+ let endpoint = quinn::Endpoint::new(
+ quinn::EndpointConfig::default(),
+ Some(quic_config),
+ listener.into_std().unwrap(),
+ quinn::default_runtime().unwrap(),
+ )?;
+ state.quic_endpoints.write().await.push(endpoint.clone());
+ while let Some(conn) = endpoint.accept().await {
+ let state = state.clone();
+ spawn(serve_stream_h3(conn, state, listen_addr));
+ }
Ok(())
}
@@ -327,43 +346,7 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r
}
}
-async fn serve_h3(state: Arc<State>) -> Result<()> {
- let generation = state.generation.read().await.clone();
- let https_config = match &generation.config.https {
- Some(n) => n,
- None => return Ok(()),
- };
- if https_config.disable_h3 {
- return Ok(());
- }
- let bind_addrs = https_config.bind.clone();
-
- try_join_all(bind_addrs.iter().map(|listen_addr| async {
- let Some(quic_config) = generation.quic_config.clone() else {
- warn!("HTTPS (h3) listener is missing TLS configuration");
- return Ok(());
- };
- let endpoint = quinn::Endpoint::server(quic_config, *listen_addr)?;
- state.quic_endpoints.write().await.push(endpoint.clone());
- let listen_addr = *listen_addr;
- info!("HTTPS (h3) listener bound to {listen_addr}/udp");
- while let Some(conn) = endpoint.accept().await {
- let state = state.clone();
- let generation = generation.clone();
- tokio::spawn(serve_h3_stream(conn, generation, state, listen_addr));
- }
- Ok::<_, anyhow::Error>(())
- }))
- .await?;
- Ok(())
-}
-
-async fn serve_h3_stream(
- conn: quinn::Incoming,
- generation: Arc<Generation>,
- state: Arc<State>,
- listen_addr: SocketAddr,
-) {
+async fn serve_stream_h3(conn: quinn::Incoming, state: Arc<State>, listen_addr: SocketAddr) {
let addr = conn.remote_address();
debug!("h3 connection attempt from {addr}");
let Ok(_sem) = state.l_incoming_h3.try_acquire() else {
@@ -383,6 +366,7 @@ async fn serve_h3_stream(
Err(e) => return warn!("h3 accept failed {e}"),
};
debug!("h3 stream from {addr}");
+ let generation = state.generation.read().await;
let max_par_requests = Semaphore::new(generation.config.limits.max_requests_per_connnection);
loop {
match conn.accept().await {
@@ -540,7 +524,7 @@ async fn service(
}
async fn watch_config<R: Future<Output = Result<()>>>(
- path: &Path,
+ path: PathBuf,
mut reload: impl FnMut() -> R,
) -> Result<()> {
let (tx, mut rx) = channel(1);