diff options
author | metamuffin <metamuffin@disroot.org> | 2023-04-06 17:51:42 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-04-06 17:51:42 +0200 |
commit | c3c3a07cae6a938534824c32573927dd7a5ece4b (patch) | |
tree | 3238fae7f8334a18b9fc6bb3342416fe184c4492 | |
parent | 6d06d9a0eaf8d7c3f03df1501b9acd0a71cb53ae (diff) | |
download | gnix-c3c3a07cae6a938534824c32573927dd7a5ece4b.tar gnix-c3c3a07cae6a938534824c32573927dd7a5ece4b.tar.bz2 gnix-c3c3a07cae6a938534824c32573927dd7a5ece4b.tar.zst |
configure max_incoming_connections
-rw-r--r-- | src/config.rs | 20 | ||||
-rw-r--r-- | src/main.rs | 88 |
2 files changed, 79 insertions, 29 deletions
diff --git a/src/config.rs b/src/config.rs index 58b885a..68479e4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,10 +6,21 @@ use std::{collections::HashMap, fs::read_to_string, net::SocketAddr, path::PathB pub struct Config { pub http: Option<HttpConfig>, pub https: Option<HttpsConfig>, + #[serde(default)] + pub limits: Limits, + #[serde(default)] pub hosts: HashMap<String, HostConfig>, } #[derive(Debug, Serialize, Deserialize)] +pub struct Limits { + #[serde(default)] + pub max_incoming_connections: usize, + #[serde(default)] + pub max_outgoing_connections: usize, +} + +#[derive(Debug, Serialize, Deserialize)] pub struct HttpConfig { pub bind: SocketAddr, } @@ -42,3 +53,12 @@ impl Config { Ok(config) } } + +impl Default for Limits { + fn default() -> Self { + Self { + max_incoming_connections: 1024, + max_outgoing_connections: usize::MAX, + } + } +} diff --git a/src/main.rs b/src/main.rs index ef62af7..f55771c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,17 @@ use hyper::{ Request, Response, StatusCode, }; use log::{debug, error, info, warn}; -use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, sync::Arc}; +use std::{ + fs::File, + io::BufReader, + net::SocketAddr, + path::Path, + process::exit, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpListener, @@ -33,6 +43,7 @@ use tokio_rustls::TlsAcceptor; pub struct State { pub config: Config, + pub total_connection: AtomicUsize, } #[tokio::main] @@ -42,15 +53,25 @@ async fn main() -> anyhow::Result<()> { let config_path = std::env::args().skip(1).next().ok_or(anyhow!( "first argument is expected to be the configuration file" ))?; - - let config = Config::load(&config_path)?; - let state = Arc::new(State { config }); + + let config = match Config::load(&config_path) { + Ok(c) => c, + Err(e) => { + eprintln!("error {e:?}"); + exit(1); + } + }; + let state = Arc::new(State { + config, + total_connection: AtomicUsize::new(0), + }); { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_http(state).await { - error!("{e}") + error!("{e}"); + exit(1) } }); } @@ -58,7 +79,8 @@ async fn main() -> anyhow::Result<()> { let state = state.clone(); tokio::spawn(async move { if let Err(e) = serve_https(state).await { - error!("{e}") + error!("{e}"); + exit(1) } }); } @@ -123,32 +145,40 @@ pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>( stream: T, addr: SocketAddr, ) { - let conn = http1::Builder::new() - .serve_connection( - stream, - service_fn(move |req| { - let state = state.clone(); - async move { - match service(state, req, addr).await { - Ok(r) => Ok(r), - Err(ServiceError::Hyper(e)) => Err(e), - Err(error) => Ok({ - let mut resp = - Response::new(format!("gnix encountered an issue: {error}")); - *resp.status_mut() = StatusCode::BAD_REQUEST; - resp.headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); - resp + let conns = state.total_connection.fetch_add(1, Relaxed); + + if conns >= state.config.limits.max_incoming_connections { + let conn = http1::Builder::new() + .serve_connection( + stream, + service_fn(|req| { + let state = state.clone(); + async move { + match service(state, req, addr).await { + Ok(r) => Ok(r), + Err(ServiceError::Hyper(e)) => Err(e), + Err(error) => Ok({ + let mut resp = + Response::new(format!("gnix encountered an issue: {error}")); + *resp.status_mut() = StatusCode::BAD_REQUEST; + resp.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + resp + } + .map(|b| b.map_err(|e| match e {}).boxed())), } - .map(|b| b.map_err(|e| match e {}).boxed())), } - } - }), - ) - .with_upgrades(); - if let Err(err) = conn.await { - warn!("error: {:?}", err); + }), + ) + .with_upgrades(); + if let Err(err) = conn.await { + warn!("error: {:?}", err); + } + } else { + warn!("connection dropped: too many incoming"); } + + state.total_connection.fetch_sub(1, Relaxed); } fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::Certificate>> { |