summaryrefslogtreecommitdiff
path: root/src/modules
diff options
context:
space:
mode:
Diffstat (limited to 'src/modules')
-rw-r--r--src/modules/auth/mod.rs1
-rw-r--r--src/modules/auth/openid.rs177
-rw-r--r--src/modules/mod.rs1
-rw-r--r--src/modules/proxy.rs9
4 files changed, 184 insertions, 4 deletions
diff --git a/src/modules/auth/mod.rs b/src/modules/auth/mod.rs
index 729ea42..80d02ad 100644
--- a/src/modules/auth/mod.rs
+++ b/src/modules/auth/mod.rs
@@ -12,6 +12,7 @@ use std::{collections::HashMap, fmt, fs::read_to_string};
pub mod basic;
pub mod cookie;
+pub mod openid;
struct Credentials {
wrong_user: PasswordHashString,
diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs
new file mode 100644
index 0000000..979649e
--- /dev/null
+++ b/src/modules/auth/openid.rs
@@ -0,0 +1,177 @@
+use crate::{
+ config::DynNode,
+ error::ServiceError,
+ modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse},
+};
+use base64::Engine;
+use bytes::Buf;
+use futures::Future;
+use http::{
+ header::{CONTENT_TYPE, HOST},
+ uri::{Authority, Parts, PathAndQuery, Scheme},
+ HeaderValue, Method, Request, Uri,
+};
+use http_body_util::{combinators::BoxBody, BodyExt};
+use hyper::{Response, StatusCode};
+use hyper_util::rt::TokioIo;
+use log::info;
+use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC};
+use serde::Deserialize;
+use serde_yaml::Value;
+use sha2::{Digest, Sha256};
+use std::{io::Read, pin::Pin, str::FromStr, sync::Arc};
+use tokio::net::TcpStream;
+
+pub struct OpenIDAuthKind;
+impl NodeKind for OpenIDAuthKind {
+ fn name(&self) -> &'static str {
+ "openid_auth"
+ }
+ fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> {
+ Ok(Arc::new(serde_yaml::from_value::<OpenIDAuth>(config)?))
+ }
+}
+
+#[derive(Deserialize)]
+pub struct OpenIDAuth {
+ client_id: String,
+ provider: String,
+ next: DynNode,
+}
+
+impl Node for OpenIDAuth {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" {
+ let mut state = None;
+ let mut code = None;
+ for entry in request.uri().query().unwrap_or_default().split("&") {
+ let (key, value) = entry.split_once("=").unwrap_or((entry, ""));
+ match key {
+ "state" => {
+ state = Some(
+ percent_decode(value.as_bytes())
+ .decode_utf8_lossy()
+ .to_string(),
+ )
+ }
+ "code" => code = Some(value.to_owned()),
+ _ => (),
+ }
+ }
+ let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?;
+ let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?;
+
+ let mut redirect_uri = Parts::default();
+ redirect_uri.scheme = Some(Scheme::HTTP);
+ redirect_uri.path_and_query = Some(PathAndQuery::from_str(&state).unwrap());
+ redirect_uri.authority = Authority::from_str(
+ request
+ .headers()
+ .get(HOST)
+ .ok_or(ServiceError::InvalidHeader)?
+ .to_str()?,
+ )
+ .ok();
+ let redirect_uri = Uri::from_parts(redirect_uri)
+ .map_err(|_| ServiceError::InvalidUri)?
+ .to_string();
+
+ token_request(&self.provider, &self.client_id, &redirect_uri, &code).await?;
+
+ let mut r = Response::new(BoxBody::<_, ServiceError>::new(
+ format!("state={state:?}\ncode={code:?}").map_err(|_| unreachable!()),
+ ));
+ r.headers_mut()
+ .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
+ return Ok(r);
+ } else {
+ let mut redirect_uri = Parts::default();
+ redirect_uri.scheme = Some(Scheme::HTTP);
+ redirect_uri.path_and_query =
+ Some(PathAndQuery::from_static("/_gnix_auth_callback"));
+ redirect_uri.authority = Authority::from_str(
+ request
+ .headers()
+ .get(HOST)
+ .ok_or(ServiceError::InvalidHeader)?
+ .to_str()?,
+ )
+ .ok();
+ let redirect_uri = Uri::from_parts(redirect_uri)
+ .map_err(|_| ServiceError::InvalidUri)?
+ .to_string();
+
+ let chal: Vec<u8> = {
+ let mut hasher = Sha256::new();
+ hasher.update(r"testvalue");
+ hasher.finalize().to_vec()
+ };
+
+ let uri = format!(
+ "{}/authorize?client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid profile email",
+ self.provider,
+ utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC),
+ utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC),
+ utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC),
+ base64::engine::general_purpose::URL_SAFE.encode(chal),
+ );
+ info!("redirect {uri:?}");
+ let mut resp =
+ Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed());
+ *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
+ resp.headers_mut().insert(
+ "Location",
+ HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?,
+ );
+ Ok(resp)
+ }
+ })
+ }
+}
+
+async fn token_request(
+ provider: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code: &str,
+) -> Result<(), ServiceError> {
+ let url = Uri::from_str(&format!("{provider}/token")).unwrap();
+ let body = format!(
+ "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code",
+ utf8_percent_encode(client_id, NON_ALPHANUMERIC),
+ utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC),
+ utf8_percent_encode(code, NON_ALPHANUMERIC),
+ "testvalue"
+ );
+ let authority = url.authority().unwrap().clone();
+
+ let stream = TcpStream::connect(authority.as_str()).await.unwrap();
+ let io = TokioIo::new(stream);
+
+ let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
+ tokio::task::spawn(async move {
+ if let Err(err) = conn.await {
+ println!("Connection failed: {:?}", err);
+ }
+ });
+
+ let req = Request::builder()
+ .method(Method::POST)
+ .uri(url)
+ .header(HOST, authority.as_str())
+ .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
+ .body(body)
+ .unwrap();
+
+ let res = sender.send_request(req).await.unwrap();
+ let body = res.collect().await.unwrap().aggregate();
+ let mut buf = String::new();
+ body.reader().read_to_string(&mut buf).unwrap();
+ eprintln!("{buf:?}");
+ Ok(())
+}
diff --git a/src/modules/mod.rs b/src/modules/mod.rs
index 9840935..051299b 100644
--- a/src/modules/mod.rs
+++ b/src/modules/mod.rs
@@ -27,6 +27,7 @@ pub type NodeResponse = Response<BoxBody<Bytes, ServiceError>>;
pub static MODULES: &[&dyn NodeKind] = &[
&auth::basic::HttpBasicAuthKind,
&auth::cookie::CookieAuthKind,
+ &auth::openid::OpenIDAuthKind,
&proxy::ProxyKind,
&hosts::HostsKind,
&paths::PathsKind,
diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs
index ce72f65..925e456 100644
--- a/src/modules/proxy.rs
+++ b/src/modules/proxy.rs
@@ -1,8 +1,9 @@
use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
-use crate::{helper::TokioIo, ServiceError};
+use crate::ServiceError;
use futures::Future;
use http_body_util::BodyExt;
use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode};
+use hyper_util::rt::TokioIo;
use log::{debug, warn};
use serde::Deserialize;
use serde_yaml::Value;
@@ -42,7 +43,7 @@ impl Node for Proxy {
let _limit_guard = context.state.l_outgoing.try_acquire()?;
debug!("\tforwarding to {}", self.backend);
let mut resp = {
- let client_stream = TokioIo(
+ let client_stream = TokioIo::new(
TcpStream::connect(self.backend)
.await
.map_err(|_| ServiceError::CantConnect)?,
@@ -75,8 +76,8 @@ impl Node for Proxy {
(Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
debug!("upgrade successful");
match tokio::io::copy_bidirectional(
- &mut TokioIo(upgraded_downstream),
- &mut TokioIo(upgraded_upstream),
+ &mut TokioIo::new(upgraded_downstream),
+ &mut TokioIo::new(upgraded_upstream),
)
.await
{