#![feature(try_trait_v2)] #![feature(exclusive_range_pattern)] #![feature(slice_split_once)] #![feature(iterator_try_collect)] pub mod config; pub mod error; pub mod modules; pub mod helper; use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit}; use anyhow::{anyhow, Context, Result}; use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; use modules::{NodeContext, MODULES}; use futures::future::try_join_all; use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, header::{CONTENT_TYPE, HOST, SERVER}, http::HeaderValue, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use log::{debug, error, info, warn, LevelFilter}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use std::{ collections::HashMap, io::BufReader, net::SocketAddr, path::{Path, PathBuf}, process::exit, str::FromStr, sync::Arc, }; use tokio::{ fs::File, io::BufWriter, net::TcpListener, signal::ctrl_c, sync::{RwLock, Semaphore}, }; use tokio_rustls::TlsAcceptor; pub struct State { pub crypto_key: Aes256GcmSiv, pub config: RwLock>, pub access_logs: RwLock>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, } #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::Builder::new() .filter_level(LevelFilter::Info) .parse_env("LOG") .init(); NODE_KINDS .write() .unwrap() .extend(MODULES.iter().map(|m| (m.name().to_owned(), *m))); let Some(config_path) = std::env::args().skip(1).next() else { eprintln!("error: first argument is expected to be the configuration file"); exit(1) }; let config_path = PathBuf::from_str(&config_path) .unwrap() .canonicalize() .unwrap(); let config = match Config::load(&config_path) { Ok(c) => c, Err(e) => { eprintln!("error {e:?}"); exit(1); } }; let state = Arc::new(State { crypto_key: aes_gcm_siv::Aes256GcmSiv::new(GenericArray::from_slice(&config.private_key)), l_incoming: Semaphore::new(config.limits.max_incoming_connections), l_outgoing: Semaphore::new(config.limits.max_outgoing_connections), #[cfg(feature = "mond")] reporting: Reporting::new(&config), config: RwLock::new(Arc::new(config)), access_logs: Default::default(), }); if state.config.read().await.watch_config { setup_file_watch(config_path, state.clone()); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_http(state).await { error!("{e:?}"); exit(1) } }); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_https(state).await { error!("{e:?}"); exit(1) } }); } ctrl_c().await.unwrap(); Ok(()) } async fn serve_http(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let http_config = match &config.http { Some(n) => n, None => return Ok(()), }; let listen_futs: Result> = try_join_all(http_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; info!("HTTP listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); let stream = TokioIo(stream); let state = state.clone(); tokio::spawn(async move { serve_stream(state, stream, addr).await }); } })) .await; info!("serving http"); listen_futs?; Ok(()) } async fn serve_https(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let https_config = match &config.https { Some(n) => n, None => return Ok(()), }; let tls_config = { let certs = load_certs(&https_config.tls_cert)?; let key = load_private_key(&https_config.tls_key)?; let mut cfg = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; cfg.alpn_protocols = vec![ // b"h2".to_vec(), b"http/1.1".to_vec(), ]; Arc::new(cfg) }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); let listen_futs: Result> = try_join_all(https_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; info!("HTTPS listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; let state = state.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::task::spawn(async move { debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { Ok(stream) => serve_stream(state, TokioIo(stream), addr).await, Err(e) => warn!("error accepting tls: {e}"), }; }); } })) .await; listen_futs?; Ok(()) } pub async fn serve_stream( state: Arc, stream: T, addr: SocketAddr, ) { if let Ok(_semaphore) = state.l_incoming.try_acquire() { let conn = http1::Builder::new() .serve_connection( stream, service_fn(|req| { let state = state.clone(); async move { let config = state.config.read().await.clone(); match service(state, config, req, addr).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ warn!("service error {addr} {error:?}"); let mut resp = Response::new(format!( "Sorry, we were unable to process your request: {error}" )); *resp.status_mut() = StatusCode::BAD_REQUEST; resp.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); resp.headers_mut() .insert(SERVER, HeaderValue::from_static("gnix")); resp } .map(|b| b.map_err(|e| match e {}).boxed())), } } }), ) .with_upgrades(); if let Err(err) = conn.await { warn!("error: {:?}", err); } } else { warn!("connection dropped: too many incoming"); } } fn load_certs(path: &Path) -> anyhow::Result>> { let mut reader = BufReader::new(std::fs::File::open(path).context("reading tls certs")?); let certs = rustls_pemfile::certs(&mut reader) .try_collect::>() .context("parsing tls certs")?; Ok(certs) } fn load_private_key(path: &Path) -> anyhow::Result> { let mut reader = BufReader::new(std::fs::File::open(path).context("reading tls private key")?); let keys = rustls_pemfile::private_key(&mut reader).context("parsing tls private key")?; Ok(keys.ok_or(anyhow!("no private key found"))?) } async fn service( state: Arc, config: Arc, request: Request, addr: SocketAddr, ) -> Result>, ServiceError> { debug!( "{addr} ~> {:?} {}", request.headers().get(HOST), request.uri() ); let mut context = NodeContext { addr, state }; let mut resp = config.handler.handle(&mut context, request).await?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( SERVER, HeaderValue::from_str(&if let Some(o) = server_header { format!("{} via gnix", o.to_str().ok().unwrap_or("invalid")) } else { format!("gnix") }) .unwrap(), ); return Ok(resp); }