aboutsummaryrefslogtreecommitdiff
path: root/src/modules/limits.rs
blob: dcfc5085173038d6a0df8751004fef87a82be627 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*
    This file is part of gnix (https://codeberg.org/metamuffin/gnix)
    which is licensed under the GNU Affero General Public License (version 3); see /COPYING.
    Copyright (C) 2025 metamuffin <metamuffin.org>
*/
use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
use crate::{config::DynNode, error::ServiceError};
use anyhow::Result;
use bytes::Bytes;
use futures::{Future, StreamExt};
use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
use hyper::body::{Body, Frame};
use serde::Deserialize;
use std::{
    pin::Pin,
    sync::Arc,
    time::{Duration, Instant},
};
use tokio::time::sleep_until;

pub struct LimitsKind;

#[derive(Deserialize)]
pub struct Limits {
    #[serde(default)]
    request: LimitParam,
    #[serde(default)]
    response: LimitParam,
    next: DynNode,
}

#[derive(Debug, Clone, Deserialize, Default)]
struct LimitParam {
    size: Option<u64>,
    rate: Option<u64>,
    rate_buffer: Option<u64>,
}

impl NodeKind for LimitsKind {
    fn name(&self) -> &'static str {
        "limits"
    }
    fn instanciate(&self, config: serde_yml::Value) -> Result<Arc<dyn Node>> {
        Ok(Arc::new(serde_yml::from_value::<Limits>(config)?))
    }
}

impl Node for Limits {
    fn handle<'a>(
        &'a self,
        context: &'a mut NodeContext,
        request: NodeRequest,
    ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
        Box::pin(async move {
            let request = request.map(|body| limit_body(body, self.request.clone()).boxed());
            let response = self.next.handle(context, request).await?;
            Ok(response.map(|body| limit_body(body, self.response.clone()).boxed()))
        })
    }
}

fn limit_body(
    inner: BoxBody<Bytes, ServiceError>,
    config: LimitParam,
) -> impl Body<Data = Bytes, Error = ServiceError> {
    let mut stream = inner.into_data_stream();
    let mut state = State::new(config);
    StreamBody::new(async_stream::stream! {
        while let Some(k) = stream.next().await {
            match k {
                Ok(data) => {
                    if state.update(data.len()).await {
                        yield Ok(Frame::data(data))
                    } else {
                        break;
                    }
                },
                Err(e) => yield Err(e)
            }
        }
    })
}

struct State {
    config: LimitParam,
    last_read: Instant,
    size: u64,
}

impl State {
    pub fn new(config: LimitParam) -> Self {
        Self {
            config,
            last_read: Instant::now(),
            size: 0,
        }
    }
    pub async fn update(&mut self, len: usize) -> bool {
        if let Some(rate) = self.config.rate {
            let buffer = Duration::from_millis(self.config.rate_buffer.unwrap_or(1000));
            let el = self.last_read.elapsed();
            if el > buffer {
                self.last_read += el - buffer
            }
            self.last_read += Duration::from_nanos((len as u64 * 1_000_000_000) / rate);
            sleep_until(self.last_read.into()).await
        }
        if let Some(max_size) = self.config.size {
            self.size += len as u64;
            self.size < max_size
        } else {
            true
        }
    }
}