aboutsummaryrefslogtreecommitdiff
path: root/src/modules/ratelimit.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/modules/ratelimit.rs')
-rw-r--r--src/modules/ratelimit.rs176
1 files changed, 176 insertions, 0 deletions
diff --git a/src/modules/ratelimit.rs b/src/modules/ratelimit.rs
new file mode 100644
index 0000000..9e5d583
--- /dev/null
+++ b/src/modules/ratelimit.rs
@@ -0,0 +1,176 @@
+/*
+ This file is part of gnix (https://codeberg.org/metamuffin/gnix)
+ which is licensed under the GNU Affero General Public License (version 3); see /COPYING.
+ Copyright (C) 2025 metamuffin <metamuffin.org>
+*/
+use super::{cgi::set_cgi_variables, Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
+use crate::{config::DynNode, error::ServiceError};
+use anyhow::Result;
+use futures::Future;
+use http::{header::RETRY_AFTER, HeaderValue, Response, StatusCode};
+use http_body_util::{combinators::BoxBody, BodyExt};
+use log::{error, warn};
+use serde::Deserialize;
+use std::{
+ collections::HashMap,
+ hash::{DefaultHasher, Hash, Hasher},
+ net::IpAddr,
+ pin::Pin,
+ process::Stdio,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+use tokio::{process::Command, spawn, sync::Mutex, time::sleep_until};
+
+pub struct RatelimitKind;
+
+#[derive(Deserialize)]
+pub struct RatelimitConfig {
+ next: DynNode,
+ #[serde(default)]
+ identity: IdentityMode,
+ #[serde(default = "default_max_identities")]
+ max_identities: usize,
+ reference_duration: f32,
+ thresholds: Vec<(usize, LimitMode)>,
+}
+fn default_max_identities() -> usize {
+ 1 << 16
+}
+
+#[derive(Deserialize, Default)]
+#[serde(rename_all = "snake_case")]
+enum IdentityMode {
+ Global,
+ #[default]
+ SourceAddress,
+ SourceAddressTrunc {
+ v4: u8,
+ v6: u8,
+ },
+ Path,
+ PathQuery,
+}
+
+#[derive(Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum LimitMode {
+ TooManyRequests,
+ Exec(String),
+}
+
+pub struct Ratelimit {
+ state: Arc<Mutex<HashMap<u64, IdentityState>>>,
+ config: RatelimitConfig,
+}
+
+struct IdentityState {
+ counter: usize,
+ frame_end: Instant,
+}
+
+impl NodeKind for RatelimitKind {
+ fn name(&self) -> &'static str {
+ "ratelimit"
+ }
+ fn instanciate(&self, config: serde_yml::Value) -> Result<Arc<dyn Node>> {
+ Ok(Arc::new(Ratelimit {
+ state: Arc::new(Mutex::new(HashMap::new())),
+ config: serde_yml::from_value::<RatelimitConfig>(config)?,
+ }))
+ }
+}
+
+impl Node for Ratelimit {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ let identity_hash = match self.config.identity {
+ IdentityMode::Global => 0,
+ IdentityMode::SourceAddress => hash(context.addr.ip()),
+ IdentityMode::SourceAddressTrunc { v4, v6 } => match context.addr.ip() {
+ IpAddr::V4(a) => hash(a.to_bits() >> v4),
+ IpAddr::V6(a) => hash(a.to_bits() >> v6),
+ },
+ IdentityMode::Path => hash(request.uri().path()),
+ IdentityMode::PathQuery => hash(request.uri().path_and_query()),
+ };
+
+ let now = Instant::now();
+
+ let (counter, frame_end) = {
+ let mut state = self.state.lock().await;
+ if state.len() > self.config.max_identities {
+ return Err(ServiceError::TooManyIdentities);
+ }
+ let istate = state.entry(identity_hash).or_insert_with(|| {
+ let frame_end = now + Duration::from_secs_f32(self.config.reference_duration);
+ let state = self.state.clone();
+ //? How efficient is this? Does it scale better to have a central task?
+ spawn(async move {
+ sleep_until(frame_end.into()).await;
+ state.lock().await.remove(&identity_hash)
+ });
+ IdentityState {
+ counter: 0,
+ frame_end,
+ }
+ });
+ istate.counter += 1;
+ (istate.counter, istate.frame_end)
+ };
+
+ for (thres, l) in &self.config.thresholds {
+ match l {
+ LimitMode::TooManyRequests => {
+ if counter > *thres {
+ let mut r = Response::new(BoxBody::new(
+ http_body_util::Empty::new().map_err(|x| match x {}),
+ ));
+ *r.status_mut() = StatusCode::TOO_MANY_REQUESTS;
+ r.headers_mut().insert(
+ RETRY_AFTER,
+ HeaderValue::from_str(&(frame_end - now).as_secs().to_string())
+ .unwrap(),
+ );
+ return Ok(r);
+ }
+ }
+ LimitMode::Exec(path) => {
+ // Exact comparison so it can only trigger once per frame
+ if counter == *thres {
+ let mut command = Command::new(&path);
+ command.stdin(Stdio::null());
+ set_cgi_variables(&mut command, &request, context);
+ spawn(async move {
+ let mut child = match command.spawn() {
+ Ok(c) => c,
+ Err(e) => return error!("exec limiter spawn failed: {e}"),
+ };
+ match child.wait().await {
+ Ok(s) if s.success() => (),
+ Ok(s) => warn!(
+ "exec limiter failed with code {}",
+ s.code().unwrap_or_default()
+ ),
+ Err(e) => warn!("exec limiter failed: {e}"),
+ }
+ });
+ }
+ }
+ }
+ }
+
+ self.config.next.handle(context, request).await
+ })
+ }
+}
+
+fn hash(value: impl Hash) -> u64 {
+ let mut h = DefaultHasher::default();
+ value.hash(&mut h);
+ h.finish()
+}