diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 180 |
1 files changed, 154 insertions, 26 deletions
diff --git a/src/main.rs b/src/main.rs index d8e1407..de43a6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,38 @@ -#![feature(try_trait_v2)] -#![feature(slice_split_once)] -#![feature(iterator_try_collect)] -#![feature(path_add_extension)] +#![feature( + try_trait_v2, + slice_split_once, + iterator_try_collect, + path_add_extension, + never_type +)] pub mod certs; pub mod config; pub mod error; +pub mod h3_support; pub mod modules; use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit}; use anyhow::{Context, Result}; +use bytes::Bytes; use certs::CertPool; use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; use futures::future::try_join_all; +use h3::error::ErrorLevel; +use h3_support::H3RequestBody; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, header::{CONTENT_TYPE, HOST, SERVER}, http::HeaderValue, service::service_fn, - Request, Response, StatusCode, Uri, + Request, Response, Uri, }; use hyper_util::rt::{TokioExecutor, TokioIo}; use log::{debug, error, info, warn, LevelFilter}; use modules::{NodeContext, MODULES}; +use quinn::crypto::rustls::QuicServerConfig; use std::{ collections::HashMap, net::SocketAddr, path::PathBuf, process::exit, str::FromStr, sync::Arc, }; @@ -107,6 +115,15 @@ async fn main() -> anyhow::Result<()> { } }); } + { + let state = state.clone(); + tokio::spawn(async move { + if let Err(e) = serve_h3(state).await { + error!("{e:?}"); + exit(1) + } + }); + } ctrl_c().await.unwrap(); Ok(()) @@ -180,6 +197,116 @@ async fn serve_https(state: Arc<State>) -> Result<()> { Ok(()) } +async fn serve_h3(state: Arc<State>) -> Result<()> { + let config = state.config.read().await.clone(); + let https_config = match &config.https { + Some(n) => n, + None => return Ok(()), + }; + let bind_addrs = https_config.bind.clone(); + let certs = CertPool::load(&https_config.cert_path, https_config.cert_fallback.clone())?; + let mut cfg = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(certs)); + cfg.alpn_protocols = vec![b"h3".to_vec()]; + let cfg = Arc::new(QuicServerConfig::try_from(cfg)?); + + try_join_all(bind_addrs.iter().map(|listen_addr| { + async { + let cfg = quinn::ServerConfig::with_crypto(cfg.clone()); + let endpoint = quinn::Endpoint::server(cfg, *listen_addr)?; + let listen_addr = *listen_addr; + info!("HTTPS (HTTP/3) listener bound to {listen_addr}"); + while let Some(conn) = endpoint.accept().await { + let state = state.clone(); + tokio::spawn(async move { + let addr = conn.remote_address(); // TODO wait for validatation (or not?) + let Ok(_sem) = state.l_incoming.try_acquire() else { + return conn.refuse(); + }; + let conn = match conn.await { + Ok(conn) => conn, + Err(e) => return warn!("quic connection failed: {e}"), + }; + let mut conn = match h3::server::Connection::<_, Bytes>::new( + h3_quinn::Connection::new(conn), + ) + .await + { + Ok(conn) => conn, + Err(e) => return warn!("h3 accept failed {e}"), + }; + loop { + match conn.accept().await { + Ok(Some((req, stream))) => { + let (mut send, recv) = stream.split(); + let req = req.map(|()| H3RequestBody(recv).boxed()); + + let resp = match service( + state.clone(), + req, + addr, + true, + listen_addr, + ) + .await + { + Ok(resp) => resp, + Err(error) => error_response(addr, error), + }; + + let (parts, mut body) = resp.into_parts(); + let resp = Response::from_parts(parts, ()); + + if let Err(e) = send.send_response(resp).await { + warn!("h3 response send error: {e}"); + return; + }; + while let Some(frame) = body.frame().await { + match frame { + Ok(frame) => { + if frame.is_data() { + let data = frame.into_data().unwrap(); + if let Err(e) = send.send_data(data).await { + warn!("h3 body send error: {e}"); + return; + } + } else if frame.is_trailers() { + let trailers = frame.into_trailers().unwrap(); + if let Err(e) = send.send_trailers(trailers).await { + warn!("h3 trailers send error: {e}"); + return; + } + } + } + Err(_) => todo!(), + } + } + if let Err(e) = send.finish().await { + warn!("h3 response finish error: {e}"); + return; + } + } + Ok(None) => break, + Err(e) => { + warn!("h3 error: {e}"); + match e.get_error_level() { + ErrorLevel::ConnectionError => break, + ErrorLevel::StreamError => continue, + } + } + } + } + drop(_sem); + }); + } + Ok::<_, anyhow::Error>(()) + } + })) + .await?; + Ok(()) +} + pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::rt::Write>( state: Arc<State>, stream: T, @@ -194,23 +321,11 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r service_fn(|req| { let state = state.clone(); async move { - let config = state.config.read().await.clone(); - match service(state, config, req, addr, secure, listen_addr).await { + let req = req.map(|body: Incoming| body.map_err(ServiceError::Hyper).boxed()); + match service(state, req, addr, secure, listen_addr).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), - Err(error) => Ok({ - warn!("service error {addr} {error:?}"); - let mut resp = Response::new(format!( - "Sorry, we were unable to process your request: {error}" - )); - *resp.status_mut() = error.status_code(); - resp.headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); - resp.headers_mut() - .insert(SERVER, HeaderValue::from_static("gnix")); - resp - } - .map(|b| b.map_err(|e| match e {}).boxed())), + Err(error) => Ok(error_response(addr, error)), } } }), @@ -223,14 +338,30 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r } } +fn error_response(addr: SocketAddr, error: ServiceError) -> Response<BoxBody<Bytes, ServiceError>> { + { + warn!("service error {addr} {error:?}"); + let mut resp = Response::new(format!( + "Sorry, we were unable to process your request: {error}" + )); + *resp.status_mut() = error.status_code(); + resp.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + resp.headers_mut() + .insert(SERVER, HeaderValue::from_static("gnix")); + resp + } + .map(|b| b.map_err(|e| match e {}).boxed()) +} + async fn service( state: Arc<State>, - config: Arc<Config>, - mut request: Request<Incoming>, + mut request: Request<BoxBody<Bytes, ServiceError>>, addr: SocketAddr, secure: bool, listen_addr: SocketAddr, ) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> { + let config = state.config.read().await.clone(); // move uri authority used in HTTP/2 to Host header field { let uri = request.uri_mut(); @@ -257,10 +388,7 @@ async fn service( secure, listen_addr, }; - let mut resp = config - .handler - .handle(&mut context, request.map(|body| body.boxed())) - .await?; + let mut resp = config.handler.handle(&mut context, request).await?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( |