diff options
Diffstat (limited to 'src/modules/ratelimit.rs')
-rw-r--r-- | src/modules/ratelimit.rs | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/src/modules/ratelimit.rs b/src/modules/ratelimit.rs new file mode 100644 index 0000000..9e5d583 --- /dev/null +++ b/src/modules/ratelimit.rs @@ -0,0 +1,176 @@ +/* + This file is part of gnix (https://codeberg.org/metamuffin/gnix) + which is licensed under the GNU Affero General Public License (version 3); see /COPYING. + Copyright (C) 2025 metamuffin <metamuffin.org> +*/ +use super::{cgi::set_cgi_variables, Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use anyhow::Result; +use futures::Future; +use http::{header::RETRY_AFTER, HeaderValue, Response, StatusCode}; +use http_body_util::{combinators::BoxBody, BodyExt}; +use log::{error, warn}; +use serde::Deserialize; +use std::{ + collections::HashMap, + hash::{DefaultHasher, Hash, Hasher}, + net::IpAddr, + pin::Pin, + process::Stdio, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::{process::Command, spawn, sync::Mutex, time::sleep_until}; + +pub struct RatelimitKind; + +#[derive(Deserialize)] +pub struct RatelimitConfig { + next: DynNode, + #[serde(default)] + identity: IdentityMode, + #[serde(default = "default_max_identities")] + max_identities: usize, + reference_duration: f32, + thresholds: Vec<(usize, LimitMode)>, +} +fn default_max_identities() -> usize { + 1 << 16 +} + +#[derive(Deserialize, Default)] +#[serde(rename_all = "snake_case")] +enum IdentityMode { + Global, + #[default] + SourceAddress, + SourceAddressTrunc { + v4: u8, + v6: u8, + }, + Path, + PathQuery, +} + +#[derive(Deserialize)] +#[serde(rename_all = "snake_case")] +enum LimitMode { + TooManyRequests, + Exec(String), +} + +pub struct Ratelimit { + state: Arc<Mutex<HashMap<u64, IdentityState>>>, + config: RatelimitConfig, +} + +struct IdentityState { + counter: usize, + frame_end: Instant, +} + +impl NodeKind for RatelimitKind { + fn name(&self) -> &'static str { + "ratelimit" + } + fn instanciate(&self, config: serde_yml::Value) -> Result<Arc<dyn Node>> { + Ok(Arc::new(Ratelimit { + state: Arc::new(Mutex::new(HashMap::new())), + config: serde_yml::from_value::<RatelimitConfig>(config)?, + })) + } +} + +impl Node for Ratelimit { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let identity_hash = match self.config.identity { + IdentityMode::Global => 0, + IdentityMode::SourceAddress => hash(context.addr.ip()), + IdentityMode::SourceAddressTrunc { v4, v6 } => match context.addr.ip() { + IpAddr::V4(a) => hash(a.to_bits() >> v4), + IpAddr::V6(a) => hash(a.to_bits() >> v6), + }, + IdentityMode::Path => hash(request.uri().path()), + IdentityMode::PathQuery => hash(request.uri().path_and_query()), + }; + + let now = Instant::now(); + + let (counter, frame_end) = { + let mut state = self.state.lock().await; + if state.len() > self.config.max_identities { + return Err(ServiceError::TooManyIdentities); + } + let istate = state.entry(identity_hash).or_insert_with(|| { + let frame_end = now + Duration::from_secs_f32(self.config.reference_duration); + let state = self.state.clone(); + //? How efficient is this? Does it scale better to have a central task? + spawn(async move { + sleep_until(frame_end.into()).await; + state.lock().await.remove(&identity_hash) + }); + IdentityState { + counter: 0, + frame_end, + } + }); + istate.counter += 1; + (istate.counter, istate.frame_end) + }; + + for (thres, l) in &self.config.thresholds { + match l { + LimitMode::TooManyRequests => { + if counter > *thres { + let mut r = Response::new(BoxBody::new( + http_body_util::Empty::new().map_err(|x| match x {}), + )); + *r.status_mut() = StatusCode::TOO_MANY_REQUESTS; + r.headers_mut().insert( + RETRY_AFTER, + HeaderValue::from_str(&(frame_end - now).as_secs().to_string()) + .unwrap(), + ); + return Ok(r); + } + } + LimitMode::Exec(path) => { + // Exact comparison so it can only trigger once per frame + if counter == *thres { + let mut command = Command::new(&path); + command.stdin(Stdio::null()); + set_cgi_variables(&mut command, &request, context); + spawn(async move { + let mut child = match command.spawn() { + Ok(c) => c, + Err(e) => return error!("exec limiter spawn failed: {e}"), + }; + match child.wait().await { + Ok(s) if s.success() => (), + Ok(s) => warn!( + "exec limiter failed with code {}", + s.code().unwrap_or_default() + ), + Err(e) => warn!("exec limiter failed: {e}"), + } + }); + } + } + } + } + + self.config.next.handle(context, request).await + }) + } +} + +fn hash(value: impl Hash) -> u64 { + let mut h = DefaultHasher::default(); + value.hash(&mut h); + h.finish() +} |