#![feature(try_trait_v2)] #![feature(exclusive_range_pattern)] #![feature(slice_split_once)] #![feature(iterator_try_collect)] pub mod config; pub mod error; pub mod filters; pub mod helper; #[cfg(feature = "mond")] pub mod reporting; use crate::{ config::{Config, RouteFilter}, filters::{files::serve_files, proxy::proxy_request}, }; use anyhow::{anyhow, Context, Result}; use bytes::Bytes; use config::setup_file_watch; use error::ServiceError; use futures::future::try_join_all; use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, header::{CONTENT_TYPE, HOST, SERVER}, http::HeaderValue, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use log::{debug, error, info, warn}; #[cfg(feature = "mond")] use reporting::Reporting; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use std::{ collections::HashMap, io::BufReader, net::SocketAddr, ops::ControlFlow, path::{Path, PathBuf}, process::exit, str::FromStr, sync::Arc, }; use tokio::{ fs::File, io::BufWriter, net::TcpListener, signal::ctrl_c, sync::{RwLock, Semaphore}, }; use tokio_rustls::TlsAcceptor; pub struct State { pub config: RwLock>, pub access_logs: RwLock>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, #[cfg(feature = "mond")] pub reporting: Reporting, } pub struct HostState {} pub type FilterRequest = Request; pub type FilterResponseOut = Option>>; pub type FilterResponse = Option>>; #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init_from_env("LOG"); let Some(config_path) = std::env::args().skip(1).next() else { eprintln!("error: first argument is expected to be the configuration file"); exit(1) }; let config_path = PathBuf::from_str(&config_path) .unwrap() .canonicalize() .unwrap(); let config = match Config::load(&config_path) { Ok(c) => c, Err(e) => { eprintln!("error {e:?}"); exit(1); } }; let state = Arc::new(State { l_incoming: Semaphore::new(config.limits.max_incoming_connections), l_outgoing: Semaphore::new(config.limits.max_outgoing_connections), #[cfg(feature = "mond")] reporting: Reporting::new(&config), config: RwLock::new(Arc::new(config)), access_logs: Default::default(), }); if state.config.read().await.watch_config { setup_file_watch(config_path, state.clone()); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_http(state).await { error!("{e:?}"); exit(1) } }); } { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_https(state).await { error!("{e:?}"); exit(1) } }); } ctrl_c().await.unwrap(); Ok(()) } async fn serve_http(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let http_config = match &config.http { Some(n) => n, None => return Ok(()), }; let listen_futs: Result> = try_join_all(http_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; info!("HTTP listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); let stream = TokioIo(stream); let state = state.clone(); tokio::spawn(async move { serve_stream(state, stream, addr).await }); } })) .await; info!("serving http"); listen_futs?; Ok(()) } async fn serve_https(state: Arc) -> Result<()> { let config = state.config.read().await.clone(); let https_config = match &config.https { Some(n) => n, None => return Ok(()), }; let tls_config = { let certs = load_certs(&https_config.tls_cert)?; let key = load_private_key(&https_config.tls_key)?; let mut cfg = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key)?; cfg.alpn_protocols = vec![ // b"h2".to_vec(), b"http/1.1".to_vec(), ]; Arc::new(cfg) }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); let listen_futs: Result> = try_join_all(https_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; info!("HTTPS listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; let state = state.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::task::spawn(async move { debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { Ok(stream) => serve_stream(state, TokioIo(stream), addr).await, Err(e) => warn!("error accepting tls: {e}"), }; }); } })) .await; listen_futs?; Ok(()) } pub async fn serve_stream( state: Arc, stream: T, addr: SocketAddr, ) { if let Ok(_semaphore) = state.l_incoming.try_acquire() { let conn = http1::Builder::new() .serve_connection( stream, service_fn(|req| { let state = state.clone(); async move { let config = state.config.read().await.clone(); match service(state, config, req, addr).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ debug!("service error {error:?}"); let mut resp = Response::new(format!( "Sorry, we were unable to process your request: {error}" )); *resp.status_mut() = StatusCode::BAD_REQUEST; resp.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); resp.headers_mut() .insert(SERVER, HeaderValue::from_static("gnix")); resp } .map(|b| b.map_err(|e| match e {}).boxed())), } } }), ) .with_upgrades(); if let Err(err) = conn.await { warn!("error: {:?}", err); } } else { warn!("connection dropped: too many incoming"); } } fn load_certs(path: &Path) -> anyhow::Result>> { let mut reader = BufReader::new(std::fs::File::open(path).context("reading tls certs")?); let certs = rustls_pemfile::certs(&mut reader) .try_collect::>() .context("parsing tls certs")?; Ok(certs) } fn load_private_key(path: &Path) -> anyhow::Result> { let mut reader = BufReader::new(std::fs::File::open(path).context("reading tls private key")?); let keys = rustls_pemfile::private_key(&mut reader).context("parsing tls private key")?; Ok(keys.ok_or(anyhow!("no private key found"))?) } async fn service( state: Arc, config: Arc, req: Request, addr: SocketAddr, ) -> Result>, ServiceError> { debug!("{addr} ~> {:?} {}", req.headers().get(HOST), req.uri()); #[cfg(feature = "mond")] state.reporting.request_in.inc(); let host = req .headers() .get(HOST) .and_then(|e| e.to_str().ok()) .map(String::from) .unwrap_or(String::from("")); let host = remove_port(&host); let route = config.hosts.get(host).ok_or(ServiceError::NoHost)?; #[cfg(feature = "mond")] state.reporting.hosts.get(host).unwrap().requests_in.inc(); // TODO this code is horrible let mut req = Some(req); let mut resp = None; for filter in &route.0 { let cf = match filter { RouteFilter::Proxy { backend } => { resp = Some( proxy_request( &state, req.take().ok_or(ServiceError::RequestTaken)?, addr, backend, ) .await?, ); ControlFlow::Continue(()) } RouteFilter::Files { config } => { resp = Some( serve_files(req.as_ref().ok_or(ServiceError::RequestTaken)?, config).await?, ); ControlFlow::Continue(()) } RouteFilter::HttpBasicAuth { config } => filters::auth::http_basic( config, req.as_ref().ok_or(ServiceError::RequestTaken)?, &mut resp, )?, RouteFilter::AccessLog { config } => { filters::accesslog::access_log( &state, host, addr, config, req.as_ref().ok_or(ServiceError::RequestTaken)?, ) .await? } }; match cf { ControlFlow::Continue(_) => continue, ControlFlow::Break(_) => break, } } let mut resp = resp.ok_or(ServiceError::NoResponse)?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( SERVER, HeaderValue::from_str(&if let Some(o) = server_header { format!("{} via gnix", o.to_str().ok().unwrap_or("invalid")) } else { format!("gnix") }) .unwrap(), ); return Ok(resp); } pub fn remove_port(s: &str) -> &str { s.split_once(":").map(|(s, _)| s).unwrap_or(s) }