use crate::{ config::DynNode, error::ServiceError, modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}, }; use aes_gcm_siv::{ aead::{Aead, Payload}, Nonce, }; use base64::Engine; use bytes::Buf; use futures::Future; use http::{ header::{CONTENT_TYPE, HOST}, uri::{Authority, Parts, PathAndQuery, Scheme}, HeaderValue, Method, Request, Uri, }; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{Response, StatusCode}; use hyper_util::rt::TokioIo; use log::info; use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC}; use rand::random; use serde::Deserialize; use serde_yaml::Value; use sha2::{Digest, Sha256}; use std::{io::Read, pin::Pin, str::FromStr, sync::Arc}; use tokio::net::TcpStream; pub struct OpenIDAuthKind; impl NodeKind for OpenIDAuthKind { fn name(&self) -> &'static str { "openid_auth" } fn instanciate(&self, config: Value) -> anyhow::Result> { Ok(Arc::new(serde_yaml::from_value::(config)?)) } } #[derive(Deserialize)] pub struct OpenIDAuth { client_id: String, provider: String, next: DynNode, } impl Node for OpenIDAuth { fn handle<'a>( &'a self, context: &'a mut NodeContext, request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" { let mut state = None; let mut code = None; for entry in request.uri().query().unwrap_or_default().split("&") { let (key, value) = entry.split_once("=").unwrap_or((entry, "")); match key { "state" => { state = Some( percent_decode(value.as_bytes()) .decode_utf8_lossy() .to_string(), ) } "code" => code = Some(value.to_owned()), _ => (), } } let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?; let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?; let (verif_cipher, return_path) = state .split_once("_") .ok_or(ServiceError::CustomStatic("state malformed"))?; let verif_cipher = hex::decode(verif_cipher) .map_err(|_| ServiceError::CustomStatic("invalid hex in code verifier"))?; if verif_cipher.len() < 12 { return Err(ServiceError::BadAuth); } let verif_plain = { let (msg, nonce) = verif_cipher.split_at(verif_cipher.len() - 12); let verif_plain = context .state .crypto_key .decrypt(Nonce::from_slice(nonce), Payload { msg, aad: &[] }) .map_err(|_| ServiceError::CustomStatic("authentication invalid"))?; String::from_utf8(verif_plain)? }; let redirect_uri = redirect_uri(&request)?.to_string(); token_request( &self.provider, &self.client_id, &redirect_uri, &code, &verif_plain, ) .await?; let mut r = Response::new(BoxBody::<_, ServiceError>::new( format!("state={state:?}\ncode={code:?}\nreturn_path={return_path:?}") .map_err(|_| unreachable!()), )); r.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); return Ok(r); } if request.method() == Method::GET && request.uri().path() == "/favicon.ico" { let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::NO_CONTENT; Ok(resp) } else { let (chal, verif_cipher): (Vec, Vec) = { let r = [(); 16].map(|()| random::()); let r = base64::engine::general_purpose::URL_SAFE.encode(r); let r = r.as_bytes(); let nonce = [(); 12].map(|_| random::()); // gcm and siv are overkill but its fine let mut v = context .state .crypto_key .encrypt(Nonce::from_slice(&nonce), Payload { msg: &r, aad: &[] }) .unwrap(); v.extend(nonce); let mut hasher = Sha256::new(); hasher.update(r); (hasher.finalize().to_vec(), v) }; let redirect_uri = redirect_uri(&request)?.to_string(); let uri = format!( "{}/authorize?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid profile email", self.provider, utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), hex::encode(verif_cipher), utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), base64::engine::general_purpose::URL_SAFE.encode(chal).trim_end_matches("="), ); info!("redirect {uri:?}"); let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().insert( "Location", HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?, ); Ok(resp) } }) } } fn redirect_uri(request: &NodeRequest) -> Result { let mut redirect_uri = Parts::default(); redirect_uri.scheme = Some(Scheme::HTTP); redirect_uri.path_and_query = Some(PathAndQuery::from_static("/_gnix_auth_callback")); redirect_uri.authority = Authority::from_str( request .headers() .get(HOST) .ok_or(ServiceError::InvalidHeader)? .to_str()?, ) .ok(); Ok(Uri::from_parts(redirect_uri).map_err(|_| ServiceError::InvalidUri)?) } async fn token_request( provider: &str, client_id: &str, redirect_uri: &str, code: &str, verifier: &str, ) -> Result<(), ServiceError> { let url = Uri::from_str(&format!("{provider}/token")).unwrap(); let body = format!( "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code", utf8_percent_encode(client_id, NON_ALPHANUMERIC), utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC), utf8_percent_encode(code, NON_ALPHANUMERIC), utf8_percent_encode(verifier, NON_ALPHANUMERIC), ); info!("validate {url} {body:?}"); let authority = url.authority().unwrap().clone(); let stream = TcpStream::connect(authority.as_str()).await.unwrap(); let io = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); } }); let req = Request::builder() .method(Method::POST) .uri(url) .header(HOST, authority.as_str()) .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body(body) .unwrap(); let res = sender.send_request(req).await.unwrap(); let body = res.collect().await.unwrap().aggregate(); let mut buf = String::new(); body.reader().read_to_string(&mut buf).unwrap(); eprintln!("{buf:?}"); Ok(()) }