diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/modules/auth/openid.rs | 4 | ||||
-rw-r--r-- | src/modules/cache.rs | 103 | ||||
-rw-r--r-- | src/modules/cgi.rs | 22 | ||||
-rw-r--r-- | src/modules/debug.rs | 2 | ||||
-rw-r--r-- | src/modules/loadbalance.rs | 62 | ||||
-rw-r--r-- | src/modules/mod.rs | 4 |
6 files changed, 191 insertions, 6 deletions
diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs index a8d9d6e..7253ba8 100644 --- a/src/modules/auth/openid.rs +++ b/src/modules/auth/openid.rs @@ -42,6 +42,7 @@ pub struct OpenIDAuth { client_id: String, authorize_endpoint: String, token_endpoint: String, + scope: String, next: DynNode, } @@ -153,13 +154,14 @@ id_token={id_token:?}"# let redirect_uri = redirect_uri(&request)?.to_string(); let uri = format!( - "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid magic", + "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid {}", self.authorize_endpoint, utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), hex::encode(verif_cipher), utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), base64::engine::general_purpose::URL_SAFE.encode(chal).trim_end_matches("="), + utf8_percent_encode(&self.scope, NON_ALPHANUMERIC), ); info!("redirect {uri:?}"); let mut resp = diff --git a/src/modules/cache.rs b/src/modules/cache.rs new file mode 100644 index 0000000..b07bbbb --- /dev/null +++ b/src/modules/cache.rs @@ -0,0 +1,103 @@ +//! Response caching module +//! +//! Considerations: +//! - check cache header +//! - ignore responses that get too large +//! - ignore requests with body (or too large body) +//! - LRU cache pruning +//! - different backends: +//! - in memory (HashMap) +//! - on disk (redb? filesystem?) +//! - external db? +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use anyhow::Result; +use bytes::Bytes; +use http::Response; +use http_body_util::{BodyExt, Full}; +use serde::Deserialize; +use serde_yaml::Value; +use sha2::{Digest, Sha256}; +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; +use tokio::sync::RwLock; + +pub struct CacheKind; + +#[derive(Deserialize)] +struct CacheConfig { + next: DynNode, +} + +struct Cache { + entries: RwLock<HashMap<[u8; 32], Response<Bytes>>>, + config: CacheConfig, +} + +impl NodeKind for CacheKind { + fn name(&self) -> &'static str { + "cache" + } + fn instanciate(&self, config: Value) -> Result<Arc<dyn Node>> { + let config = serde_yaml::from_value::<CacheConfig>(config)?; + Ok(Arc::new(Cache { + config, + entries: HashMap::new().into(), + })) + } +} +impl Node for Cache { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + // not very fast + let mut hasher = Sha256::new(); + hasher.update(request.method().as_str().len().to_be_bytes()); + hasher.update(request.method().as_str()); + hasher.update(request.uri().path().len().to_be_bytes()); + hasher.update(request.uri().path()); + hasher.update(if request.uri().query().is_some() { + [1] + } else { + [0] + }); + if let Some(q) = request.uri().query() { + hasher.update(q.len().to_be_bytes()); + hasher.update(q); + } + hasher.update(request.headers().len().to_be_bytes()); + for (k, v) in request.headers() { + hasher.update(k.as_str().len().to_be_bytes()); + hasher.update(v.as_bytes().len().to_be_bytes()); + hasher.update(k); + hasher.update(v); + } + let key: [u8; 32] = hasher.finalize().try_into().unwrap(); + + if let Some(resp) = self.entries.read().await.get(&key) { + return Ok(resp + .to_owned() + .map(|b| Full::new(b).map_err(|e| match e {}).boxed())); + } + + let response = self.config.next.handle(context, request).await?; + + let h = response.headers().to_owned(); + let s = response.status().to_owned(); + let body = response.collect().await?.to_bytes(); + + let mut r1 = Response::new(Full::new(body.clone()).map_err(|e| match e {}).boxed()); + *r1.headers_mut() = h.clone(); + *r1.status_mut() = s.clone(); + + let mut r2 = Response::new(body); + *r2.headers_mut() = h; + *r2.status_mut() = s; + self.entries.write().await.insert(key, r2); + + Ok(r1) + }) + } +} diff --git a/src/modules/cgi.rs b/src/modules/cgi.rs index b6f1033..9d9372f 100644 --- a/src/modules/cgi.rs +++ b/src/modules/cgi.rs @@ -2,7 +2,7 @@ use super::{Node, NodeKind, NodeResponse}; use crate::error::ServiceError; use anyhow::{anyhow, Result}; use futures::TryStreamExt; -use http_body_util::{combinators::BoxBody, StreamBody}; +use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; use hyper::{ body::Frame, header::{HeaderName, HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}, @@ -10,12 +10,15 @@ use hyper::{ }; use serde::Deserialize; use serde_yaml::Value; -use std::{future::Future, path::PathBuf, pin::Pin, process::Stdio, str::FromStr, sync::Arc}; +use std::{ + future::Future, io::ErrorKind, path::PathBuf, pin::Pin, process::Stdio, str::FromStr, sync::Arc, +}; use tokio::{ - io::{AsyncBufReadExt, BufReader}, + io::{copy, AsyncBufReadExt, BufReader, BufWriter}, process::Command, + spawn, }; -use tokio_util::io::ReaderStream; +use tokio_util::io::{ReaderStream, StreamReader}; use users::get_user_by_name; pub struct CgiKind; @@ -107,6 +110,17 @@ impl Node for Cgi { let mut child = command.spawn()?; let mut stdout = BufReader::new(child.stdout.take().unwrap()); + let mut stdin = BufWriter::new(child.stdin.take().unwrap()); + + // TODO prevent abuse + let mut body = StreamReader::new( + request + .into_body() + .into_data_stream() + .map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "asd")), + ); + spawn(async move { copy(&mut body, &mut stdin).await }); + let mut line = String::new(); let mut response = Response::new(()); loop { diff --git a/src/modules/debug.rs b/src/modules/debug.rs index 3ab03ec..04f9806 100644 --- a/src/modules/debug.rs +++ b/src/modules/debug.rs @@ -29,7 +29,7 @@ impl Node for Debug { ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { Box::pin(async move { let s = format!( - "address: {:?}\nverion: {:?}\nuri: {:?}\nheaders: {:#?}", + "address: {:?}\nversion: {:?}\nuri: {:?}\nheaders: {:#?}", context.addr, request.version(), request.uri(), diff --git a/src/modules/loadbalance.rs b/src/modules/loadbalance.rs new file mode 100644 index 0000000..5358b03 --- /dev/null +++ b/src/modules/loadbalance.rs @@ -0,0 +1,62 @@ +//! Load balancing module +//! +//! Given a set of handlers, the handler that is the least busy will handle the next request. +//! Current implementation does not scale well for many handlers. +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use anyhow::Result; +use serde::Deserialize; +use serde_yaml::Value; +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +pub struct LoadBalanceKind; + +#[derive(Deserialize)] +struct LoadBalanceConfig(Vec<DynNode>); + +struct LoadBalance { + load: Vec<AtomicUsize>, + config: LoadBalanceConfig, +} + +impl NodeKind for LoadBalanceKind { + fn name(&self) -> &'static str { + "loadbalance" + } + fn instanciate(&self, config: Value) -> Result<Arc<dyn Node>> { + let config = serde_yaml::from_value::<LoadBalanceConfig>(config)?; + Ok(Arc::new(LoadBalance { + load: config.0.iter().map(|_| AtomicUsize::new(0)).collect(), + config, + })) + } +} +impl Node for LoadBalance { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let index = self + .load + .iter() + .enumerate() + .min_by_key(|(_, k)| k.load(Ordering::Relaxed)) + .map(|(i, _)| i) + .ok_or(ServiceError::CustomStatic("zero routes to balance load"))?; + + self.load[index].fetch_add(1, Ordering::Relaxed); + let resp = self.config.0[index].handle(context, request).await; + self.load[index].fetch_sub(1, Ordering::Relaxed); + resp + }) + } +} diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 051299b..fc4d603 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -9,6 +9,7 @@ use std::{net::SocketAddr, pin::Pin, sync::Arc}; pub mod accesslog; pub mod auth; +pub mod cache; pub mod cgi; pub mod debug; pub mod error; @@ -16,6 +17,7 @@ pub mod file; pub mod files; pub mod headers; pub mod hosts; +pub mod loadbalance; pub mod paths; pub mod proxy; pub mod redirect; @@ -40,6 +42,8 @@ pub static MODULES: &[&dyn NodeKind] = &[ &redirect::RedirectKind, &cgi::CgiKind, &debug::DebugKind, + &cache::CacheKind, + &loadbalance::LoadBalanceKind, ]; pub struct NodeContext { |