use crate::{ config::DynNode, error::ServiceError, modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}, }; use aes_gcm_siv::{ aead::{Aead, Payload}, Nonce, }; use base64::Engine; use bytes::Buf; use futures::Future; use http::{ header::{CONTENT_TYPE, HOST}, uri::{Authority, Parts, PathAndQuery, Scheme}, HeaderValue, Method, Request, Uri, }; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{Response, StatusCode}; use hyper_util::rt::TokioIo; use log::info; use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC}; use rand::random; use serde::Deserialize; use serde_yml::Value; use sha2::{Digest, Sha256}; use std::{io::Read, pin::Pin, str::FromStr, sync::Arc}; use tokio::net::TcpStream; pub struct OpenIDAuthKind; impl NodeKind for OpenIDAuthKind { fn name(&self) -> &'static str { "openid_auth" } fn instanciate(&self, config: Value) -> anyhow::Result> { Ok(Arc::new(serde_yml::from_value::(config)?)) } } #[derive(Deserialize)] pub struct OpenIDAuth { client_id: String, authorize_endpoint: String, token_endpoint: String, scope: String, #[allow(unused)] next: DynNode, } impl Node for OpenIDAuth { fn handle<'a>( &'a self, context: &'a mut NodeContext, request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" { let mut state = None; let mut code = None; for entry in request.uri().query().unwrap_or_default().split("&") { let (key, value) = entry.split_once("=").unwrap_or((entry, "")); match key { "state" => { state = Some( percent_decode(value.as_bytes()) .decode_utf8_lossy() .to_string(), ) } "code" => code = Some(value.to_owned()), _ => (), } } let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?; let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?; let (verif_cipher, return_path) = state .split_once("_") .ok_or(ServiceError::CustomStatic("state malformed"))?; let verif_cipher = hex::decode(verif_cipher) .map_err(|_| ServiceError::CustomStatic("invalid hex in code verifier"))?; if verif_cipher.len() < 12 { return Err(ServiceError::BadAuth); } let verif_plain = { let (msg, nonce) = verif_cipher.split_at(verif_cipher.len() - 12); let verif_plain = context .state .crypto_key .decrypt(Nonce::from_slice(nonce), Payload { msg, aad: &[] }) .map_err(|_| ServiceError::CustomStatic("authentication invalid"))?; String::from_utf8(verif_plain)? }; let redirect_uri = redirect_uri(&request)?.to_string(); let OAuthTokenResponse { access_token, expires_in, token_type, id_token, } = token_request( &self.token_endpoint, &self.client_id, &redirect_uri, &code, &verif_plain, ) .await?; let mut r = Response::new(BoxBody::<_, ServiceError>::new( format!( r#"Response: state={state:?} code={code:?} return_path={return_path:?} access_token={access_token:?} token_type={token_type:?} expires_in={expires_in:?} id_token={id_token:?}"# ) .map_err(|_| unreachable!()), )); r.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); return Ok(r); } if request.method() == Method::GET && request.uri().path() == "/favicon.ico" { let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::NO_CONTENT; Ok(resp) } else { let (chal, verif_cipher): (Vec, Vec) = { let r = [(); 16].map(|()| random::()); let r = base64::engine::general_purpose::URL_SAFE.encode(r); let r = r.as_bytes(); let nonce = [(); 12].map(|_| random::()); // gcm and siv are overkill but its fine let mut v = context .state .crypto_key .encrypt(Nonce::from_slice(&nonce), Payload { msg: r, aad: &[] }) .unwrap(); v.extend(nonce); let mut hasher = Sha256::new(); hasher.update(r); (hasher.finalize().to_vec(), v) }; let redirect_uri = redirect_uri(&request)?.to_string(); let uri = format!( "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid {}", self.authorize_endpoint, utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), hex::encode(verif_cipher), utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), base64::engine::general_purpose::URL_SAFE.encode(chal).trim_end_matches("="), utf8_percent_encode(&self.scope, NON_ALPHANUMERIC), ); info!("redirect {uri:?}"); let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().insert( "Location", HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?, ); Ok(resp) } }) } } fn redirect_uri(request: &NodeRequest) -> Result { let mut redirect_uri = Parts::default(); redirect_uri.scheme = Some(Scheme::HTTP); redirect_uri.path_and_query = Some(PathAndQuery::from_static("/_gnix_auth_callback")); redirect_uri.authority = Authority::from_str( request .headers() .get(HOST) .ok_or(ServiceError::InvalidHeader)? .to_str()?, ) .ok(); Uri::from_parts(redirect_uri).map_err(|_| ServiceError::InvalidUri) } async fn token_request( provider: &str, client_id: &str, redirect_uri: &str, code: &str, verifier: &str, ) -> Result { let url = Uri::from_str(provider).unwrap(); let body = format!( "client_id={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code", utf8_percent_encode(client_id, NON_ALPHANUMERIC), utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC), utf8_percent_encode(code, NON_ALPHANUMERIC), utf8_percent_encode(verifier, NON_ALPHANUMERIC), ); info!("validate {url} {body:?}"); let authority = url.authority().unwrap().clone(); let stream = TcpStream::connect(authority.as_str()).await.unwrap(); let io = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); } }); let req = Request::builder() .method(Method::POST) .uri(url) .header(HOST, authority.as_str()) .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body(body) .unwrap(); let res = sender.send_request(req).await.unwrap(); let body = res.collect().await.unwrap().aggregate(); let mut buf = String::new(); body.reader().read_to_string(&mut buf).unwrap(); eprintln!("{buf}"); serde_json::from_str(&buf) .map_err(|_| ServiceError::CustomStatic("invalid token response")) } #[derive(Debug, Deserialize)] struct OAuthTokenResponse { access_token: String, expires_in: i64, token_type: String, id_token: String, }