diff options
author | metamuffin <metamuffin@disroot.org> | 2023-04-06 20:23:00 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-04-06 20:23:00 +0200 |
commit | 56fb681279b2f2221eef933617d521469c6e6d83 (patch) | |
tree | 36e0fc43872e2af4e1c51b72e989d01698df1fde /src | |
parent | c3c3a07cae6a938534824c32573927dd7a5ece4b (diff) | |
download | gnix-56fb681279b2f2221eef933617d521469c6e6d83.tar gnix-56fb681279b2f2221eef933617d521469c6e6d83.tar.bz2 gnix-56fb681279b2f2221eef933617d521469c6e6d83.tar.zst |
apply limits
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 3 | ||||
-rw-r--r-- | src/error.rs | 2 | ||||
-rw-r--r-- | src/limiter.rs | 37 | ||||
-rw-r--r-- | src/main.rs | 28 | ||||
-rw-r--r-- | src/proxy.rs | 91 |
5 files changed, 97 insertions, 64 deletions
diff --git a/src/config.rs b/src/config.rs index 68479e4..9e542d7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,10 +13,9 @@ pub struct Config { } #[derive(Debug, Serialize, Deserialize)] +#[serde(default)] pub struct Limits { - #[serde(default)] pub max_incoming_connections: usize, - #[serde(default)] pub max_outgoing_connections: usize, } diff --git a/src/error.rs b/src/error.rs index 9a83aff..bf775d1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,7 @@ #[derive(Debug, thiserror::Error)] pub enum ServiceError { + #[error("limit reached. try again")] + Limit, #[error("hyper error")] Hyper(hyper::Error), #[error("unknown host")] diff --git a/src/limiter.rs b/src/limiter.rs new file mode 100644 index 0000000..97bdb5c --- /dev/null +++ b/src/limiter.rs @@ -0,0 +1,37 @@ +use std::sync::{atomic::AtomicUsize, Arc}; + +pub struct Limiter { + limit: usize, + counter: Arc<AtomicUsize>, +} + +impl Limiter { + pub fn new(limit: usize) -> Self { + Limiter { + counter: Default::default(), + limit, + } + } + pub fn obtain(&self) -> Option<LimitLock> { + let c = self + .counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if c < self.limit { + Some(LimitLock { + counter: self.counter.clone(), + }) + } else { + None + } + } +} + +pub struct LimitLock { + counter: Arc<AtomicUsize>, +} +impl Drop for LimitLock { + fn drop(&mut self) { + self.counter + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + } +} diff --git a/src/main.rs b/src/main.rs index f55771c..f51cd67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ pub mod config; pub mod error; pub mod files; +pub mod limiter; pub mod proxy; use crate::{ @@ -22,18 +23,9 @@ use hyper::{ service::service_fn, Request, Response, StatusCode, }; +use limiter::Limiter; use log::{debug, error, info, warn}; -use std::{ - fs::File, - io::BufReader, - net::SocketAddr, - path::Path, - process::exit, - sync::{ - atomic::{AtomicUsize, Ordering::Relaxed}, - Arc, - }, -}; +use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpListener, @@ -43,7 +35,8 @@ use tokio_rustls::TlsAcceptor; pub struct State { pub config: Config, - pub total_connection: AtomicUsize, + pub l_incoming: Limiter, + pub l_outgoing: Limiter, } #[tokio::main] @@ -62,8 +55,9 @@ async fn main() -> anyhow::Result<()> { } }; let state = Arc::new(State { + l_incoming: Limiter::new(config.limits.max_incoming_connections), + l_outgoing: Limiter::new(config.limits.max_outgoing_connections), config, - total_connection: AtomicUsize::new(0), }); { @@ -145,9 +139,7 @@ pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>( stream: T, addr: SocketAddr, ) { - let conns = state.total_connection.fetch_add(1, Relaxed); - - if conns >= state.config.limits.max_incoming_connections { + if let Some(_limit_guard) = state.l_incoming.obtain() { let conn = http1::Builder::new() .serve_connection( stream, @@ -177,8 +169,6 @@ pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>( } else { warn!("connection dropped: too many incoming"); } - - state.total_connection.fetch_sub(1, Relaxed); } fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::Certificate>> { @@ -216,7 +206,7 @@ async fn service( .ok_or(ServiceError::NoHost)?; let mut resp = match route { - HostConfig::Backend { backend } => proxy_request(req, addr, backend).await, + HostConfig::Backend { backend } => proxy_request(&state, req, addr, backend).await, HostConfig::Files { files } => serve_files(req, files).await, }?; diff --git a/src/proxy.rs b/src/proxy.rs index 036bbc1..7070500 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,4 +1,4 @@ -use crate::ServiceError; +use crate::{ServiceError, State}; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, @@ -11,10 +11,11 @@ use hyper::{ Request, Uri, }; use log::{debug, error, warn}; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpStream; pub async fn proxy_request( + state: &Arc<State>, mut req: Request<Incoming>, addr: SocketAddr, backend: &SocketAddr, @@ -47,51 +48,55 @@ pub async fn proxy_request( let do_upgrade = req.headers().contains_key(UPGRADE); let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>(); - debug!("\tforwarding to {}", backend); - let mut resp = { - let client_stream = TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?; + if let Some(_limit_guard) = state.l_outgoing.obtain() { + debug!("\tforwarding to {}", backend); + let mut resp = { + let client_stream = TcpStream::connect(backend) + .await + .map_err(|_| ServiceError::CantConnect)?; - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - warn!("connection failed: {:?}", err); - } - }); - sender - .send_request(req) - .await - .map_err(ServiceError::Hyper)? - }; + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(req) + .await + .map_err(ServiceError::Hyper)? + }; - if do_upgrade { - let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>(); - tokio::task::spawn(async move { - debug!("about upgrading connection, sending empty response"); - match ( - on_upgrade_upstream.unwrap().await, - on_upgrade_downstream.unwrap().await, - ) { - (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => { - debug!("upgrade successful"); - match tokio::io::copy_bidirectional( - &mut upgraded_downstream, - &mut upgraded_upstream, - ) - .await - { - Ok((from_client, from_server)) => { - debug!("proxy socket terminated: {from_server} sent, {from_client} received") + if do_upgrade { + let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>(); + tokio::task::spawn(async move { + debug!("about upgrading connection, sending empty response"); + match ( + on_upgrade_upstream.unwrap().await, + on_upgrade_downstream.unwrap().await, + ) { + (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut upgraded_downstream, + &mut upgraded_upstream, + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), } - Err(e) => warn!("proxy socket error: {e}"), } + (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), } - (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), - } - }); + }); + } + Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) + } else { + Err(ServiceError::Limit) } - Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) } |