summaryrefslogtreecommitdiff
path: root/src/modules/cache.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/modules/cache.rs')
-rw-r--r--src/modules/cache.rs103
1 files changed, 103 insertions, 0 deletions
diff --git a/src/modules/cache.rs b/src/modules/cache.rs
new file mode 100644
index 0000000..b07bbbb
--- /dev/null
+++ b/src/modules/cache.rs
@@ -0,0 +1,103 @@
+//! Response caching module
+//!
+//! Considerations:
+//! - check cache header
+//! - ignore responses that get too large
+//! - ignore requests with body (or too large body)
+//! - LRU cache pruning
+//! - different backends:
+//! - in memory (HashMap)
+//! - on disk (redb? filesystem?)
+//! - external db?
+use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
+use crate::{config::DynNode, error::ServiceError};
+use anyhow::Result;
+use bytes::Bytes;
+use http::Response;
+use http_body_util::{BodyExt, Full};
+use serde::Deserialize;
+use serde_yaml::Value;
+use sha2::{Digest, Sha256};
+use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
+use tokio::sync::RwLock;
+
+pub struct CacheKind;
+
+#[derive(Deserialize)]
+struct CacheConfig {
+ next: DynNode,
+}
+
+struct Cache {
+ entries: RwLock<HashMap<[u8; 32], Response<Bytes>>>,
+ config: CacheConfig,
+}
+
+impl NodeKind for CacheKind {
+ fn name(&self) -> &'static str {
+ "cache"
+ }
+ fn instanciate(&self, config: Value) -> Result<Arc<dyn Node>> {
+ let config = serde_yaml::from_value::<CacheConfig>(config)?;
+ Ok(Arc::new(Cache {
+ config,
+ entries: HashMap::new().into(),
+ }))
+ }
+}
+impl Node for Cache {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ // not very fast
+ let mut hasher = Sha256::new();
+ hasher.update(request.method().as_str().len().to_be_bytes());
+ hasher.update(request.method().as_str());
+ hasher.update(request.uri().path().len().to_be_bytes());
+ hasher.update(request.uri().path());
+ hasher.update(if request.uri().query().is_some() {
+ [1]
+ } else {
+ [0]
+ });
+ if let Some(q) = request.uri().query() {
+ hasher.update(q.len().to_be_bytes());
+ hasher.update(q);
+ }
+ hasher.update(request.headers().len().to_be_bytes());
+ for (k, v) in request.headers() {
+ hasher.update(k.as_str().len().to_be_bytes());
+ hasher.update(v.as_bytes().len().to_be_bytes());
+ hasher.update(k);
+ hasher.update(v);
+ }
+ let key: [u8; 32] = hasher.finalize().try_into().unwrap();
+
+ if let Some(resp) = self.entries.read().await.get(&key) {
+ return Ok(resp
+ .to_owned()
+ .map(|b| Full::new(b).map_err(|e| match e {}).boxed()));
+ }
+
+ let response = self.config.next.handle(context, request).await?;
+
+ let h = response.headers().to_owned();
+ let s = response.status().to_owned();
+ let body = response.collect().await?.to_bytes();
+
+ let mut r1 = Response::new(Full::new(body.clone()).map_err(|e| match e {}).boxed());
+ *r1.headers_mut() = h.clone();
+ *r1.status_mut() = s.clone();
+
+ let mut r2 = Response::new(body);
+ *r2.headers_mut() = h;
+ *r2.status_mut() = s;
+ self.entries.write().await.insert(key, r2);
+
+ Ok(r1)
+ })
+ }
+}