diff options
author | metamuffin <metamuffin@disroot.org> | 2025-04-13 14:11:22 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2025-04-13 14:11:22 +0200 |
commit | 6e62b64a32bad5a28d3a352bc2fba6a495d65d71 (patch) | |
tree | f5dd30a1493072b86e17100ea3da5f4e53a12538 | |
parent | 5a2251390b0dc89f6383a054f6f596a416c91e3e (diff) | |
download | gnix-6e62b64a32bad5a28d3a352bc2fba6a495d65d71.tar gnix-6e62b64a32bad5a28d3a352bc2fba6a495d65d71.tar.bz2 gnix-6e62b64a32bad5a28d3a352bc2fba6a495d65d71.tar.zst |
fix parallel h3 request processing
-rw-r--r-- | src/config.rs | 2 | ||||
-rw-r--r-- | src/main.rs | 190 |
2 files changed, 108 insertions, 84 deletions
diff --git a/src/config.rs b/src/config.rs index 7b46944..112e86f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -57,6 +57,7 @@ pub struct Limits { pub max_incoming_connections: usize, pub max_outgoing_connections: usize, pub max_incoming_connections_h3: usize, + pub max_requests_per_connnection: usize, } #[derive(Debug, Serialize, Deserialize)] @@ -206,6 +207,7 @@ impl Default for Limits { Self { max_incoming_connections: 512, max_incoming_connections_h3: 4096, + max_requests_per_connnection: 16, max_outgoing_connections: 256, } } diff --git a/src/main.rs b/src/main.rs index f40f716..147a329 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,8 @@ use certs::CertPool; use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; use futures::future::try_join_all; -use h3::error::ErrorLevel; +use h3::{error::ErrorLevel, server::RequestStream}; +use h3_quinn::SendStream; use h3_support::H3RequestBody; use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING}; use http_body_util::{combinators::BoxBody, BodyExt}; @@ -53,6 +54,7 @@ use tokio::{ io::BufWriter, net::TcpListener, signal::ctrl_c, + spawn, sync::{RwLock, Semaphore}, }; use tokio_rustls::TlsAcceptor; @@ -264,97 +266,117 @@ async fn serve_h3(state: Arc<State>) -> Result<()> { cfg.alpn_protocols = vec![b"h3".to_vec()]; let cfg = Arc::new(QuicServerConfig::try_from(cfg)?); - try_join_all(bind_addrs.iter().map(|listen_addr| { - async { - let cfg = quinn::ServerConfig::with_crypto(cfg.clone()); - let endpoint = quinn::Endpoint::server(cfg, *listen_addr)?; - let listen_addr = *listen_addr; - info!("HTTPS (h3) listener bound to {listen_addr}/udp"); - while let Some(conn) = endpoint.accept().await { + try_join_all(bind_addrs.iter().map(|listen_addr| async { + let cfg = quinn::ServerConfig::with_crypto(cfg.clone()); + let endpoint = quinn::Endpoint::server(cfg, *listen_addr)?; + let listen_addr = *listen_addr; + info!("HTTPS (h3) listener bound to {listen_addr}/udp"); + while let Some(conn) = endpoint.accept().await { + let state = state.clone(); + let config = config.clone(); + tokio::spawn(serve_h3_stream(conn, config, state, listen_addr)); + } + Ok::<_, anyhow::Error>(()) + })) + .await?; + Ok(()) +} + +async fn serve_h3_stream( + conn: quinn::Incoming, + config: Arc<Config>, + state: Arc<State>, + listen_addr: SocketAddr, +) { + let addr = conn.remote_address(); // TODO wait for validatation (or not?) + debug!("h3 connection attempt from {addr}"); + // TODO move outside of spawn? + let Ok(_sem) = state.l_incoming_h3.try_acquire() else { + return conn.refuse(); + }; + let conn = match conn.accept() { + Ok(conn) => conn, + Err(e) => return warn!("quic accep failed: {e}"), + }; + let conn = match conn.await { + Ok(conn) => conn, + Err(e) => return warn!("quic connection failed: {e}"), + }; + let mut conn = + match h3::server::Connection::<_, Bytes>::new(h3_quinn::Connection::new(conn)).await { + Ok(conn) => conn, + Err(e) => return warn!("h3 accept failed {e}"), + }; + debug!("h3 stream from {addr}"); + let max_par_requests = Semaphore::new(config.limits.max_requests_per_connnection); + loop { + match conn.accept().await { + Ok(Some((req, stream))) => { + let Ok(_sem_req) = max_par_requests.acquire().await else { + warn!("h3 par request semasphore closed"); + return; + }; let state = state.clone(); - tokio::spawn(async move { - let addr = conn.remote_address(); // TODO wait for validatation (or not?) - debug!("h3 connection attempt from {addr}"); - let Ok(_sem) = state.l_incoming_h3.try_acquire() else { - return conn.refuse(); - }; - let conn = match conn.accept() { - Ok(conn) => conn, - Err(e) => return warn!("quic accep failed: {e}"), - }; - let conn = match conn.await { - Ok(conn) => conn, - Err(e) => return warn!("quic connection failed: {e}"), - }; - let mut conn = match h3::server::Connection::<_, Bytes>::new( - h3_quinn::Connection::new(conn), - ) - .await - { - Ok(conn) => conn, - Err(e) => return warn!("h3 accept failed {e}"), - }; - debug!("h3 stream from {addr}"); - loop { - match conn.accept().await { - Ok(Some((req, stream))) => { - let (mut send, recv) = stream.split(); - let req = req.map(|()| H3RequestBody(recv).boxed()); + spawn(async move { + let (mut send, recv) = stream.split(); + let req = req.map(|()| H3RequestBody(recv).boxed()); - let resp = service(state.clone(), req, addr, true, listen_addr) - .await - .unwrap_or_else(|error| error_response(addr, error)); + let resp = service(state.clone(), req, addr, true, listen_addr) + .await + .unwrap_or_else(|error| error_response(addr, error)); - let (parts, mut body) = resp.into_parts(); - let mut resp = Response::from_parts(parts, ()); + send_h3_response(resp, &mut send).await; + }); + drop(_sem_req) + } + Ok(None) => break, + Err(e) => match e.get_error_level() { + ErrorLevel::ConnectionError => break, + ErrorLevel::StreamError => continue, + }, + } + } + drop(_sem); +} - resp.headers_mut().remove(TRANSFER_ENCODING); // TODO allow "trailers" options - resp.headers_mut().remove(CONTENT_LENGTH); +async fn send_h3_response( + resp: Response<BoxBody<Bytes, ServiceError>>, + send: &mut RequestStream<SendStream<Bytes>, Bytes>, +) { + let (parts, mut body) = resp.into_parts(); + let mut resp = Response::from_parts(parts, ()); + + resp.headers_mut().remove(TRANSFER_ENCODING); // TODO allow "trailers" options + resp.headers_mut().remove(CONTENT_LENGTH); - if let Err(e) = send.send_response(resp).await { - debug!("h3 response send error: {e}"); - return; - }; - while let Some(frame) = body.frame().await { - match frame { - Ok(frame) => { - if frame.is_data() { - let data = frame.into_data().unwrap(); - if let Err(e) = send.send_data(data).await { - debug!("h3 body send error: {e}"); - return; - } - } else if frame.is_trailers() { - let trailers = frame.into_trailers().unwrap(); - if let Err(e) = send.send_trailers(trailers).await { - debug!("h3 trailers send error: {e}"); - return; - } - } - } - Err(_) => todo!(), - } - } - if let Err(e) = send.finish().await { - debug!("h3 response finish error: {e}"); - return; - } - } - Ok(None) => break, - Err(e) => match e.get_error_level() { - ErrorLevel::ConnectionError => break, - ErrorLevel::StreamError => continue, - }, - } + if let Err(e) = send.send_response(resp).await { + debug!("h3 response send error: {e}"); + return; + }; + while let Some(frame) = body.frame().await { + match frame { + Ok(frame) => { + if frame.is_data() { + let data = frame.into_data().unwrap(); + if let Err(e) = send.send_data(data).await { + debug!("h3 body send error: {e}"); + return; } - drop(_sem); - }); + } else if frame.is_trailers() { + let trailers = frame.into_trailers().unwrap(); + if let Err(e) = send.send_trailers(trailers).await { + debug!("h3 trailers send error: {e}"); + return; + } + } } - Ok::<_, anyhow::Error>(()) + Err(_) => todo!(), } - })) - .await?; - Ok(()) + } + if let Err(e) = send.finish().await { + debug!("h3 response finish error: {e}"); + return; + } } fn error_response(addr: SocketAddr, error: ServiceError) -> Response<BoxBody<Bytes, ServiceError>> { |