summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-08-28 15:02:14 +0200
committermetamuffin <metamuffin@disroot.org>2023-08-28 15:02:14 +0200
commit186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5 (patch)
tree384ed6e8faaacd77b1a5f4f11a251ee228f1e927
parent2bc557bbddb01b535dd8512fe3aadb0d4091a42d (diff)
downloadgnix-186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5.tar
gnix-186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5.tar.bz2
gnix-186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5.tar.zst
what i invented here already existed: semaphore
-rw-r--r--src/config.rs20
-rw-r--r--src/error.rs4
-rw-r--r--src/limiter.rs37
-rw-r--r--src/main.rs14
-rw-r--r--src/proxy.rs91
5 files changed, 67 insertions, 99 deletions
diff --git a/src/config.rs b/src/config.rs
index 89a2b54..c250b80 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,5 +1,8 @@
use anyhow::Context;
-use serde::{Deserialize, Serialize, Deserializer, de::{Visitor, Error, SeqAccess, value}};
+use serde::{
+ de::{value, Error, SeqAccess, Visitor},
+ Deserialize, Deserializer, Serialize,
+};
use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf};
#[derive(Debug, Serialize, Deserialize)]
@@ -27,7 +30,6 @@ pub struct HttpConfig {
#[derive(Debug, Serialize, Deserialize)]
pub struct HttpsConfig {
- #[serde(deserialize_with = "string_or_seq")]
pub bind: Vec<SocketAddr>,
pub tls_cert: PathBuf,
pub tls_key: PathBuf,
@@ -49,7 +51,9 @@ pub struct FileserverConfig {
// fall back to expecting a single string and putting that in a 1-length vector
fn string_or_seq<'de, D>(des: D) -> Result<Vec<SocketAddr>, D::Error>
-where D: Deserializer<'de> {
+where
+ D: Deserializer<'de>,
+{
struct StringOrList;
impl<'de> Visitor<'de> for StringOrList {
type Value = Vec<SocketAddr>;
@@ -59,13 +63,17 @@ where D: Deserializer<'de> {
}
fn visit_str<E>(self, val: &str) -> Result<Vec<SocketAddr>, E>
- where E: Error {
+ where
+ E: Error,
+ {
let addr = SocketAddr::deserialize(value::StrDeserializer::new(val))?;
Ok(vec![addr])
}
fn visit_seq<A>(self, val: A) -> Result<Vec<SocketAddr>, A::Error>
- where A: SeqAccess<'de> {
+ where
+ A: SeqAccess<'de>,
+ {
Vec::<SocketAddr>::deserialize(value::SeqAccessDeserializer::new(val))
}
}
@@ -85,7 +93,7 @@ impl Default for Limits {
fn default() -> Self {
Self {
max_incoming_connections: 1024,
- max_outgoing_connections: usize::MAX,
+ max_outgoing_connections: 2048,
}
}
}
diff --git a/src/error.rs b/src/error.rs
index bf775d1..5aae6d8 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,7 +1,9 @@
+use tokio::sync::TryAcquireError;
+
#[derive(Debug, thiserror::Error)]
pub enum ServiceError {
#[error("limit reached. try again")]
- Limit,
+ Limit(#[from] TryAcquireError),
#[error("hyper error")]
Hyper(hyper::Error),
#[error("unknown host")]
diff --git a/src/limiter.rs b/src/limiter.rs
deleted file mode 100644
index 97bdb5c..0000000
--- a/src/limiter.rs
+++ /dev/null
@@ -1,37 +0,0 @@
-use std::sync::{atomic::AtomicUsize, Arc};
-
-pub struct Limiter {
- limit: usize,
- counter: Arc<AtomicUsize>,
-}
-
-impl Limiter {
- pub fn new(limit: usize) -> Self {
- Limiter {
- counter: Default::default(),
- limit,
- }
- }
- pub fn obtain(&self) -> Option<LimitLock> {
- let c = self
- .counter
- .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
- if c < self.limit {
- Some(LimitLock {
- counter: self.counter.clone(),
- })
- } else {
- None
- }
- }
-}
-
-pub struct LimitLock {
- counter: Arc<AtomicUsize>,
-}
-impl Drop for LimitLock {
- fn drop(&mut self) {
- self.counter
- .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
- }
-}
diff --git a/src/main.rs b/src/main.rs
index 5230434..4d14bec 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,7 +5,6 @@ pub mod config;
pub mod error;
pub mod files;
pub mod helper;
-pub mod limiter;
pub mod proxy;
use crate::{
@@ -26,16 +25,15 @@ use hyper::{
service::service_fn,
Request, Response, StatusCode,
};
-use limiter::Limiter;
use log::{debug, error, info, warn};
use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc};
-use tokio::{net::TcpListener, signal::ctrl_c};
+use tokio::{net::TcpListener, signal::ctrl_c, sync::Semaphore};
use tokio_rustls::TlsAcceptor;
pub struct State {
pub config: Config,
- pub l_incoming: Limiter,
- pub l_outgoing: Limiter,
+ pub l_incoming: Semaphore,
+ pub l_outgoing: Semaphore,
}
#[tokio::main]
@@ -54,8 +52,8 @@ async fn main() -> anyhow::Result<()> {
}
};
let state = Arc::new(State {
- l_incoming: Limiter::new(config.limits.max_incoming_connections),
- l_outgoing: Limiter::new(config.limits.max_outgoing_connections),
+ l_incoming: Semaphore::new(config.limits.max_incoming_connections),
+ l_outgoing: Semaphore::new(config.limits.max_outgoing_connections),
config,
});
@@ -153,7 +151,7 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r
stream: T,
addr: SocketAddr,
) {
- if let Some(_limit_guard) = state.l_incoming.obtain() {
+ if let Ok(_semaphore) = state.l_incoming.try_acquire() {
let conn = http1::Builder::new()
.serve_connection(
stream,
diff --git a/src/proxy.rs b/src/proxy.rs
index 10d7e3d..d38de4d 100644
--- a/src/proxy.rs
+++ b/src/proxy.rs
@@ -48,57 +48,54 @@ pub async fn proxy_request(
let do_upgrade = req.headers().contains_key(UPGRADE);
let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>();
- if let Some(_limit_guard) = state.l_outgoing.obtain() {
- debug!("\tforwarding to {}", backend);
- let mut resp = {
- let client_stream = TokioIo(
- TcpStream::connect(backend)
- .await
- .map_err(|_| ServiceError::CantConnect)?,
- );
-
- let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
- .await
- .map_err(ServiceError::Hyper)?;
- tokio::task::spawn(async move {
- if let Err(err) = conn.await {
- warn!("connection failed: {:?}", err);
- }
- });
- sender
- .send_request(req)
+ let _limit_guard = state.l_outgoing.try_acquire()?;
+ debug!("\tforwarding to {}", backend);
+ let mut resp = {
+ let client_stream = TokioIo(
+ TcpStream::connect(backend)
.await
- .map_err(ServiceError::Hyper)?
- };
+ .map_err(|_| ServiceError::CantConnect)?,
+ );
+
+ let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
+ .await
+ .map_err(ServiceError::Hyper)?;
+ tokio::task::spawn(async move {
+ if let Err(err) = conn.await {
+ warn!("connection failed: {:?}", err);
+ }
+ });
+ sender
+ .send_request(req)
+ .await
+ .map_err(ServiceError::Hyper)?
+ };
- if do_upgrade {
- let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
- tokio::task::spawn(async move {
- debug!("about upgrading connection, sending empty response");
- match (
- on_upgrade_upstream.unwrap().await,
- on_upgrade_downstream.unwrap().await,
- ) {
- (Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
- debug!("upgrade successful");
- match tokio::io::copy_bidirectional(
- &mut TokioIo(upgraded_downstream),
- &mut TokioIo(upgraded_upstream),
- )
- .await
- {
- Ok((from_client, from_server)) => {
- debug!("proxy socket terminated: {from_server} sent, {from_client} received")
- }
- Err(e) => warn!("proxy socket error: {e}"),
+ if do_upgrade {
+ let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
+ tokio::task::spawn(async move {
+ debug!("about upgrading connection, sending empty response");
+ match (
+ on_upgrade_upstream.unwrap().await,
+ on_upgrade_downstream.unwrap().await,
+ ) {
+ (Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
+ debug!("upgrade successful");
+ match tokio::io::copy_bidirectional(
+ &mut TokioIo(upgraded_downstream),
+ &mut TokioIo(upgraded_upstream),
+ )
+ .await
+ {
+ Ok((from_client, from_server)) => {
+ debug!("proxy socket terminated: {from_server} sent, {from_client} received")
}
+ Err(e) => warn!("proxy socket error: {e}"),
}
- (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"),
}
- });
- }
- Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed()))
- } else {
- Err(ServiceError::Limit)
+ (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"),
+ }
+ });
}
+ Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed()))
}