pub mod config; use crate::config::Config; use anyhow::{anyhow, bail, Context, Result}; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, header::{HOST, UPGRADE}, http::{uri::PathAndQuery, HeaderValue}, server::conn::http1, service::service_fn, upgrade::OnUpgrade, Request, Response, StatusCode, Uri, }; use log::{debug, error, info, warn}; use std::{fs::File, io::BufReader, path::Path, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream}, }; use tokio_rustls::TlsAcceptor; #[derive(Debug, thiserror::Error)] enum ServiceError { #[error("hyper error")] Hyper(hyper::Error), #[error("unknown host")] NoHost, #[error("can't connect to the backend")] CantConnect, } #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init_from_env("LOG"); let config_path = std::env::args().skip(1).next().ok_or(anyhow!( "first argument is expected to be the configuration file" ))?; let config = Arc::new(Config::load(&config_path)?); tokio::select! { x = serve_http(config.clone()) => x.context("serving http")?, x = serve_https(config.clone()) => x.context("serving https")?, }; Ok(()) } async fn serve_http(config: Arc) -> Result<()> { let http_config = match &config.http { Some(n) => n, None => return Ok(()), }; let listener = TcpListener::bind(http_config.bind).await?; info!("serving http"); loop { let (stream, addr) = listener.accept().await.context("accepting connection")?; debug!("connection from {addr}"); serve_stream(config.clone(), stream) } } async fn serve_https(config: Arc) -> Result<()> { let https_config = match &config.https { Some(n) => n, None => return Ok(()), }; let tls_config = { let certs = load_certs(&https_config.tls_cert)?; let key = load_private_key(&https_config.tls_key)?; let mut cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(certs, key)?; cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Arc::new(cfg) }; let listener = TcpListener::bind(https_config.bind).await?; let tls_acceptor = TlsAcceptor::from(tls_config); info!("serving https"); loop { let (stream, addr) = listener.accept().await.context("accepting connection")?; debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { Ok(stream) => serve_stream(config.clone(), stream), Err(e) => warn!("error accepting tls: {e}"), }; } } pub fn serve_stream( config: Arc, stream: T, ) { tokio::task::spawn(async move { let conn = http1::Builder::new() .serve_connection( stream, service_fn(move |req| { let config = config.clone(); async move { match service(config, req).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ let mut resp = Response::new(format!( "the reverse proxy encountered an issue: {error}" )); *resp.status_mut() = StatusCode::BAD_REQUEST; resp } .map(|b| b.map_err(|e| match e {}).boxed())), } } }), ) .with_upgrades(); if let Err(err) = conn.await { warn!("error: {:?}", err); } }); } fn load_certs(path: &Path) -> anyhow::Result> { let mut reader = BufReader::new(File::open(path).context("reading tls certs")?); let certs = rustls_pemfile::certs(&mut reader).context("parsing tls certs")?; Ok(certs.into_iter().map(rustls::Certificate).collect()) } fn load_private_key(path: &Path) -> anyhow::Result { let mut reader = BufReader::new(File::open(path).context("reading tls private key")?); let keys = rustls_pemfile::pkcs8_private_keys(&mut reader).context("parsing tls private key")?; if keys.len() != 1 { bail!("expected a single private key, found {}", keys.len()) } Ok(rustls::PrivateKey(keys[0].clone())) } async fn service( config: Arc, mut req: Request, ) -> Result>, ServiceError> { debug!("<- {:?} {}", req.headers().get(HOST), req.uri()); *req.uri_mut() = Uri::builder() .scheme("http") .authority("backend") .path_and_query( req.uri() .clone() .path_and_query() .cloned() .unwrap_or(PathAndQuery::from_static("/")), ) .build() .unwrap(); let route = config .hosts .get(remove_port( &req.headers() .get(HOST) .and_then(|e| e.to_str().ok()) .map(String::from) .unwrap_or(String::from("")), )) .ok_or(ServiceError::NoHost)?; let do_upgrade = req.headers().contains_key(UPGRADE); let on_upgrade_downstream = req.extensions_mut().remove::(); let mut resp = { let client_stream = TcpStream::connect(&route.backend) .await .map_err(|_| ServiceError::CantConnect)?; let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) .await .map_err(ServiceError::Hyper)?; tokio::task::spawn(async move { if let Err(err) = conn.await { warn!("connection failed: {:?}", err); } }); sender .send_request(req) .await .map_err(ServiceError::Hyper)? }; resp.headers_mut() .insert("server", HeaderValue::from_static("gnix")); if do_upgrade { let on_upgrade_upstream = resp.extensions_mut().remove::(); tokio::task::spawn(async move { debug!("about upgrading connection, sending empty response"); match ( on_upgrade_upstream.unwrap().await, on_upgrade_downstream.unwrap().await, ) { (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => { debug!("upgrade successful"); match tokio::io::copy_bidirectional( &mut upgraded_downstream, &mut upgraded_upstream, ) .await { Ok((from_client, from_server)) => { debug!("proxy socket terminated: {from_server} sent, {from_client} received") } Err(e) => warn!("proxy socket error: {e}"), } } (a, b) => eprintln!("upgrade error: upstream={a:?} downstream={b:?}"), } }); } Ok(resp.map(|b| b.boxed())) } pub fn remove_port(s: &str) -> &str { s.split_once(":").map(|(s, _)| s).unwrap_or(s) }