diff options
Diffstat (limited to 'src/modules/cache.rs')
-rw-r--r-- | src/modules/cache.rs | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/src/modules/cache.rs b/src/modules/cache.rs new file mode 100644 index 0000000..b07bbbb --- /dev/null +++ b/src/modules/cache.rs @@ -0,0 +1,103 @@ +//! Response caching module +//! +//! Considerations: +//! - check cache header +//! - ignore responses that get too large +//! - ignore requests with body (or too large body) +//! - LRU cache pruning +//! - different backends: +//! - in memory (HashMap) +//! - on disk (redb? filesystem?) +//! - external db? +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use anyhow::Result; +use bytes::Bytes; +use http::Response; +use http_body_util::{BodyExt, Full}; +use serde::Deserialize; +use serde_yaml::Value; +use sha2::{Digest, Sha256}; +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; +use tokio::sync::RwLock; + +pub struct CacheKind; + +#[derive(Deserialize)] +struct CacheConfig { + next: DynNode, +} + +struct Cache { + entries: RwLock<HashMap<[u8; 32], Response<Bytes>>>, + config: CacheConfig, +} + +impl NodeKind for CacheKind { + fn name(&self) -> &'static str { + "cache" + } + fn instanciate(&self, config: Value) -> Result<Arc<dyn Node>> { + let config = serde_yaml::from_value::<CacheConfig>(config)?; + Ok(Arc::new(Cache { + config, + entries: HashMap::new().into(), + })) + } +} +impl Node for Cache { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + // not very fast + let mut hasher = Sha256::new(); + hasher.update(request.method().as_str().len().to_be_bytes()); + hasher.update(request.method().as_str()); + hasher.update(request.uri().path().len().to_be_bytes()); + hasher.update(request.uri().path()); + hasher.update(if request.uri().query().is_some() { + [1] + } else { + [0] + }); + if let Some(q) = request.uri().query() { + hasher.update(q.len().to_be_bytes()); + hasher.update(q); + } + hasher.update(request.headers().len().to_be_bytes()); + for (k, v) in request.headers() { + hasher.update(k.as_str().len().to_be_bytes()); + hasher.update(v.as_bytes().len().to_be_bytes()); + hasher.update(k); + hasher.update(v); + } + let key: [u8; 32] = hasher.finalize().try_into().unwrap(); + + if let Some(resp) = self.entries.read().await.get(&key) { + return Ok(resp + .to_owned() + .map(|b| Full::new(b).map_err(|e| match e {}).boxed())); + } + + let response = self.config.next.handle(context, request).await?; + + let h = response.headers().to_owned(); + let s = response.status().to_owned(); + let body = response.collect().await?.to_bytes(); + + let mut r1 = Response::new(Full::new(body.clone()).map_err(|e| match e {}).boxed()); + *r1.headers_mut() = h.clone(); + *r1.status_mut() = s.clone(); + + let mut r2 = Response::new(body); + *r2.headers_mut() = h; + *r2.status_mut() = s; + self.entries.write().await.insert(key, r2); + + Ok(r1) + }) + } +} |