diff options
author | metamuffin <metamuffin@disroot.org> | 2023-08-28 15:02:14 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-08-28 15:02:14 +0200 |
commit | 186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5 (patch) | |
tree | 384ed6e8faaacd77b1a5f4f11a251ee228f1e927 | |
parent | 2bc557bbddb01b535dd8512fe3aadb0d4091a42d (diff) | |
download | gnix-186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5.tar gnix-186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5.tar.bz2 gnix-186bf476aeab0ff0838d1ae26a9dbcb2e40a8eb5.tar.zst |
what i invented here already existed: semaphore
-rw-r--r-- | src/config.rs | 20 | ||||
-rw-r--r-- | src/error.rs | 4 | ||||
-rw-r--r-- | src/limiter.rs | 37 | ||||
-rw-r--r-- | src/main.rs | 14 | ||||
-rw-r--r-- | src/proxy.rs | 91 |
5 files changed, 67 insertions, 99 deletions
diff --git a/src/config.rs b/src/config.rs index 89a2b54..c250b80 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,8 @@ use anyhow::Context; -use serde::{Deserialize, Serialize, Deserializer, de::{Visitor, Error, SeqAccess, value}}; +use serde::{ + de::{value, Error, SeqAccess, Visitor}, + Deserialize, Deserializer, Serialize, +}; use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf}; #[derive(Debug, Serialize, Deserialize)] @@ -27,7 +30,6 @@ pub struct HttpConfig { #[derive(Debug, Serialize, Deserialize)] pub struct HttpsConfig { - #[serde(deserialize_with = "string_or_seq")] pub bind: Vec<SocketAddr>, pub tls_cert: PathBuf, pub tls_key: PathBuf, @@ -49,7 +51,9 @@ pub struct FileserverConfig { // fall back to expecting a single string and putting that in a 1-length vector fn string_or_seq<'de, D>(des: D) -> Result<Vec<SocketAddr>, D::Error> -where D: Deserializer<'de> { +where + D: Deserializer<'de>, +{ struct StringOrList; impl<'de> Visitor<'de> for StringOrList { type Value = Vec<SocketAddr>; @@ -59,13 +63,17 @@ where D: Deserializer<'de> { } fn visit_str<E>(self, val: &str) -> Result<Vec<SocketAddr>, E> - where E: Error { + where + E: Error, + { let addr = SocketAddr::deserialize(value::StrDeserializer::new(val))?; Ok(vec![addr]) } fn visit_seq<A>(self, val: A) -> Result<Vec<SocketAddr>, A::Error> - where A: SeqAccess<'de> { + where + A: SeqAccess<'de>, + { Vec::<SocketAddr>::deserialize(value::SeqAccessDeserializer::new(val)) } } @@ -85,7 +93,7 @@ impl Default for Limits { fn default() -> Self { Self { max_incoming_connections: 1024, - max_outgoing_connections: usize::MAX, + max_outgoing_connections: 2048, } } } diff --git a/src/error.rs b/src/error.rs index bf775d1..5aae6d8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,9 @@ +use tokio::sync::TryAcquireError; + #[derive(Debug, thiserror::Error)] pub enum ServiceError { #[error("limit reached. try again")] - Limit, + Limit(#[from] TryAcquireError), #[error("hyper error")] Hyper(hyper::Error), #[error("unknown host")] diff --git a/src/limiter.rs b/src/limiter.rs deleted file mode 100644 index 97bdb5c..0000000 --- a/src/limiter.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::sync::{atomic::AtomicUsize, Arc}; - -pub struct Limiter { - limit: usize, - counter: Arc<AtomicUsize>, -} - -impl Limiter { - pub fn new(limit: usize) -> Self { - Limiter { - counter: Default::default(), - limit, - } - } - pub fn obtain(&self) -> Option<LimitLock> { - let c = self - .counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if c < self.limit { - Some(LimitLock { - counter: self.counter.clone(), - }) - } else { - None - } - } -} - -pub struct LimitLock { - counter: Arc<AtomicUsize>, -} -impl Drop for LimitLock { - fn drop(&mut self) { - self.counter - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); - } -} diff --git a/src/main.rs b/src/main.rs index 5230434..4d14bec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,6 @@ pub mod config; pub mod error; pub mod files; pub mod helper; -pub mod limiter; pub mod proxy; use crate::{ @@ -26,16 +25,15 @@ use hyper::{ service::service_fn, Request, Response, StatusCode, }; -use limiter::Limiter; use log::{debug, error, info, warn}; use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc}; -use tokio::{net::TcpListener, signal::ctrl_c}; +use tokio::{net::TcpListener, signal::ctrl_c, sync::Semaphore}; use tokio_rustls::TlsAcceptor; pub struct State { pub config: Config, - pub l_incoming: Limiter, - pub l_outgoing: Limiter, + pub l_incoming: Semaphore, + pub l_outgoing: Semaphore, } #[tokio::main] @@ -54,8 +52,8 @@ async fn main() -> anyhow::Result<()> { } }; let state = Arc::new(State { - l_incoming: Limiter::new(config.limits.max_incoming_connections), - l_outgoing: Limiter::new(config.limits.max_outgoing_connections), + l_incoming: Semaphore::new(config.limits.max_incoming_connections), + l_outgoing: Semaphore::new(config.limits.max_outgoing_connections), config, }); @@ -153,7 +151,7 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r stream: T, addr: SocketAddr, ) { - if let Some(_limit_guard) = state.l_incoming.obtain() { + if let Ok(_semaphore) = state.l_incoming.try_acquire() { let conn = http1::Builder::new() .serve_connection( stream, diff --git a/src/proxy.rs b/src/proxy.rs index 10d7e3d..d38de4d 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -48,57 +48,54 @@ pub async fn proxy_request( let do_upgrade = req.headers().contains_key(UPGRADE); let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>(); - if let Some(_limit_guard) = state.l_outgoing.obtain() { - debug!("\tforwarding to {}", backend); - let mut resp = { - let client_stream = TokioIo( - TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?, - ); - - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - warn!("connection failed: {:?}", err); - } - }); - sender - .send_request(req) + let _limit_guard = state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(backend) .await - .map_err(ServiceError::Hyper)? - }; + .map_err(|_| ServiceError::CantConnect)?, + ); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(req) + .await + .map_err(ServiceError::Hyper)? + }; - if do_upgrade { - let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>(); - tokio::task::spawn(async move { - debug!("about upgrading connection, sending empty response"); - match ( - on_upgrade_upstream.unwrap().await, - on_upgrade_downstream.unwrap().await, - ) { - (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { - debug!("upgrade successful"); - match tokio::io::copy_bidirectional( - &mut TokioIo(upgraded_downstream), - &mut TokioIo(upgraded_upstream), - ) - .await - { - Ok((from_client, from_server)) => { - debug!("proxy socket terminated: {from_server} sent, {from_client} received") - } - Err(e) => warn!("proxy socket error: {e}"), + if do_upgrade { + let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>(); + tokio::task::spawn(async move { + debug!("about upgrading connection, sending empty response"); + match ( + on_upgrade_upstream.unwrap().await, + on_upgrade_downstream.unwrap().await, + ) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") } + Err(e) => warn!("proxy socket error: {e}"), } - (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), } - }); - } - Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) - } else { - Err(ServiceError::Limit) + (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), + } + }); } + Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) } |