use crate::{ config::DynNode, error::ServiceError, modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}, }; use base64::Engine; use bytes::Buf; use futures::Future; use http::{ header::{CONTENT_TYPE, HOST}, uri::{Authority, Parts, PathAndQuery, Scheme}, HeaderValue, Method, Request, Uri, }; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{Response, StatusCode}; use hyper_util::rt::TokioIo; use log::info; use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC}; use serde::Deserialize; use serde_yaml::Value; use sha2::{Digest, Sha256}; use std::{io::Read, pin::Pin, str::FromStr, sync::Arc}; use tokio::net::TcpStream; pub struct OpenIDAuthKind; impl NodeKind for OpenIDAuthKind { fn name(&self) -> &'static str { "openid_auth" } fn instanciate(&self, config: Value) -> anyhow::Result> { Ok(Arc::new(serde_yaml::from_value::(config)?)) } } #[derive(Deserialize)] pub struct OpenIDAuth { client_id: String, provider: String, next: DynNode, } impl Node for OpenIDAuth { fn handle<'a>( &'a self, context: &'a mut NodeContext, request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" { let mut state = None; let mut code = None; for entry in request.uri().query().unwrap_or_default().split("&") { let (key, value) = entry.split_once("=").unwrap_or((entry, "")); match key { "state" => { state = Some( percent_decode(value.as_bytes()) .decode_utf8_lossy() .to_string(), ) } "code" => code = Some(value.to_owned()), _ => (), } } let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?; let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?; let mut redirect_uri = Parts::default(); redirect_uri.scheme = Some(Scheme::HTTP); redirect_uri.path_and_query = Some(PathAndQuery::from_str(&state).unwrap()); redirect_uri.authority = Authority::from_str( request .headers() .get(HOST) .ok_or(ServiceError::InvalidHeader)? .to_str()?, ) .ok(); let redirect_uri = Uri::from_parts(redirect_uri) .map_err(|_| ServiceError::InvalidUri)? .to_string(); token_request(&self.provider, &self.client_id, &redirect_uri, &code).await?; let mut r = Response::new(BoxBody::<_, ServiceError>::new( format!("state={state:?}\ncode={code:?}").map_err(|_| unreachable!()), )); r.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); return Ok(r); } else { let mut redirect_uri = Parts::default(); redirect_uri.scheme = Some(Scheme::HTTP); redirect_uri.path_and_query = Some(PathAndQuery::from_static("/_gnix_auth_callback")); redirect_uri.authority = Authority::from_str( request .headers() .get(HOST) .ok_or(ServiceError::InvalidHeader)? .to_str()?, ) .ok(); let redirect_uri = Uri::from_parts(redirect_uri) .map_err(|_| ServiceError::InvalidUri)? .to_string(); let chal: Vec = { let mut hasher = Sha256::new(); hasher.update(r"testvalue"); hasher.finalize().to_vec() }; let uri = format!( "{}/authorize?client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid profile email", self.provider, utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), base64::engine::general_purpose::URL_SAFE.encode(chal), ); info!("redirect {uri:?}"); let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().insert( "Location", HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?, ); Ok(resp) } }) } } async fn token_request( provider: &str, client_id: &str, redirect_uri: &str, code: &str, ) -> Result<(), ServiceError> { let url = Uri::from_str(&format!("{provider}/token")).unwrap(); let body = format!( "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code", utf8_percent_encode(client_id, NON_ALPHANUMERIC), utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC), utf8_percent_encode(code, NON_ALPHANUMERIC), "testvalue" ); let authority = url.authority().unwrap().clone(); let stream = TcpStream::connect(authority.as_str()).await.unwrap(); let io = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); } }); let req = Request::builder() .method(Method::POST) .uri(url) .header(HOST, authority.as_str()) .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body(body) .unwrap(); let res = sender.send_request(req).await.unwrap(); let body = res.collect().await.unwrap().aggregate(); let mut buf = String::new(); body.reader().read_to_string(&mut buf).unwrap(); eprintln!("{buf:?}"); Ok(()) }