diff options
| author | metamuffin <metamuffin@disroot.org> | 2025-11-15 19:31:01 +0100 |
|---|---|---|
| committer | metamuffin <metamuffin@disroot.org> | 2025-11-15 19:31:01 +0100 |
| commit | 6524e4b26d869c9d48a4abedfb4efcc089b9588d (patch) | |
| tree | b9bd389dfe22d6e07845c8125a1936f8401c303d /src | |
| parent | cfe57a5e8f0cb9d76526776d3d274a638b138ea3 (diff) | |
| download | gnix-6524e4b26d869c9d48a4abedfb4efcc089b9588d.tar gnix-6524e4b26d869c9d48a4abedfb4efcc089b9588d.tar.bz2 gnix-6524e4b26d869c9d48a4abedfb4efcc089b9588d.tar.zst | |
Establish listeners before switching users
Diffstat (limited to 'src')
| -rw-r--r-- | src/control_socket.rs | 7 | ||||
| -rw-r--r-- | src/main.rs | 322 |
2 files changed, 157 insertions, 172 deletions
diff --git a/src/control_socket.rs b/src/control_socket.rs index d0fa3b9..da057ec 100644 --- a/src/control_socket.rs +++ b/src/control_socket.rs @@ -29,7 +29,7 @@ pub enum ControlSocketResponse { Error(String), } -pub async fn serve_control_socket(state: Arc<State>, path: &Path) -> Result<()> { +pub async fn create_unix_listener(path: &Path) -> Result<UnixListener> { if path.exists() { let meta = path.metadata().context("metadata of old socket")?; // TODO proper check @@ -39,9 +39,10 @@ pub async fn serve_control_socket(state: Arc<State>, path: &Path) -> Result<()> .context("removing old control socket")?; } } + Ok(UnixListener::bind(path).context("creating control socket")?) +} - let listener = UnixListener::bind(path).context("creating control socket")?; - +pub async fn serve_control_socket(state: Arc<State>, listener: UnixListener) -> Result<()> { loop { let Ok((conn, _addr)) = listener.accept().await else { continue; diff --git a/src/main.rs b/src/main.rs index 9a623d0..c37059d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,9 +21,10 @@ pub mod h3_support; pub mod modules; use crate::{ - config::ConfigPackage, + config::{Config, ConfigPackage}, control_socket::{ - cs_client_reload, handle_control_socket_request, serve_control_socket, ControlSocketRequest, + create_unix_listener, cs_client_reload, handle_control_socket_request, + serve_control_socket, ControlSocketRequest, }, generation::Generation, }; @@ -32,7 +33,6 @@ use anyhow::{anyhow, Context, Result}; use bytes::Bytes; use clap::Parser; use error::ServiceError; -use futures::future::try_join_all; use h3::server::RequestStream; use h3_quinn::SendStream; use h3_support::H3RequestBody; @@ -52,13 +52,13 @@ use notify::{RecursiveMode, Watcher}; use std::{ future::Future, net::{IpAddr, SocketAddr}, - path::{Path, PathBuf}, + path::PathBuf, process::exit, str::FromStr, sync::Arc, }; use tokio::{ - net::TcpListener, + net::{TcpListener, UdpSocket}, signal::ctrl_c, spawn, sync::{mpsc::channel, RwLock, Semaphore}, @@ -105,6 +105,7 @@ async fn main() -> anyhow::Result<()> { .install_default() .unwrap(); + info!("Reading configuration and certificates..."); let config_package = ConfigPackage::new(&args.config)?; if args.reload { @@ -112,7 +113,7 @@ async fn main() -> anyhow::Result<()> { .control_socket .ok_or(anyhow!("reload needs control socket path"))?; if args.watch { - watch_config(&args.config, || async { + watch_config(args.config.clone(), || async { let config_package = ConfigPackage::new(&args.config)?; cs_client_reload(&cs_path, config_package).await?; info!("Config updated"); @@ -133,11 +134,61 @@ async fn main() -> anyhow::Result<()> { } } + let mut http_listeners = Vec::new(); + let mut https_listeners = Vec::new(); + let mut https_listeners_h3 = Vec::new(); + let mut control_socket_listeners = Vec::new(); + + let init_config: Config = + serde_yml::from_str(&config_package.config).context("parsing config YAML")?; + if let Some(path) = args.control_socket { + control_socket_listeners.push( + create_unix_listener(&path) + .await + .context(anyhow!("creating Unix listener for {path:?}"))?, + ); + info!("Control Socket listener bound to {path:?}/unix"); + } + + if let Some(http_config) = &init_config.http { + for addr in &http_config.bind { + http_listeners.push( + TcpListener::bind(addr) + .await + .context(anyhow!("creating TCP listener for {addr}"))?, + ); + info!("HTTP listener bound to {addr}/tcp"); + } + } + if let Some(https_config) = &init_config.https { + for addr in &https_config.bind { + if !https_config.disable_h1 || !https_config.disable_h2 { + https_listeners.push( + TcpListener::bind(addr) + .await + .context(anyhow!("creating TCP listener for {addr}"))?, + ); + info!("HTTPS (h1+h2) listener bound to {addr}/tcp"); + } + if !https_config.disable_h3 { + https_listeners_h3.push( + UdpSocket::bind(addr) + .await + .context(anyhow!("creating UDP listener for {addr}"))?, + ); + info!("HTTPS (h3) listener bound to {addr}/udp"); + } + } + } + drop(init_config); + if let Some(username) = args.user { + info!("Switching user to {username:?}..."); let user = get_user_by_name(&username).ok_or(anyhow!("user for setuid not found"))?; set_current_uid(user.uid()).context("setuid")?; } + info!("Initializing request handlers..."); let generation = Generation::new(config_package)?; let state = Arc::new(State { crypto_key: aes_gcm_siv::Aes256GcmSiv::new(GenericArray::from_slice( @@ -152,147 +203,115 @@ async fn main() -> anyhow::Result<()> { if args.watch { let state = state.clone(); - spawn(async move { - if let Err(e) = watch_config(&args.config, || async { - let config_package = ConfigPackage::new(&args.config)?; - handle_control_socket_request( - state.clone(), - ControlSocketRequest::Config(config_package), - ) - .await?; - info!("Config updated"); - Ok(()) - }) - .await - { - error!("Config Watch: {e:#}"); - exit(1) - } - }); + let config = args.config.clone(); + spawn(exit_on_error( + "Config Watch", + watch_config(args.config.clone(), move || { + let config = config.clone(); + let state = state.clone(); + async move { + let config_package = ConfigPackage::new(&config)?; + handle_control_socket_request( + state.clone(), + ControlSocketRequest::Config(config_package), + ) + .await?; + info!("Config updated"); + Ok(()) + } + }), + )); } - if let Some(path) = args.control_socket { - let state = state.clone(); - tokio::spawn(async move { - if let Err(e) = serve_control_socket(state, &path).await { - error!("Control Socket: {e:#}"); - exit(1) - } - }); + for li in control_socket_listeners { + spawn(exit_on_error( + "Control Socket", + serve_control_socket(state.clone(), li), + )); } - { - let state = state.clone(); - tokio::spawn(async move { - if let Err(e) = serve_http(state).await { - error!("HTTP: {e:#}"); - exit(1) - } - }); + for li in http_listeners { + spawn(exit_on_error("HTTP", serve_http(state.clone(), li))); } - { - let state = state.clone(); - tokio::spawn(async move { - if let Err(e) = serve_https(state).await { - error!("HTTPS (h1+h2): {e:#}"); - exit(1) - } - }); + for li in https_listeners { + spawn(exit_on_error( + "HTTPS (h1+h2)", + serve_https(state.clone(), li), + )); } - { - let state = state.clone(); - tokio::spawn(async move { - if let Err(e) = serve_h3(state).await { - error!("HTTPS (h3): {e:#}"); - exit(1) - } - }); + for li in https_listeners_h3 { + spawn(exit_on_error( + "HTTPS (h3)", + serve_https_h3(state.clone(), li), + )); } + info!("Ready to accept connections"); ctrl_c().await.unwrap(); - Ok(()) + exit(0) } -async fn serve_http(state: Arc<State>) -> Result<()> { - let Some(bind_addrs) = state - .generation - .read() - .await - .config - .http - .as_ref() - .map(|h| h.bind.clone()) - else { - return Ok(()); // http disabled - }; +async fn exit_on_error(name: &str, fut: impl Future<Output = Result<()>>) { + if let Err(e) = fut.await { + error!("{name}: {e:#}"); + exit(1) + } +} - let listen_futs: Result<Vec<()>> = try_join_all(bind_addrs.iter().map(|e| async { - let listener = TcpListener::bind(*e).await?; - let listen_addr = listener.local_addr()?; - info!( - "HTTP listener bound to {}/tcp", - listener.local_addr().unwrap() - ); - loop { - let (stream, addr) = listener.accept().await.context("accepting connection")?; +async fn serve_http(state: Arc<State>, listener: TcpListener) -> Result<()> { + let listen_addr = listener.local_addr()?; + loop { + let (stream, addr) = listener.accept().await.context("accepting connection")?; + debug!("connection from {addr}"); + let stream = TokioIo::new(stream); + let state = state.clone(); + spawn(async move { serve_stream(state, stream, addr, false, listen_addr).await }); + } +} + +async fn serve_https(state: Arc<State>, listener: TcpListener) -> Result<()> { + let listen_addr = listener.local_addr()?; + loop { + let (stream, addr) = listener.accept().await.context("accepting connection")?; + let generation = state.generation.read().await.clone(); + let Some(tls_acceptor) = generation.tls_acceptor.clone() else { + warn!("HTTPS (h1+h2) listener is missing TLS configuration"); + break; + }; + let state = state.clone(); + spawn(async move { debug!("connection from {addr}"); - let stream = TokioIo::new(stream); - let state = state.clone(); - tokio::spawn( - async move { serve_stream(state, stream, addr, false, listen_addr).await }, - ); - } - })) - .await; - listen_futs?; + match tls_acceptor.accept(stream).await { + Ok(stream) => { + serve_stream(state, TokioIo::new(stream), addr, true, listen_addr).await + } + Err(e) => warn!("error accepting tls: {e}"), + }; + }); + } + info!( + "HTTPS (h1+h2) listener for {} shutting down", + listener.local_addr().unwrap() + ); Ok(()) } -async fn serve_https(state: Arc<State>) -> Result<()> { - let Some(bind_addrs) = state - .generation - .read() - .await - .config - .https - .as_ref() - .map(|h| h.bind.clone()) - else { - return Ok(()); // https disabled +async fn serve_https_h3(state: Arc<State>, listener: UdpSocket) -> Result<()> { + let Some(quic_config) = state.generation.read().await.quic_config.clone() else { + warn!("HTTPS (h3) listener is missing TLS configuration"); + return Ok(()); }; - - let listen_futs: Result<Vec<()>> = try_join_all(bind_addrs.iter().map(|e| async { - let listener = TcpListener::bind(*e).await?; - let listen_addr = listener.local_addr()?; - info!( - "HTTPS (h1+h2) listener bound to {}/tcp", - listener.local_addr().unwrap() - ); - loop { - let (stream, addr) = listener.accept().await.context("accepting connection")?; - let generation = state.generation.read().await.clone(); - let Some(tls_acceptor) = generation.tls_acceptor.clone() else { - warn!("HTTPS (h1+h2) listener is missing TLS configuration"); - break; - }; - let state = state.clone(); - tokio::task::spawn(async move { - debug!("connection from {addr}"); - match tls_acceptor.accept(stream).await { - Ok(stream) => { - serve_stream(state, TokioIo::new(stream), addr, true, listen_addr).await - } - Err(e) => warn!("error accepting tls: {e}"), - }; - }); - } - info!( - "HTTPS (h1+h2) listener for {} shutting down", - listener.local_addr().unwrap() - ); - Ok(()) - })) - .await; - listen_futs?; + let listen_addr = listener.local_addr()?; + let endpoint = quinn::Endpoint::new( + quinn::EndpointConfig::default(), + Some(quic_config), + listener.into_std().unwrap(), + quinn::default_runtime().unwrap(), + )?; + state.quic_endpoints.write().await.push(endpoint.clone()); + while let Some(conn) = endpoint.accept().await { + let state = state.clone(); + spawn(serve_stream_h3(conn, state, listen_addr)); + } Ok(()) } @@ -327,43 +346,7 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r } } -async fn serve_h3(state: Arc<State>) -> Result<()> { - let generation = state.generation.read().await.clone(); - let https_config = match &generation.config.https { - Some(n) => n, - None => return Ok(()), - }; - if https_config.disable_h3 { - return Ok(()); - } - let bind_addrs = https_config.bind.clone(); - - try_join_all(bind_addrs.iter().map(|listen_addr| async { - let Some(quic_config) = generation.quic_config.clone() else { - warn!("HTTPS (h3) listener is missing TLS configuration"); - return Ok(()); - }; - let endpoint = quinn::Endpoint::server(quic_config, *listen_addr)?; - state.quic_endpoints.write().await.push(endpoint.clone()); - let listen_addr = *listen_addr; - info!("HTTPS (h3) listener bound to {listen_addr}/udp"); - while let Some(conn) = endpoint.accept().await { - let state = state.clone(); - let generation = generation.clone(); - tokio::spawn(serve_h3_stream(conn, generation, state, listen_addr)); - } - Ok::<_, anyhow::Error>(()) - })) - .await?; - Ok(()) -} - -async fn serve_h3_stream( - conn: quinn::Incoming, - generation: Arc<Generation>, - state: Arc<State>, - listen_addr: SocketAddr, -) { +async fn serve_stream_h3(conn: quinn::Incoming, state: Arc<State>, listen_addr: SocketAddr) { let addr = conn.remote_address(); debug!("h3 connection attempt from {addr}"); let Ok(_sem) = state.l_incoming_h3.try_acquire() else { @@ -383,6 +366,7 @@ async fn serve_h3_stream( Err(e) => return warn!("h3 accept failed {e}"), }; debug!("h3 stream from {addr}"); + let generation = state.generation.read().await; let max_par_requests = Semaphore::new(generation.config.limits.max_requests_per_connnection); loop { match conn.accept().await { @@ -540,7 +524,7 @@ async fn service( } async fn watch_config<R: Future<Output = Result<()>>>( - path: &Path, + path: PathBuf, mut reload: impl FnMut() -> R, ) -> Result<()> { let (tx, mut rx) = channel(1); |