/* This file is part of gnix (https://codeberg.org/metamuffin/gnix) which is licensed under the GNU Affero General Public License (version 3); see /COPYING. Copyright (C) 2025 metamuffin */ use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; use crate::{config::DynNode, error::ServiceError}; use anyhow::Result; use bytes::Bytes; use futures::{Future, StreamExt}; use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; use hyper::body::{Body, Frame}; use serde::Deserialize; use std::{ pin::Pin, sync::Arc, time::{Duration, Instant}, }; use tokio::time::sleep_until; pub struct LimitsKind; #[derive(Deserialize)] pub struct Limits { #[serde(default)] request: LimitParam, #[serde(default)] response: LimitParam, next: DynNode, } #[derive(Debug, Clone, Deserialize, Default)] struct LimitParam { size: Option, rate: Option, rate_buffer: Option, } impl NodeKind for LimitsKind { fn name(&self) -> &'static str { "limits" } fn instanciate(&self, config: serde_yml::Value) -> Result> { Ok(Arc::new(serde_yml::from_value::(config)?)) } } impl Node for Limits { fn handle<'a>( &'a self, context: &'a mut NodeContext, request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { let request = request.map(|body| limit_body(body, self.request.clone()).boxed()); let response = self.next.handle(context, request).await?; Ok(response.map(|body| limit_body(body, self.response.clone()).boxed())) }) } } fn limit_body( inner: BoxBody, config: LimitParam, ) -> impl Body { let mut stream = inner.into_data_stream(); let mut state = State::new(config); StreamBody::new(async_stream::stream! { while let Some(k) = stream.next().await { match k { Ok(data) => { if state.update(data.len()).await { yield Ok(Frame::data(data)) } else { break; } }, Err(e) => yield Err(e) } } }) } struct State { config: LimitParam, last_read: Instant, size: u64, } impl State { pub fn new(config: LimitParam) -> Self { Self { config, last_read: Instant::now(), size: 0, } } pub async fn update(&mut self, len: usize) -> bool { if let Some(rate) = self.config.rate { let buffer = Duration::from_millis(self.config.rate_buffer.unwrap_or(1000)); let el = self.last_read.elapsed(); if el > buffer { self.last_read += el - buffer } self.last_read += Duration::from_nanos((len as u64 * 1_000_000_000) / rate); sleep_until(self.last_read.into()).await } if let Some(max_size) = self.config.size { self.size += len as u64; self.size < max_size } else { true } } }