diff options
-rw-r--r-- | Cargo.lock | 21 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | readme.md | 9 | ||||
-rw-r--r-- | src/error.rs | 4 | ||||
-rw-r--r-- | src/helper.rs | 138 | ||||
-rw-r--r-- | src/main.rs | 8 | ||||
-rw-r--r-- | src/modules/auth/mod.rs | 1 | ||||
-rw-r--r-- | src/modules/auth/openid.rs | 177 | ||||
-rw-r--r-- | src/modules/mod.rs | 1 | ||||
-rw-r--r-- | src/modules/proxy.rs | 9 |
10 files changed, 222 insertions, 147 deletions
@@ -332,6 +332,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] +name = "const-oid" +version = "0.10.0-rc.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9adcf94f05e094fca3005698822ec791cb4433ced416afda1c5ca3b8dfc05a2f" + +[[package]] name = "cpufeatures" version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -357,7 +363,9 @@ version = "0.2.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0b8ce8218c97789f16356e7896b3714f26c2ee1079b79c0b7ae7064bb9089fa" dependencies = [ + "getrandom", "hybrid-array", + "rand_core 0.6.4", ] [[package]] @@ -386,6 +394,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf2e3d6615d99707295a9673e889bf363a04b2a466bd320c65a72536f7577379" dependencies = [ "block-buffer 0.11.0-rc.1", + "const-oid", "crypto-common 0.2.0-rc.1", "subtle", ] @@ -609,6 +618,7 @@ dependencies = [ "rustls-webpki", "serde", "serde_yaml", + "sha2", "thiserror", "tokio", "tokio-rustls", @@ -1371,6 +1381,17 @@ dependencies = [ ] [[package]] +name = "sha2" +version = "0.11.0-pre.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "540c0893cce56cdbcfebcec191ec8e0f470dd1889b6e7a0b503e310a94a168f5" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.11.0-pre.9", +] + +[[package]] name = "shlex" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -52,6 +52,7 @@ mime_guess = "2.0.5" # Crypto for authentificating clients aes-gcm-siv = "0.11.1" argon2 = "0.6.0-pre.1" +sha2 = "0.11.0-pre.4" rand = "0.9.0-alpha.2" # Other helpers and stuff @@ -212,6 +212,15 @@ Login credentials for `cookie_auth` and `http_basic_auth` are supplied as either an object mapping usernames to PHC strings or a file path pointing to a file that contains that map in YALM format. Currently only `argon2id` is supported. +### Additional Notes + +Internally gnix processes requests as they would be sent in HTTP/1.1. HTTP/2 is +translated on arrival. + +Paths matching `/_gnix*` might be used internally in gnix for purposes like +OpenID callback or login action endpoints. I hope your application doesn't rely +on using them for itself. + ## License AGPL-3.0-only; see [COPYING](./COPYING) diff --git a/src/error.rs b/src/error.rs index 636f226..c7290af 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,6 +26,8 @@ pub enum ServiceError { BadUtf8(#[from] std::str::Utf8Error), #[error("bad utf8")] BadUtf82(#[from] std::string::FromUtf8Error), + #[error("bad utf8")] + BadUtf83(#[from] http::header::ToStrError), #[error("bad path")] BadPath, #[error("bad auth")] @@ -36,6 +38,8 @@ pub enum ServiceError { UpgradeFailed, #[error("{0}")] Custom(String), + #[error("{0}")] + CustomStatic(&'static str), #[error("parse int error: {0}")] ParseIntError(#[from] std::num::ParseIntError), #[error("invalid header")] diff --git a/src/helper.rs b/src/helper.rs deleted file mode 100644 index 0daa3d5..0000000 --- a/src/helper.rs +++ /dev/null @@ -1,138 +0,0 @@ -// From https://github.com/hyperium/hyper/blob/master/benches/support/tokiort.rs -use pin_project::pin_project; -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -#[pin_project] -#[derive(Debug)] -pub struct TokioIo<T>(#[pin] pub T); - -impl<T> hyper::rt::Read for TokioIo<T> -where - T: tokio::io::AsyncRead, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll<Result<(), std::io::Error>> { - let n = unsafe { - let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); - match tokio::io::AsyncRead::poll_read(self.project().0, cx, &mut tbuf) { - Poll::Ready(Ok(())) => tbuf.filled().len(), - other => return other, - } - }; - - unsafe { - buf.advance(n); - } - Poll::Ready(Ok(())) - } -} - -impl<T> hyper::rt::Write for TokioIo<T> -where - T: tokio::io::AsyncWrite, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll<Result<usize, std::io::Error>> { - tokio::io::AsyncWrite::poll_write(self.project().0, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { - tokio::io::AsyncWrite::poll_flush(self.project().0, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Result<(), std::io::Error>> { - tokio::io::AsyncWrite::poll_shutdown(self.project().0, cx) - } - - fn is_write_vectored(&self) -> bool { - tokio::io::AsyncWrite::is_write_vectored(&self.0) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll<Result<usize, std::io::Error>> { - tokio::io::AsyncWrite::poll_write_vectored(self.project().0, cx, bufs) - } -} - -impl<T> tokio::io::AsyncRead for TokioIo<T> -where - T: hyper::rt::Read, -{ - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - tbuf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll<Result<(), std::io::Error>> { - //let init = tbuf.initialized().len(); - let filled = tbuf.filled().len(); - let sub_filled = unsafe { - let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); - - match hyper::rt::Read::poll_read(self.project().0, cx, buf.unfilled()) { - Poll::Ready(Ok(())) => buf.filled().len(), - other => return other, - } - }; - - let n_filled = filled + sub_filled; - // At least sub_filled bytes had to have been initialized. - let n_init = sub_filled; - unsafe { - tbuf.assume_init(n_init); - tbuf.set_filled(n_filled); - } - - Poll::Ready(Ok(())) - } -} - -impl<T> tokio::io::AsyncWrite for TokioIo<T> -where - T: hyper::rt::Write, -{ - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll<Result<usize, std::io::Error>> { - hyper::rt::Write::poll_write(self.project().0, cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { - hyper::rt::Write::poll_flush(self.project().0, cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Result<(), std::io::Error>> { - hyper::rt::Write::poll_shutdown(self.project().0, cx) - } - - fn is_write_vectored(&self) -> bool { - hyper::rt::Write::is_write_vectored(&self.0) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll<Result<usize, std::io::Error>> { - hyper::rt::Write::poll_write_vectored(self.project().0, cx, bufs) - } -} diff --git a/src/main.rs b/src/main.rs index a13d171..0dcf01b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,6 @@ pub mod certs; pub mod config; pub mod error; -pub mod helper; pub mod modules; use aes_gcm_siv::{aead::generic_array::GenericArray, Aes256GcmSiv, KeyInit}; @@ -14,7 +13,6 @@ use certs::CertPool; use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; use futures::future::try_join_all; -use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, @@ -23,7 +21,7 @@ use hyper::{ service::service_fn, Request, Response, StatusCode, Uri, }; -use hyper_util::rt::TokioExecutor; +use hyper_util::rt::{TokioExecutor, TokioIo}; use log::{debug, error, info, warn, LevelFilter}; use modules::{NodeContext, MODULES}; use std::{ @@ -126,7 +124,7 @@ async fn serve_http(state: Arc<State>) -> Result<()> { loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); - let stream = TokioIo(stream); + let stream = TokioIo::new(stream); let state = state.clone(); tokio::spawn(async move { serve_stream(state, stream, addr).await }); } @@ -164,7 +162,7 @@ async fn serve_https(state: Arc<State>) -> Result<()> { tokio::task::spawn(async move { debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { - Ok(stream) => serve_stream(state, TokioIo(stream), addr).await, + Ok(stream) => serve_stream(state, TokioIo::new(stream), addr).await, Err(e) => warn!("error accepting tls: {e}"), }; }); diff --git a/src/modules/auth/mod.rs b/src/modules/auth/mod.rs index 729ea42..80d02ad 100644 --- a/src/modules/auth/mod.rs +++ b/src/modules/auth/mod.rs @@ -12,6 +12,7 @@ use std::{collections::HashMap, fmt, fs::read_to_string}; pub mod basic; pub mod cookie; +pub mod openid; struct Credentials { wrong_user: PasswordHashString, diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs new file mode 100644 index 0000000..979649e --- /dev/null +++ b/src/modules/auth/openid.rs @@ -0,0 +1,177 @@ +use crate::{ + config::DynNode, + error::ServiceError, + modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}, +}; +use base64::Engine; +use bytes::Buf; +use futures::Future; +use http::{ + header::{CONTENT_TYPE, HOST}, + uri::{Authority, Parts, PathAndQuery, Scheme}, + HeaderValue, Method, Request, Uri, +}; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{Response, StatusCode}; +use hyper_util::rt::TokioIo; +use log::info; +use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC}; +use serde::Deserialize; +use serde_yaml::Value; +use sha2::{Digest, Sha256}; +use std::{io::Read, pin::Pin, str::FromStr, sync::Arc}; +use tokio::net::TcpStream; + +pub struct OpenIDAuthKind; +impl NodeKind for OpenIDAuthKind { + fn name(&self) -> &'static str { + "openid_auth" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<OpenIDAuth>(config)?)) + } +} + +#[derive(Deserialize)] +pub struct OpenIDAuth { + client_id: String, + provider: String, + next: DynNode, +} + +impl Node for OpenIDAuth { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" { + let mut state = None; + let mut code = None; + for entry in request.uri().query().unwrap_or_default().split("&") { + let (key, value) = entry.split_once("=").unwrap_or((entry, "")); + match key { + "state" => { + state = Some( + percent_decode(value.as_bytes()) + .decode_utf8_lossy() + .to_string(), + ) + } + "code" => code = Some(value.to_owned()), + _ => (), + } + } + let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?; + let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?; + + let mut redirect_uri = Parts::default(); + redirect_uri.scheme = Some(Scheme::HTTP); + redirect_uri.path_and_query = Some(PathAndQuery::from_str(&state).unwrap()); + redirect_uri.authority = Authority::from_str( + request + .headers() + .get(HOST) + .ok_or(ServiceError::InvalidHeader)? + .to_str()?, + ) + .ok(); + let redirect_uri = Uri::from_parts(redirect_uri) + .map_err(|_| ServiceError::InvalidUri)? + .to_string(); + + token_request(&self.provider, &self.client_id, &redirect_uri, &code).await?; + + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + format!("state={state:?}\ncode={code:?}").map_err(|_| unreachable!()), + )); + r.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + return Ok(r); + } else { + let mut redirect_uri = Parts::default(); + redirect_uri.scheme = Some(Scheme::HTTP); + redirect_uri.path_and_query = + Some(PathAndQuery::from_static("/_gnix_auth_callback")); + redirect_uri.authority = Authority::from_str( + request + .headers() + .get(HOST) + .ok_or(ServiceError::InvalidHeader)? + .to_str()?, + ) + .ok(); + let redirect_uri = Uri::from_parts(redirect_uri) + .map_err(|_| ServiceError::InvalidUri)? + .to_string(); + + let chal: Vec<u8> = { + let mut hasher = Sha256::new(); + hasher.update(r"testvalue"); + hasher.finalize().to_vec() + }; + + let uri = format!( + "{}/authorize?client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid profile email", + self.provider, + utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), + utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), + utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), + base64::engine::general_purpose::URL_SAFE.encode(chal), + ); + info!("redirect {uri:?}"); + let mut resp = + Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); + *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; + resp.headers_mut().insert( + "Location", + HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?, + ); + Ok(resp) + } + }) + } +} + +async fn token_request( + provider: &str, + client_id: &str, + redirect_uri: &str, + code: &str, +) -> Result<(), ServiceError> { + let url = Uri::from_str(&format!("{provider}/token")).unwrap(); + let body = format!( + "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code", + utf8_percent_encode(client_id, NON_ALPHANUMERIC), + utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC), + utf8_percent_encode(code, NON_ALPHANUMERIC), + "testvalue" + ); + let authority = url.authority().unwrap().clone(); + + let stream = TcpStream::connect(authority.as_str()).await.unwrap(); + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let req = Request::builder() + .method(Method::POST) + .uri(url) + .header(HOST, authority.as_str()) + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(body) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + let body = res.collect().await.unwrap().aggregate(); + let mut buf = String::new(); + body.reader().read_to_string(&mut buf).unwrap(); + eprintln!("{buf:?}"); + Ok(()) +} diff --git a/src/modules/mod.rs b/src/modules/mod.rs index 9840935..051299b 100644 --- a/src/modules/mod.rs +++ b/src/modules/mod.rs @@ -27,6 +27,7 @@ pub type NodeResponse = Response<BoxBody<Bytes, ServiceError>>; pub static MODULES: &[&dyn NodeKind] = &[ &auth::basic::HttpBasicAuthKind, &auth::cookie::CookieAuthKind, + &auth::openid::OpenIDAuthKind, &proxy::ProxyKind, &hosts::HostsKind, &paths::PathsKind, diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs index ce72f65..925e456 100644 --- a/src/modules/proxy.rs +++ b/src/modules/proxy.rs @@ -1,8 +1,9 @@ use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; -use crate::{helper::TokioIo, ServiceError}; +use crate::ServiceError; use futures::Future; use http_body_util::BodyExt; use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; +use hyper_util::rt::TokioIo; use log::{debug, warn}; use serde::Deserialize; use serde_yaml::Value; @@ -42,7 +43,7 @@ impl Node for Proxy { let _limit_guard = context.state.l_outgoing.try_acquire()?; debug!("\tforwarding to {}", self.backend); let mut resp = { - let client_stream = TokioIo( + let client_stream = TokioIo::new( TcpStream::connect(self.backend) .await .map_err(|_| ServiceError::CantConnect)?, @@ -75,8 +76,8 @@ impl Node for Proxy { (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { debug!("upgrade successful"); match tokio::io::copy_bidirectional( - &mut TokioIo(upgraded_downstream), - &mut TokioIo(upgraded_upstream), + &mut TokioIo::new(upgraded_downstream), + &mut TokioIo::new(upgraded_upstream), ) .await { |