aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-04-06 20:23:00 +0200
committermetamuffin <metamuffin@disroot.org>2023-04-06 20:23:00 +0200
commit56fb681279b2f2221eef933617d521469c6e6d83 (patch)
tree36e0fc43872e2af4e1c51b72e989d01698df1fde
parentc3c3a07cae6a938534824c32573927dd7a5ece4b (diff)
downloadgnix-56fb681279b2f2221eef933617d521469c6e6d83.tar
gnix-56fb681279b2f2221eef933617d521469c6e6d83.tar.bz2
gnix-56fb681279b2f2221eef933617d521469c6e6d83.tar.zst
apply limits
-rw-r--r--src/config.rs3
-rw-r--r--src/error.rs2
-rw-r--r--src/limiter.rs37
-rw-r--r--src/main.rs28
-rw-r--r--src/proxy.rs91
5 files changed, 97 insertions, 64 deletions
diff --git a/src/config.rs b/src/config.rs
index 68479e4..9e542d7 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -13,10 +13,9 @@ pub struct Config {
}
#[derive(Debug, Serialize, Deserialize)]
+#[serde(default)]
pub struct Limits {
- #[serde(default)]
pub max_incoming_connections: usize,
- #[serde(default)]
pub max_outgoing_connections: usize,
}
diff --git a/src/error.rs b/src/error.rs
index 9a83aff..bf775d1 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,5 +1,7 @@
#[derive(Debug, thiserror::Error)]
pub enum ServiceError {
+ #[error("limit reached. try again")]
+ Limit,
#[error("hyper error")]
Hyper(hyper::Error),
#[error("unknown host")]
diff --git a/src/limiter.rs b/src/limiter.rs
new file mode 100644
index 0000000..97bdb5c
--- /dev/null
+++ b/src/limiter.rs
@@ -0,0 +1,37 @@
+use std::sync::{atomic::AtomicUsize, Arc};
+
+pub struct Limiter {
+ limit: usize,
+ counter: Arc<AtomicUsize>,
+}
+
+impl Limiter {
+ pub fn new(limit: usize) -> Self {
+ Limiter {
+ counter: Default::default(),
+ limit,
+ }
+ }
+ pub fn obtain(&self) -> Option<LimitLock> {
+ let c = self
+ .counter
+ .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
+ if c < self.limit {
+ Some(LimitLock {
+ counter: self.counter.clone(),
+ })
+ } else {
+ None
+ }
+ }
+}
+
+pub struct LimitLock {
+ counter: Arc<AtomicUsize>,
+}
+impl Drop for LimitLock {
+ fn drop(&mut self) {
+ self.counter
+ .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index f55771c..f51cd67 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -4,6 +4,7 @@
pub mod config;
pub mod error;
pub mod files;
+pub mod limiter;
pub mod proxy;
use crate::{
@@ -22,18 +23,9 @@ use hyper::{
service::service_fn,
Request, Response, StatusCode,
};
+use limiter::Limiter;
use log::{debug, error, info, warn};
-use std::{
- fs::File,
- io::BufReader,
- net::SocketAddr,
- path::Path,
- process::exit,
- sync::{
- atomic::{AtomicUsize, Ordering::Relaxed},
- Arc,
- },
-};
+use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpListener,
@@ -43,7 +35,8 @@ use tokio_rustls::TlsAcceptor;
pub struct State {
pub config: Config,
- pub total_connection: AtomicUsize,
+ pub l_incoming: Limiter,
+ pub l_outgoing: Limiter,
}
#[tokio::main]
@@ -62,8 +55,9 @@ async fn main() -> anyhow::Result<()> {
}
};
let state = Arc::new(State {
+ l_incoming: Limiter::new(config.limits.max_incoming_connections),
+ l_outgoing: Limiter::new(config.limits.max_outgoing_connections),
config,
- total_connection: AtomicUsize::new(0),
});
{
@@ -145,9 +139,7 @@ pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
stream: T,
addr: SocketAddr,
) {
- let conns = state.total_connection.fetch_add(1, Relaxed);
-
- if conns >= state.config.limits.max_incoming_connections {
+ if let Some(_limit_guard) = state.l_incoming.obtain() {
let conn = http1::Builder::new()
.serve_connection(
stream,
@@ -177,8 +169,6 @@ pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
} else {
warn!("connection dropped: too many incoming");
}
-
- state.total_connection.fetch_sub(1, Relaxed);
}
fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::Certificate>> {
@@ -216,7 +206,7 @@ async fn service(
.ok_or(ServiceError::NoHost)?;
let mut resp = match route {
- HostConfig::Backend { backend } => proxy_request(req, addr, backend).await,
+ HostConfig::Backend { backend } => proxy_request(&state, req, addr, backend).await,
HostConfig::Files { files } => serve_files(req, files).await,
}?;
diff --git a/src/proxy.rs b/src/proxy.rs
index 036bbc1..7070500 100644
--- a/src/proxy.rs
+++ b/src/proxy.rs
@@ -1,4 +1,4 @@
-use crate::ServiceError;
+use crate::{ServiceError, State};
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
body::Incoming,
@@ -11,10 +11,11 @@ use hyper::{
Request, Uri,
};
use log::{debug, error, warn};
-use std::net::SocketAddr;
+use std::{net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;
pub async fn proxy_request(
+ state: &Arc<State>,
mut req: Request<Incoming>,
addr: SocketAddr,
backend: &SocketAddr,
@@ -47,51 +48,55 @@ pub async fn proxy_request(
let do_upgrade = req.headers().contains_key(UPGRADE);
let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>();
- debug!("\tforwarding to {}", backend);
- let mut resp = {
- let client_stream = TcpStream::connect(backend)
- .await
- .map_err(|_| ServiceError::CantConnect)?;
+ if let Some(_limit_guard) = state.l_outgoing.obtain() {
+ debug!("\tforwarding to {}", backend);
+ let mut resp = {
+ let client_stream = TcpStream::connect(backend)
+ .await
+ .map_err(|_| ServiceError::CantConnect)?;
- let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
- .await
- .map_err(ServiceError::Hyper)?;
- tokio::task::spawn(async move {
- if let Err(err) = conn.await {
- warn!("connection failed: {:?}", err);
- }
- });
- sender
- .send_request(req)
- .await
- .map_err(ServiceError::Hyper)?
- };
+ let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
+ .await
+ .map_err(ServiceError::Hyper)?;
+ tokio::task::spawn(async move {
+ if let Err(err) = conn.await {
+ warn!("connection failed: {:?}", err);
+ }
+ });
+ sender
+ .send_request(req)
+ .await
+ .map_err(ServiceError::Hyper)?
+ };
- if do_upgrade {
- let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
- tokio::task::spawn(async move {
- debug!("about upgrading connection, sending empty response");
- match (
- on_upgrade_upstream.unwrap().await,
- on_upgrade_downstream.unwrap().await,
- ) {
- (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => {
- debug!("upgrade successful");
- match tokio::io::copy_bidirectional(
- &mut upgraded_downstream,
- &mut upgraded_upstream,
- )
- .await
- {
- Ok((from_client, from_server)) => {
- debug!("proxy socket terminated: {from_server} sent, {from_client} received")
+ if do_upgrade {
+ let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
+ tokio::task::spawn(async move {
+ debug!("about upgrading connection, sending empty response");
+ match (
+ on_upgrade_upstream.unwrap().await,
+ on_upgrade_downstream.unwrap().await,
+ ) {
+ (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => {
+ debug!("upgrade successful");
+ match tokio::io::copy_bidirectional(
+ &mut upgraded_downstream,
+ &mut upgraded_upstream,
+ )
+ .await
+ {
+ Ok((from_client, from_server)) => {
+ debug!("proxy socket terminated: {from_server} sent, {from_client} received")
+ }
+ Err(e) => warn!("proxy socket error: {e}"),
}
- Err(e) => warn!("proxy socket error: {e}"),
}
+ (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"),
}
- (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"),
- }
- });
+ });
+ }
+ Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed()))
+ } else {
+ Err(ServiceError::Limit)
}
- Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed()))
}