/* This file is part of gnix (https://codeberg.org/metamuffin/gnix) which is licensed under the GNU Affero General Public License (version 3); see /COPYING. Copyright (C) 2025 metamuffin */ #![feature( try_trait_v2, slice_split_once, iterator_try_collect, path_add_extension, never_type, string_from_utf8_lossy_owned )] pub mod certs; pub mod config; pub mod error; pub mod h3_support; pub mod modules; use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit}; use anyhow::{Context, Result}; use bytes::Bytes; use certs::CertPool; use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; use futures::future::try_join_all; use h3::{error::ErrorLevel, server::RequestStream}; use h3_quinn::SendStream; use h3_support::H3RequestBody; use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING}; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, header::{CONTENT_TYPE, HOST, SERVER}, http::HeaderValue, service::service_fn, Request, Response, Uri, }; use hyper_util::rt::{TokioExecutor, TokioIo}; use log::{debug, error, info, warn, LevelFilter}; use modules::{NodeContext, MODULES}; use quinn::crypto::rustls::QuicServerConfig; use std::{ collections::HashMap, net::{IpAddr, SocketAddr}, path::PathBuf, process::exit, str::FromStr, sync::Arc, }; use tokio::{ fs::File, io::BufWriter, net::TcpListener, signal::ctrl_c, spawn, sync::{RwLock, Semaphore}, }; use tokio_rustls::TlsAcceptor; pub struct State { pub crypto_key: Aes256GcmSiv, pub config: RwLock>, pub access_logs: RwLock>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, pub l_incoming_h3: Semaphore, } #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::Builder::new() .filter_level(LevelFilter::Info) .parse_env("LOG") .init(); rustls::crypto::ring::default_provider() .install_default() .unwrap(); NODE_KINDS .write() .unwrap() .extend(MODULES.iter().map(|m| (m.name().to_owned(), *m))); let Some(config_path) = std::env::args().nth(1) else { eprintln!("error: first argument is expected to be the configuration file"); exit(1) }; let config_path = PathBuf::from_str(&config_path) .unwrap() .canonicalize() .unwrap(); let config = match Config::load(&config_path) { Ok(c) => c, Err(e) => { eprintln!("error {e:?}"); exit(1); } }; let state = Arc::new(State { crypto_key: aes_gcm_siv::Aes256GcmSiv::new(GenericArray::from_slice(&config.private_key)), l_incoming: Semaphore::new(config.limits.max_incoming_connections), l_incoming_h3: Semaphore::new(config.limits.max_incoming_connections_h3), l_outgoing: Semaphore::new(config.limits.max_outgoing_connections), config: RwLock::new(Arc::new(config)), access_logs: Default::default(), }); if state.config.read().await.watch_config { setup_file_watch(config_path, state.clone()); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_http(state).await { error!("{e:?}"); exit(1) } }); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_https(state).await { error!("{e:?}"); exit(1) } }); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_h3(state).await { error!("{e:?}"); exit(1) } }); } ctrl_c().await.unwrap(); Ok(()) } async fn serve_http(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let http_config = match &config.http { Some(n) => n, None => return Ok(()), }; let listen_futs: Result> = try_join_all(http_config.bind.iter().map(|e| async { let l = TcpListener::bind(*e).await?; let listen_addr = l.local_addr()?; info!("HTTP listener bound to {}/tcp", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); let stream = TokioIo::new(stream); let state = state.clone(); tokio::spawn( async move { serve_stream(state, stream, addr, false, listen_addr).await }, ); } })) .await; listen_futs?; Ok(()) } async fn serve_https(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let https_config = match &config.https { Some(n) => n, None => return Ok(()), }; let tls_config = { let certs = CertPool::load(&https_config.cert_path, https_config.cert_fallback.clone())?; let mut cfg = rustls::ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new(certs)); if !https_config.disable_h1 { cfg.alpn_protocols.push(b"http/1.1".to_vec()); } if !https_config.disable_h2 { cfg.alpn_protocols.push(b"h2".to_vec()); } Arc::new(cfg) }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); let listen_futs: Result> = try_join_all(https_config.bind.iter().map(|e| async { let l = TcpListener::bind(*e).await?; let listen_addr = l.local_addr()?; info!( "HTTPS (h1+h2) listener bound to {}/tcp", l.local_addr().unwrap() ); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; let state = state.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::task::spawn(async move { debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { Ok(stream) => { serve_stream(state, TokioIo::new(stream), addr, true, listen_addr).await } Err(e) => warn!("error accepting tls: {e}"), }; }); } })) .await; listen_futs?; Ok(()) } pub async fn serve_stream( state: Arc, stream: T, addr: SocketAddr, secure: bool, listen_addr: SocketAddr, ) { if let Ok(_semaphore) = state.l_incoming.try_acquire() { let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); let conn = builder.serve_connection_with_upgrades( stream, service_fn(|req| { let state = state.clone(); async move { let req = req.map(|body: Incoming| body.map_err(ServiceError::Hyper).boxed()); match service(state, req, addr, secure, listen_addr).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok(error_response(addr, error)), } } }), ); if let Err(err) = conn.await { warn!("error: {:?}", err); } } else { warn!("connection dropped: too many incoming"); } } async fn serve_h3(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let https_config = match &config.https { Some(n) => n, None => return Ok(()), }; if https_config.disable_h3 { return Ok(()); } let bind_addrs = https_config.bind.clone(); let certs = CertPool::load(&https_config.cert_path, https_config.cert_fallback.clone())?; let mut cfg = rustls::ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new(certs)); cfg.alpn_protocols = vec![b"h3".to_vec()]; let cfg = Arc::new(QuicServerConfig::try_from(cfg)?); try_join_all(bind_addrs.iter().map(|listen_addr| async { let cfg = quinn::ServerConfig::with_crypto(cfg.clone()); let endpoint = quinn::Endpoint::server(cfg, *listen_addr)?; let listen_addr = *listen_addr; info!("HTTPS (h3) listener bound to {listen_addr}/udp"); while let Some(conn) = endpoint.accept().await { let state = state.clone(); let config = config.clone(); tokio::spawn(serve_h3_stream(conn, config, state, listen_addr)); } Ok::<_, anyhow::Error>(()) })) .await?; Ok(()) } async fn serve_h3_stream( conn: quinn::Incoming, config: Arc, state: Arc, listen_addr: SocketAddr, ) { let addr = conn.remote_address(); // TODO wait for validatation (or not?) debug!("h3 connection attempt from {addr}"); // TODO move outside of spawn? let Ok(_sem) = state.l_incoming_h3.try_acquire() else { return conn.refuse(); }; let conn = match conn.accept() { Ok(conn) => conn, Err(e) => return warn!("quic accep failed: {e}"), }; let conn = match conn.await { Ok(conn) => conn, Err(e) => return warn!("quic connection failed: {e}"), }; let mut conn = match h3::server::Connection::<_, Bytes>::new(h3_quinn::Connection::new(conn)).await { Ok(conn) => conn, Err(e) => return warn!("h3 accept failed {e}"), }; debug!("h3 stream from {addr}"); let max_par_requests = Semaphore::new(config.limits.max_requests_per_connnection); loop { match conn.accept().await { Ok(Some((req, stream))) => { let Ok(_sem_req) = max_par_requests.acquire().await else { warn!("h3 par request semasphore closed"); return; }; let state = state.clone(); spawn(async move { let (mut send, recv) = stream.split(); let req = req.map(|()| H3RequestBody(recv).boxed()); let resp = service(state.clone(), req, addr, true, listen_addr) .await .unwrap_or_else(|error| error_response(addr, error)); send_h3_response(resp, &mut send).await; }); drop(_sem_req) } Ok(None) => break, Err(e) => match e.get_error_level() { ErrorLevel::ConnectionError => break, ErrorLevel::StreamError => continue, }, } } drop(_sem); } async fn send_h3_response( resp: Response>, send: &mut RequestStream, Bytes>, ) { let (parts, mut body) = resp.into_parts(); let mut resp = Response::from_parts(parts, ()); resp.headers_mut().remove(TRANSFER_ENCODING); // TODO allow "trailers" options resp.headers_mut().remove(CONTENT_LENGTH); if let Err(e) = send.send_response(resp).await { debug!("h3 response send error: {e}"); return; }; while let Some(frame) = body.frame().await { match frame { Ok(frame) => { if frame.is_data() { let data = frame.into_data().unwrap(); if let Err(e) = send.send_data(data).await { debug!("h3 body send error: {e}"); return; } } else if frame.is_trailers() { let trailers = frame.into_trailers().unwrap(); if let Err(e) = send.send_trailers(trailers).await { debug!("h3 trailers send error: {e}"); return; } } } Err(_) => todo!(), } } if let Err(e) = send.finish().await { debug!("h3 response finish error: {e}"); return; } } fn error_response(addr: SocketAddr, error: ServiceError) -> Response> { { warn!("service error {addr} {error:?}"); let mut resp = Response::new(format!( "Sorry, we were unable to process your request: {error}" )); *resp.status_mut() = error.status_code(); resp.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); resp.headers_mut() .insert(SERVER, HeaderValue::from_static("gnix")); resp } .map(|b| b.map_err(|e| match e {}).boxed()) } async fn service( state: Arc, mut request: Request>, mut addr: SocketAddr, secure: bool, listen_addr: SocketAddr, ) -> Result>, ServiceError> { let config = state.config.read().await.clone(); // move uri authority used in HTTP/2 to Host header field { let uri = request.uri_mut(); if let Some(authority) = uri.authority() { let host = HeaderValue::from_str(authority.host()).map_err(|_| ServiceError::InvalidUri)?; let mut new_uri = http::uri::Parts::default(); new_uri.path_and_query = uri.path_and_query().cloned(); *uri = Uri::from_parts(new_uri).map_err(|_| ServiceError::InvalidUri)?; request.headers_mut().insert(HOST, host); } } if config.source_ip_from_header { if let Some(x) = request.headers_mut().remove("x-real-ip") { addr = SocketAddr::new( IpAddr::from_str(x.to_str()?).map_err(|_| ServiceError::InvalidHeader)?, 0, ); } else { return Err(ServiceError::XRealIPMissing); } } debug!( "{addr} ~> {:?} {}", request.headers().get(HOST), request.uri() ); let mut context = NodeContext { addr, state, secure, listen_addr, }; let mut resp = config.handler.handle(&mut context, request).await?; if !config.disable_server_header { let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( SERVER, if let Some(o) = server_header { HeaderValue::from_str(&format!( "{} via gnix", o.to_str().ok().unwrap_or("invalid") )) .unwrap() } else { HeaderValue::from_static("gnix") }, ); } Ok(resp) }