aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/modules/auth/openid.rs4
-rw-r--r--src/modules/cache.rs103
-rw-r--r--src/modules/cgi.rs22
-rw-r--r--src/modules/debug.rs2
-rw-r--r--src/modules/loadbalance.rs62
-rw-r--r--src/modules/mod.rs4
6 files changed, 191 insertions, 6 deletions
diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs
index a8d9d6e..7253ba8 100644
--- a/src/modules/auth/openid.rs
+++ b/src/modules/auth/openid.rs
@@ -42,6 +42,7 @@ pub struct OpenIDAuth {
client_id: String,
authorize_endpoint: String,
token_endpoint: String,
+ scope: String,
next: DynNode,
}
@@ -153,13 +154,14 @@ id_token={id_token:?}"#
let redirect_uri = redirect_uri(&request)?.to_string();
let uri = format!(
- "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid magic",
+ "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid {}",
self.authorize_endpoint,
utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC),
utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC),
hex::encode(verif_cipher),
utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC),
base64::engine::general_purpose::URL_SAFE.encode(chal).trim_end_matches("="),
+ utf8_percent_encode(&self.scope, NON_ALPHANUMERIC),
);
info!("redirect {uri:?}");
let mut resp =
diff --git a/src/modules/cache.rs b/src/modules/cache.rs
new file mode 100644
index 0000000..b07bbbb
--- /dev/null
+++ b/src/modules/cache.rs
@@ -0,0 +1,103 @@
+//! Response caching module
+//!
+//! Considerations:
+//! - check cache header
+//! - ignore responses that get too large
+//! - ignore requests with body (or too large body)
+//! - LRU cache pruning
+//! - different backends:
+//! - in memory (HashMap)
+//! - on disk (redb? filesystem?)
+//! - external db?
+use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
+use crate::{config::DynNode, error::ServiceError};
+use anyhow::Result;
+use bytes::Bytes;
+use http::Response;
+use http_body_util::{BodyExt, Full};
+use serde::Deserialize;
+use serde_yaml::Value;
+use sha2::{Digest, Sha256};
+use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
+use tokio::sync::RwLock;
+
+pub struct CacheKind;
+
+#[derive(Deserialize)]
+struct CacheConfig {
+ next: DynNode,
+}
+
+struct Cache {
+ entries: RwLock<HashMap<[u8; 32], Response<Bytes>>>,
+ config: CacheConfig,
+}
+
+impl NodeKind for CacheKind {
+ fn name(&self) -> &'static str {
+ "cache"
+ }
+ fn instanciate(&self, config: Value) -> Result<Arc<dyn Node>> {
+ let config = serde_yaml::from_value::<CacheConfig>(config)?;
+ Ok(Arc::new(Cache {
+ config,
+ entries: HashMap::new().into(),
+ }))
+ }
+}
+impl Node for Cache {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ // not very fast
+ let mut hasher = Sha256::new();
+ hasher.update(request.method().as_str().len().to_be_bytes());
+ hasher.update(request.method().as_str());
+ hasher.update(request.uri().path().len().to_be_bytes());
+ hasher.update(request.uri().path());
+ hasher.update(if request.uri().query().is_some() {
+ [1]
+ } else {
+ [0]
+ });
+ if let Some(q) = request.uri().query() {
+ hasher.update(q.len().to_be_bytes());
+ hasher.update(q);
+ }
+ hasher.update(request.headers().len().to_be_bytes());
+ for (k, v) in request.headers() {
+ hasher.update(k.as_str().len().to_be_bytes());
+ hasher.update(v.as_bytes().len().to_be_bytes());
+ hasher.update(k);
+ hasher.update(v);
+ }
+ let key: [u8; 32] = hasher.finalize().try_into().unwrap();
+
+ if let Some(resp) = self.entries.read().await.get(&key) {
+ return Ok(resp
+ .to_owned()
+ .map(|b| Full::new(b).map_err(|e| match e {}).boxed()));
+ }
+
+ let response = self.config.next.handle(context, request).await?;
+
+ let h = response.headers().to_owned();
+ let s = response.status().to_owned();
+ let body = response.collect().await?.to_bytes();
+
+ let mut r1 = Response::new(Full::new(body.clone()).map_err(|e| match e {}).boxed());
+ *r1.headers_mut() = h.clone();
+ *r1.status_mut() = s.clone();
+
+ let mut r2 = Response::new(body);
+ *r2.headers_mut() = h;
+ *r2.status_mut() = s;
+ self.entries.write().await.insert(key, r2);
+
+ Ok(r1)
+ })
+ }
+}
diff --git a/src/modules/cgi.rs b/src/modules/cgi.rs
index b6f1033..9d9372f 100644
--- a/src/modules/cgi.rs
+++ b/src/modules/cgi.rs
@@ -2,7 +2,7 @@ use super::{Node, NodeKind, NodeResponse};
use crate::error::ServiceError;
use anyhow::{anyhow, Result};
use futures::TryStreamExt;
-use http_body_util::{combinators::BoxBody, StreamBody};
+use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
use hyper::{
body::Frame,
header::{HeaderName, HeaderValue, CONTENT_LENGTH, CONTENT_TYPE},
@@ -10,12 +10,15 @@ use hyper::{
};
use serde::Deserialize;
use serde_yaml::Value;
-use std::{future::Future, path::PathBuf, pin::Pin, process::Stdio, str::FromStr, sync::Arc};
+use std::{
+ future::Future, io::ErrorKind, path::PathBuf, pin::Pin, process::Stdio, str::FromStr, sync::Arc,
+};
use tokio::{
- io::{AsyncBufReadExt, BufReader},
+ io::{copy, AsyncBufReadExt, BufReader, BufWriter},
process::Command,
+ spawn,
};
-use tokio_util::io::ReaderStream;
+use tokio_util::io::{ReaderStream, StreamReader};
use users::get_user_by_name;
pub struct CgiKind;
@@ -107,6 +110,17 @@ impl Node for Cgi {
let mut child = command.spawn()?;
let mut stdout = BufReader::new(child.stdout.take().unwrap());
+ let mut stdin = BufWriter::new(child.stdin.take().unwrap());
+
+ // TODO prevent abuse
+ let mut body = StreamReader::new(
+ request
+ .into_body()
+ .into_data_stream()
+ .map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "asd")),
+ );
+ spawn(async move { copy(&mut body, &mut stdin).await });
+
let mut line = String::new();
let mut response = Response::new(());
loop {
diff --git a/src/modules/debug.rs b/src/modules/debug.rs
index 3ab03ec..04f9806 100644
--- a/src/modules/debug.rs
+++ b/src/modules/debug.rs
@@ -29,7 +29,7 @@ impl Node for Debug {
) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
Box::pin(async move {
let s = format!(
- "address: {:?}\nverion: {:?}\nuri: {:?}\nheaders: {:#?}",
+ "address: {:?}\nversion: {:?}\nuri: {:?}\nheaders: {:#?}",
context.addr,
request.version(),
request.uri(),
diff --git a/src/modules/loadbalance.rs b/src/modules/loadbalance.rs
new file mode 100644
index 0000000..5358b03
--- /dev/null
+++ b/src/modules/loadbalance.rs
@@ -0,0 +1,62 @@
+//! Load balancing module
+//!
+//! Given a set of handlers, the handler that is the least busy will handle the next request.
+//! Current implementation does not scale well for many handlers.
+use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
+use crate::{config::DynNode, error::ServiceError};
+use anyhow::Result;
+use serde::Deserialize;
+use serde_yaml::Value;
+use std::{
+ future::Future,
+ pin::Pin,
+ sync::{
+ atomic::{AtomicUsize, Ordering},
+ Arc,
+ },
+};
+
+pub struct LoadBalanceKind;
+
+#[derive(Deserialize)]
+struct LoadBalanceConfig(Vec<DynNode>);
+
+struct LoadBalance {
+ load: Vec<AtomicUsize>,
+ config: LoadBalanceConfig,
+}
+
+impl NodeKind for LoadBalanceKind {
+ fn name(&self) -> &'static str {
+ "loadbalance"
+ }
+ fn instanciate(&self, config: Value) -> Result<Arc<dyn Node>> {
+ let config = serde_yaml::from_value::<LoadBalanceConfig>(config)?;
+ Ok(Arc::new(LoadBalance {
+ load: config.0.iter().map(|_| AtomicUsize::new(0)).collect(),
+ config,
+ }))
+ }
+}
+impl Node for LoadBalance {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ let index = self
+ .load
+ .iter()
+ .enumerate()
+ .min_by_key(|(_, k)| k.load(Ordering::Relaxed))
+ .map(|(i, _)| i)
+ .ok_or(ServiceError::CustomStatic("zero routes to balance load"))?;
+
+ self.load[index].fetch_add(1, Ordering::Relaxed);
+ let resp = self.config.0[index].handle(context, request).await;
+ self.load[index].fetch_sub(1, Ordering::Relaxed);
+ resp
+ })
+ }
+}
diff --git a/src/modules/mod.rs b/src/modules/mod.rs
index 051299b..fc4d603 100644
--- a/src/modules/mod.rs
+++ b/src/modules/mod.rs
@@ -9,6 +9,7 @@ use std::{net::SocketAddr, pin::Pin, sync::Arc};
pub mod accesslog;
pub mod auth;
+pub mod cache;
pub mod cgi;
pub mod debug;
pub mod error;
@@ -16,6 +17,7 @@ pub mod file;
pub mod files;
pub mod headers;
pub mod hosts;
+pub mod loadbalance;
pub mod paths;
pub mod proxy;
pub mod redirect;
@@ -40,6 +42,8 @@ pub static MODULES: &[&dyn NodeKind] = &[
&redirect::RedirectKind,
&cgi::CgiKind,
&debug::DebugKind,
+ &cache::CacheKind,
+ &loadbalance::LoadBalanceKind,
];
pub struct NodeContext {