#![feature(try_trait_v2)] #![feature(slice_split_once)] #![feature(iterator_try_collect)] pub mod certs; pub mod config; pub mod error; pub mod helper; pub mod modules; use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit}; use anyhow::{Context, Result}; use certs::CertPool; use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; use futures::future::try_join_all; use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, header::{CONTENT_TYPE, HOST, SERVER}, http::HeaderValue, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioExecutor; use log::{debug, error, info, warn, LevelFilter}; use modules::{NodeContext, MODULES}; use std::{ collections::HashMap, net::SocketAddr, path::PathBuf, process::exit, str::FromStr, sync::Arc, }; use tokio::{ fs::File, io::BufWriter, net::TcpListener, signal::ctrl_c, sync::{RwLock, Semaphore}, }; use tokio_rustls::TlsAcceptor; pub struct State { pub crypto_key: Aes256GcmSiv, pub config: RwLock>, pub access_logs: RwLock>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, } #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::Builder::new() .filter_level(LevelFilter::Info) .parse_env("LOG") .init(); rustls::crypto::ring::default_provider() .install_default() .unwrap(); NODE_KINDS .write() .unwrap() .extend(MODULES.iter().map(|m| (m.name().to_owned(), *m))); let Some(config_path) = std::env::args().nth(1) else { eprintln!("error: first argument is expected to be the configuration file"); exit(1) }; let config_path = PathBuf::from_str(&config_path) .unwrap() .canonicalize() .unwrap(); let config = match Config::load(&config_path) { Ok(c) => c, Err(e) => { eprintln!("error {e:?}"); exit(1); } }; let state = Arc::new(State { crypto_key: aes_gcm_siv::Aes256GcmSiv::new(GenericArray::from_slice(&config.private_key)), l_incoming: Semaphore::new(config.limits.max_incoming_connections), l_outgoing: Semaphore::new(config.limits.max_outgoing_connections), config: RwLock::new(Arc::new(config)), access_logs: Default::default(), }); if state.config.read().await.watch_config { setup_file_watch(config_path, state.clone()); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_http(state).await { error!("{e:?}"); exit(1) } }); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_https(state).await { error!("{e:?}"); exit(1) } }); } ctrl_c().await.unwrap(); Ok(()) } async fn serve_http(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let http_config = match &config.http { Some(n) => n, None => return Ok(()), }; let listen_futs: Result> = try_join_all(http_config.bind.iter().map(|e| async { let l = TcpListener::bind(*e).await?; info!("HTTP listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); let stream = TokioIo(stream); let state = state.clone(); tokio::spawn(async move { serve_stream(state, stream, addr).await }); } })) .await; info!("serving http"); listen_futs?; Ok(()) } async fn serve_https(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let https_config = match &config.https { Some(n) => n, None => return Ok(()), }; let tls_config = { let certs = CertPool::load(&https_config.cert_path, https_config.cert_fallback.clone())?; let mut cfg = rustls::ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new(certs)); cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Arc::new(cfg) }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); let listen_futs: Result> = try_join_all(https_config.bind.iter().map(|e| async { let l = TcpListener::bind(*e).await?; info!("HTTPS listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; let state = state.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::task::spawn(async move { debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { Ok(stream) => serve_stream(state, TokioIo(stream), addr).await, Err(e) => warn!("error accepting tls: {e}"), }; }); } })) .await; listen_futs?; Ok(()) } pub async fn serve_stream( state: Arc, stream: T, addr: SocketAddr, ) { if let Ok(_semaphore) = state.l_incoming.try_acquire() { let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); let conn = builder.serve_connection_with_upgrades( stream, service_fn(|req| { let state = state.clone(); async move { let config = state.config.read().await.clone(); match service(state, config, req, addr).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ warn!("service error {addr} {error:?}"); let mut resp = Response::new(format!( "Sorry, we were unable to process your request: {error}" )); *resp.status_mut() = StatusCode::BAD_REQUEST; resp.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); resp.headers_mut() .insert(SERVER, HeaderValue::from_static("gnix")); resp } .map(|b| b.map_err(|e| match e {}).boxed())), } } }), ); if let Err(err) = conn.await { warn!("error: {:?}", err); } } else { warn!("connection dropped: too many incoming"); } } async fn service( state: Arc, config: Arc, mut request: Request, addr: SocketAddr, ) -> Result>, ServiceError> { // copy uri authority used in HTTP/2 to Host header field if let Some(host) = request.uri().authority().map(|a| a.host()) { let host = HeaderValue::from_str(host).map_err(|_| ServiceError::InvalidHeader)?; request.headers_mut().insert(HOST, host); } debug!( "{addr} ~> {:?} {}", request.headers().get(HOST), request.uri() ); let mut context = NodeContext { addr, state }; let mut resp = config.handler.handle(&mut context, request).await?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( SERVER, if let Some(o) = server_header { HeaderValue::from_str(&format!( "{} via gnix", o.to_str().ok().unwrap_or("invalid") )) .unwrap() } else { HeaderValue::from_static("gnix") }, ); Ok(resp) }