summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2024-08-21 23:30:19 +0200
committermetamuffin <metamuffin@disroot.org>2024-08-21 23:30:19 +0200
commit4555531d2cb4856d6216907a22aac6797d097ad2 (patch)
tree9011bcdb37b8a314d4772bae905c4c920c88e565
parent0cd6ddac0833c7fc6d2fb8511073132006072148 (diff)
downloadgnix-4555531d2cb4856d6216907a22aac6797d097ad2.tar
gnix-4555531d2cb4856d6216907a22aac6797d097ad2.tar.bz2
gnix-4555531d2cb4856d6216907a22aac6797d097ad2.tar.zst
first steps torwards openid auth
-rw-r--r--Cargo.lock21
-rw-r--r--Cargo.toml1
-rw-r--r--readme.md9
-rw-r--r--src/error.rs4
-rw-r--r--src/helper.rs138
-rw-r--r--src/main.rs8
-rw-r--r--src/modules/auth/mod.rs1
-rw-r--r--src/modules/auth/openid.rs177
-rw-r--r--src/modules/mod.rs1
-rw-r--r--src/modules/proxy.rs9
10 files changed, 222 insertions, 147 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 98a98e9..a297a03 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -332,6 +332,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
[[package]]
+name = "const-oid"
+version = "0.10.0-rc.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9adcf94f05e094fca3005698822ec791cb4433ced416afda1c5ca3b8dfc05a2f"
+
+[[package]]
name = "cpufeatures"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -357,7 +363,9 @@ version = "0.2.0-rc.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0b8ce8218c97789f16356e7896b3714f26c2ee1079b79c0b7ae7064bb9089fa"
dependencies = [
+ "getrandom",
"hybrid-array",
+ "rand_core 0.6.4",
]
[[package]]
@@ -386,6 +394,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf2e3d6615d99707295a9673e889bf363a04b2a466bd320c65a72536f7577379"
dependencies = [
"block-buffer 0.11.0-rc.1",
+ "const-oid",
"crypto-common 0.2.0-rc.1",
"subtle",
]
@@ -609,6 +618,7 @@ dependencies = [
"rustls-webpki",
"serde",
"serde_yaml",
+ "sha2",
"thiserror",
"tokio",
"tokio-rustls",
@@ -1371,6 +1381,17 @@ dependencies = [
]
[[package]]
+name = "sha2"
+version = "0.11.0-pre.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "540c0893cce56cdbcfebcec191ec8e0f470dd1889b6e7a0b503e310a94a168f5"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "digest 0.11.0-pre.9",
+]
+
+[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 20efb47..6c32a9d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -52,6 +52,7 @@ mime_guess = "2.0.5"
# Crypto for authentificating clients
aes-gcm-siv = "0.11.1"
argon2 = "0.6.0-pre.1"
+sha2 = "0.11.0-pre.4"
rand = "0.9.0-alpha.2"
# Other helpers and stuff
diff --git a/readme.md b/readme.md
index 5739199..1403df0 100644
--- a/readme.md
+++ b/readme.md
@@ -212,6 +212,15 @@ Login credentials for `cookie_auth` and `http_basic_auth` are supplied as either
an object mapping usernames to PHC strings or a file path pointing to a file
that contains that map in YALM format. Currently only `argon2id` is supported.
+### Additional Notes
+
+Internally gnix processes requests as they would be sent in HTTP/1.1. HTTP/2 is
+translated on arrival.
+
+Paths matching `/_gnix*` might be used internally in gnix for purposes like
+OpenID callback or login action endpoints. I hope your application doesn't rely
+on using them for itself.
+
## License
AGPL-3.0-only; see [COPYING](./COPYING)
diff --git a/src/error.rs b/src/error.rs
index 636f226..c7290af 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -26,6 +26,8 @@ pub enum ServiceError {
BadUtf8(#[from] std::str::Utf8Error),
#[error("bad utf8")]
BadUtf82(#[from] std::string::FromUtf8Error),
+ #[error("bad utf8")]
+ BadUtf83(#[from] http::header::ToStrError),
#[error("bad path")]
BadPath,
#[error("bad auth")]
@@ -36,6 +38,8 @@ pub enum ServiceError {
UpgradeFailed,
#[error("{0}")]
Custom(String),
+ #[error("{0}")]
+ CustomStatic(&'static str),
#[error("parse int error: {0}")]
ParseIntError(#[from] std::num::ParseIntError),
#[error("invalid header")]
diff --git a/src/helper.rs b/src/helper.rs
deleted file mode 100644
index 0daa3d5..0000000
--- a/src/helper.rs
+++ /dev/null
@@ -1,138 +0,0 @@
-// From https://github.com/hyperium/hyper/blob/master/benches/support/tokiort.rs
-use pin_project::pin_project;
-use std::{
- pin::Pin,
- task::{Context, Poll},
-};
-
-#[pin_project]
-#[derive(Debug)]
-pub struct TokioIo<T>(#[pin] pub T);
-
-impl<T> hyper::rt::Read for TokioIo<T>
-where
- T: tokio::io::AsyncRead,
-{
- fn poll_read(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- mut buf: hyper::rt::ReadBufCursor<'_>,
- ) -> Poll<Result<(), std::io::Error>> {
- let n = unsafe {
- let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
- match tokio::io::AsyncRead::poll_read(self.project().0, cx, &mut tbuf) {
- Poll::Ready(Ok(())) => tbuf.filled().len(),
- other => return other,
- }
- };
-
- unsafe {
- buf.advance(n);
- }
- Poll::Ready(Ok(()))
- }
-}
-
-impl<T> hyper::rt::Write for TokioIo<T>
-where
- T: tokio::io::AsyncWrite,
-{
- fn poll_write(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<Result<usize, std::io::Error>> {
- tokio::io::AsyncWrite::poll_write(self.project().0, cx, buf)
- }
-
- fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
- tokio::io::AsyncWrite::poll_flush(self.project().0, cx)
- }
-
- fn poll_shutdown(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Result<(), std::io::Error>> {
- tokio::io::AsyncWrite::poll_shutdown(self.project().0, cx)
- }
-
- fn is_write_vectored(&self) -> bool {
- tokio::io::AsyncWrite::is_write_vectored(&self.0)
- }
-
- fn poll_write_vectored(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- bufs: &[std::io::IoSlice<'_>],
- ) -> Poll<Result<usize, std::io::Error>> {
- tokio::io::AsyncWrite::poll_write_vectored(self.project().0, cx, bufs)
- }
-}
-
-impl<T> tokio::io::AsyncRead for TokioIo<T>
-where
- T: hyper::rt::Read,
-{
- fn poll_read(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- tbuf: &mut tokio::io::ReadBuf<'_>,
- ) -> Poll<Result<(), std::io::Error>> {
- //let init = tbuf.initialized().len();
- let filled = tbuf.filled().len();
- let sub_filled = unsafe {
- let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
-
- match hyper::rt::Read::poll_read(self.project().0, cx, buf.unfilled()) {
- Poll::Ready(Ok(())) => buf.filled().len(),
- other => return other,
- }
- };
-
- let n_filled = filled + sub_filled;
- // At least sub_filled bytes had to have been initialized.
- let n_init = sub_filled;
- unsafe {
- tbuf.assume_init(n_init);
- tbuf.set_filled(n_filled);
- }
-
- Poll::Ready(Ok(()))
- }
-}
-
-impl<T> tokio::io::AsyncWrite for TokioIo<T>
-where
- T: hyper::rt::Write,
-{
- fn poll_write(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<Result<usize, std::io::Error>> {
- hyper::rt::Write::poll_write(self.project().0, cx, buf)
- }
-
- fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
- hyper::rt::Write::poll_flush(self.project().0, cx)
- }
-
- fn poll_shutdown(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Result<(), std::io::Error>> {
- hyper::rt::Write::poll_shutdown(self.project().0, cx)
- }
-
- fn is_write_vectored(&self) -> bool {
- hyper::rt::Write::is_write_vectored(&self.0)
- }
-
- fn poll_write_vectored(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- bufs: &[std::io::IoSlice<'_>],
- ) -> Poll<Result<usize, std::io::Error>> {
- hyper::rt::Write::poll_write_vectored(self.project().0, cx, bufs)
- }
-}
diff --git a/src/main.rs b/src/main.rs
index a13d171..0dcf01b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,7 +5,6 @@
pub mod certs;
pub mod config;
pub mod error;
-pub mod helper;
pub mod modules;
use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit};
@@ -14,7 +13,6 @@ use certs::CertPool;
use config::{setup_file_watch, Config, NODE_KINDS};
use error::ServiceError;
use futures::future::try_join_all;
-use helper::TokioIo;
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
body::Incoming,
@@ -23,7 +21,7 @@ use hyper::{
service::service_fn,
Request, Response, StatusCode, Uri,
};
-use hyper_util::rt::TokioExecutor;
+use hyper_util::rt::{TokioExecutor, TokioIo};
use log::{debug, error, info, warn, LevelFilter};
use modules::{NodeContext, MODULES};
use std::{
@@ -126,7 +124,7 @@ async fn serve_http(state: Arc<State>) -> Result<()> {
loop {
let (stream, addr) = l.accept().await.context("accepting connection")?;
debug!("connection from {addr}");
- let stream = TokioIo(stream);
+ let stream = TokioIo::new(stream);
let state = state.clone();
tokio::spawn(async move { serve_stream(state, stream, addr).await });
}
@@ -164,7 +162,7 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
tokio::task::spawn(async move {
debug!("connection from {addr}");
match tls_acceptor.accept(stream).await {
- Ok(stream) => serve_stream(state, TokioIo(stream), addr).await,
+ Ok(stream) => serve_stream(state, TokioIo::new(stream), addr).await,
Err(e) => warn!("error accepting tls: {e}"),
};
});
diff --git a/src/modules/auth/mod.rs b/src/modules/auth/mod.rs
index 729ea42..80d02ad 100644
--- a/src/modules/auth/mod.rs
+++ b/src/modules/auth/mod.rs
@@ -12,6 +12,7 @@ use std::{collections::HashMap, fmt, fs::read_to_string};
pub mod basic;
pub mod cookie;
+pub mod openid;
struct Credentials {
wrong_user: PasswordHashString,
diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs
new file mode 100644
index 0000000..979649e
--- /dev/null
+++ b/src/modules/auth/openid.rs
@@ -0,0 +1,177 @@
+use crate::{
+ config::DynNode,
+ error::ServiceError,
+ modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse},
+};
+use base64::Engine;
+use bytes::Buf;
+use futures::Future;
+use http::{
+ header::{CONTENT_TYPE, HOST},
+ uri::{Authority, Parts, PathAndQuery, Scheme},
+ HeaderValue, Method, Request, Uri,
+};
+use http_body_util::{combinators::BoxBody, BodyExt};
+use hyper::{Response, StatusCode};
+use hyper_util::rt::TokioIo;
+use log::info;
+use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC};
+use serde::Deserialize;
+use serde_yaml::Value;
+use sha2::{Digest, Sha256};
+use std::{io::Read, pin::Pin, str::FromStr, sync::Arc};
+use tokio::net::TcpStream;
+
+pub struct OpenIDAuthKind;
+impl NodeKind for OpenIDAuthKind {
+ fn name(&self) -> &'static str {
+ "openid_auth"
+ }
+ fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> {
+ Ok(Arc::new(serde_yaml::from_value::<OpenIDAuth>(config)?))
+ }
+}
+
+#[derive(Deserialize)]
+pub struct OpenIDAuth {
+ client_id: String,
+ provider: String,
+ next: DynNode,
+}
+
+impl Node for OpenIDAuth {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" {
+ let mut state = None;
+ let mut code = None;
+ for entry in request.uri().query().unwrap_or_default().split("&") {
+ let (key, value) = entry.split_once("=").unwrap_or((entry, ""));
+ match key {
+ "state" => {
+ state = Some(
+ percent_decode(value.as_bytes())
+ .decode_utf8_lossy()
+ .to_string(),
+ )
+ }
+ "code" => code = Some(value.to_owned()),
+ _ => (),
+ }
+ }
+ let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?;
+ let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?;
+
+ let mut redirect_uri = Parts::default();
+ redirect_uri.scheme = Some(Scheme::HTTP);
+ redirect_uri.path_and_query = Some(PathAndQuery::from_str(&state).unwrap());
+ redirect_uri.authority = Authority::from_str(
+ request
+ .headers()
+ .get(HOST)
+ .ok_or(ServiceError::InvalidHeader)?
+ .to_str()?,
+ )
+ .ok();
+ let redirect_uri = Uri::from_parts(redirect_uri)
+ .map_err(|_| ServiceError::InvalidUri)?
+ .to_string();
+
+ token_request(&self.provider, &self.client_id, &redirect_uri, &code).await?;
+
+ let mut r = Response::new(BoxBody::<_, ServiceError>::new(
+ format!("state={state:?}\ncode={code:?}").map_err(|_| unreachable!()),
+ ));
+ r.headers_mut()
+ .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
+ return Ok(r);
+ } else {
+ let mut redirect_uri = Parts::default();
+ redirect_uri.scheme = Some(Scheme::HTTP);
+ redirect_uri.path_and_query =
+ Some(PathAndQuery::from_static("/_gnix_auth_callback"));
+ redirect_uri.authority = Authority::from_str(
+ request
+ .headers()
+ .get(HOST)
+ .ok_or(ServiceError::InvalidHeader)?
+ .to_str()?,
+ )
+ .ok();
+ let redirect_uri = Uri::from_parts(redirect_uri)
+ .map_err(|_| ServiceError::InvalidUri)?
+ .to_string();
+
+ let chal: Vec<u8> = {
+ let mut hasher = Sha256::new();
+ hasher.update(r"testvalue");
+ hasher.finalize().to_vec()
+ };
+
+ let uri = format!(
+ "{}/authorize?client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid profile email",
+ self.provider,
+ utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC),
+ utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC),
+ utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC),
+ base64::engine::general_purpose::URL_SAFE.encode(chal),
+ );
+ info!("redirect {uri:?}");
+ let mut resp =
+ Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed());
+ *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;
+ resp.headers_mut().insert(
+ "Location",
+ HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?,
+ );
+ Ok(resp)
+ }
+ })
+ }
+}
+
+async fn token_request(
+ provider: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code: &str,
+) -> Result<(), ServiceError> {
+ let url = Uri::from_str(&format!("{provider}/token")).unwrap();
+ let body = format!(
+ "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code",
+ utf8_percent_encode(client_id, NON_ALPHANUMERIC),
+ utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC),
+ utf8_percent_encode(code, NON_ALPHANUMERIC),
+ "testvalue"
+ );
+ let authority = url.authority().unwrap().clone();
+
+ let stream = TcpStream::connect(authority.as_str()).await.unwrap();
+ let io = TokioIo::new(stream);
+
+ let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
+ tokio::task::spawn(async move {
+ if let Err(err) = conn.await {
+ println!("Connection failed: {:?}", err);
+ }
+ });
+
+ let req = Request::builder()
+ .method(Method::POST)
+ .uri(url)
+ .header(HOST, authority.as_str())
+ .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
+ .body(body)
+ .unwrap();
+
+ let res = sender.send_request(req).await.unwrap();
+ let body = res.collect().await.unwrap().aggregate();
+ let mut buf = String::new();
+ body.reader().read_to_string(&mut buf).unwrap();
+ eprintln!("{buf:?}");
+ Ok(())
+}
diff --git a/src/modules/mod.rs b/src/modules/mod.rs
index 9840935..051299b 100644
--- a/src/modules/mod.rs
+++ b/src/modules/mod.rs
@@ -27,6 +27,7 @@ pub type NodeResponse = Response<BoxBody<Bytes, ServiceError>>;
pub static MODULES: &[&dyn NodeKind] = &[
&auth::basic::HttpBasicAuthKind,
&auth::cookie::CookieAuthKind,
+ &auth::openid::OpenIDAuthKind,
&proxy::ProxyKind,
&hosts::HostsKind,
&paths::PathsKind,
diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs
index ce72f65..925e456 100644
--- a/src/modules/proxy.rs
+++ b/src/modules/proxy.rs
@@ -1,8 +1,9 @@
use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
-use crate::{helper::TokioIo, ServiceError};
+use crate::ServiceError;
use futures::Future;
use http_body_util::BodyExt;
use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode};
+use hyper_util::rt::TokioIo;
use log::{debug, warn};
use serde::Deserialize;
use serde_yaml::Value;
@@ -42,7 +43,7 @@ impl Node for Proxy {
let _limit_guard = context.state.l_outgoing.try_acquire()?;
debug!("\tforwarding to {}", self.backend);
let mut resp = {
- let client_stream = TokioIo(
+ let client_stream = TokioIo::new(
TcpStream::connect(self.backend)
.await
.map_err(|_| ServiceError::CantConnect)?,
@@ -75,8 +76,8 @@ impl Node for Proxy {
(Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
debug!("upgrade successful");
match tokio::io::copy_bidirectional(
- &mut TokioIo(upgraded_downstream),
- &mut TokioIo(upgraded_upstream),
+ &mut TokioIo::new(upgraded_downstream),
+ &mut TokioIo::new(upgraded_upstream),
)
.await
{