diff options
Diffstat (limited to 'src/modules/auth/openid.rs')
-rw-r--r-- | src/modules/auth/openid.rs | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs new file mode 100644 index 0000000..979649e --- /dev/null +++ b/src/modules/auth/openid.rs @@ -0,0 +1,177 @@ +use crate::{ + config::DynNode, + error::ServiceError, + modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}, +}; +use base64::Engine; +use bytes::Buf; +use futures::Future; +use http::{ + header::{CONTENT_TYPE, HOST}, + uri::{Authority, Parts, PathAndQuery, Scheme}, + HeaderValue, Method, Request, Uri, +}; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{Response, StatusCode}; +use hyper_util::rt::TokioIo; +use log::info; +use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC}; +use serde::Deserialize; +use serde_yaml::Value; +use sha2::{Digest, Sha256}; +use std::{io::Read, pin::Pin, str::FromStr, sync::Arc}; +use tokio::net::TcpStream; + +pub struct OpenIDAuthKind; +impl NodeKind for OpenIDAuthKind { + fn name(&self) -> &'static str { + "openid_auth" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<OpenIDAuth>(config)?)) + } +} + +#[derive(Deserialize)] +pub struct OpenIDAuth { + client_id: String, + provider: String, + next: DynNode, +} + +impl Node for OpenIDAuth { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" { + let mut state = None; + let mut code = None; + for entry in request.uri().query().unwrap_or_default().split("&") { + let (key, value) = entry.split_once("=").unwrap_or((entry, "")); + match key { + "state" => { + state = Some( + percent_decode(value.as_bytes()) + .decode_utf8_lossy() + .to_string(), + ) + } + "code" => code = Some(value.to_owned()), + _ => (), + } + } + let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?; + let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?; + + let mut redirect_uri = Parts::default(); + redirect_uri.scheme = Some(Scheme::HTTP); + redirect_uri.path_and_query = Some(PathAndQuery::from_str(&state).unwrap()); + redirect_uri.authority = Authority::from_str( + request + .headers() + .get(HOST) + .ok_or(ServiceError::InvalidHeader)? + .to_str()?, + ) + .ok(); + let redirect_uri = Uri::from_parts(redirect_uri) + .map_err(|_| ServiceError::InvalidUri)? + .to_string(); + + token_request(&self.provider, &self.client_id, &redirect_uri, &code).await?; + + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + format!("state={state:?}\ncode={code:?}").map_err(|_| unreachable!()), + )); + r.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + return Ok(r); + } else { + let mut redirect_uri = Parts::default(); + redirect_uri.scheme = Some(Scheme::HTTP); + redirect_uri.path_and_query = + Some(PathAndQuery::from_static("/_gnix_auth_callback")); + redirect_uri.authority = Authority::from_str( + request + .headers() + .get(HOST) + .ok_or(ServiceError::InvalidHeader)? + .to_str()?, + ) + .ok(); + let redirect_uri = Uri::from_parts(redirect_uri) + .map_err(|_| ServiceError::InvalidUri)? + .to_string(); + + let chal: Vec<u8> = { + let mut hasher = Sha256::new(); + hasher.update(r"testvalue"); + hasher.finalize().to_vec() + }; + + let uri = format!( + "{}/authorize?client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid profile email", + self.provider, + utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), + utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), + utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), + base64::engine::general_purpose::URL_SAFE.encode(chal), + ); + info!("redirect {uri:?}"); + let mut resp = + Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); + *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; + resp.headers_mut().insert( + "Location", + HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?, + ); + Ok(resp) + } + }) + } +} + +async fn token_request( + provider: &str, + client_id: &str, + redirect_uri: &str, + code: &str, +) -> Result<(), ServiceError> { + let url = Uri::from_str(&format!("{provider}/token")).unwrap(); + let body = format!( + "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code", + utf8_percent_encode(client_id, NON_ALPHANUMERIC), + utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC), + utf8_percent_encode(code, NON_ALPHANUMERIC), + "testvalue" + ); + let authority = url.authority().unwrap().clone(); + + let stream = TcpStream::connect(authority.as_str()).await.unwrap(); + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let req = Request::builder() + .method(Method::POST) + .uri(url) + .header(HOST, authority.as_str()) + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(body) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + let body = res.collect().await.unwrap().aggregate(); + let mut buf = String::new(); + body.reader().read_to_string(&mut buf).unwrap(); + eprintln!("{buf:?}"); + Ok(()) +} |