/* This file is part of gnix (https://codeberg.org/metamuffin/gnix) which is licensed under the GNU Affero General Public License (version 3); see /COPYING. Copyright (C) 2025 metamuffin */ use crate::{ config::DynNode, error::ServiceError, modules::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}, }; use aes_gcm_siv::{ aead::{Aead, Payload}, Nonce, }; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use bytes::Buf; use futures::Future; use headers::{Cookie, HeaderMapExt}; use http::{ header::{CONTENT_TYPE, HOST, LOCATION, SET_COOKIE}, uri::{Authority, Parts, PathAndQuery, Scheme}, HeaderValue, Method, Request, Uri, }; use http_body_util::BodyExt; use hyper::{Response, StatusCode}; use hyper_util::rt::TokioIo; use log::{debug, info, warn}; use percent_encoding::{ percent_decode, percent_decode_str, percent_encode, utf8_percent_encode, NON_ALPHANUMERIC, }; use rand::random; use rustls::RootCertStore; use serde::Deserialize; use serde_yml::Value; use sha2::{Digest, Sha256}; use std::{collections::HashSet, io::Read, pin::Pin, str::FromStr, sync::Arc, time::SystemTime}; use tokio::net::TcpStream; use webpki::types::ServerName; pub struct OpenIDAuthKind; impl NodeKind for OpenIDAuthKind { fn name(&self) -> &'static str { "openid_auth" } fn instanciate(&self, config: Value) -> anyhow::Result> { Ok(Arc::new(serde_yml::from_value::(config)?)) } } #[derive(Deserialize)] pub struct OpenIDAuth { salt: String, client_id: String, client_secret: String, authorize_endpoint: String, token_endpoint: String, scope: String, authorized_emails: HashSet, next: DynNode, } impl Node for OpenIDAuth { fn handle<'a>( &'a self, context: &'a mut NodeContext, request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { if let Some(cookie) = request.headers().typed_get::() { if let Some(auth) = cookie.get("gnix_oauth") { let username = percent_decode_str(cookie.get("gnix_oauth_email").unwrap_or("default")) .decode_utf8()?; let auth = BASE64_URL_SAFE_NO_PAD.decode(auth)?; if auth.len() < 12 { return Err(ServiceError::BadAuth); } let (msg, nonce) = auth.split_at(auth.len() - 12); let plaintext = context.state.crypto_key.decrypt( Nonce::from_slice(nonce), Payload { msg, aad: username.as_bytes(), }, ); if let Ok(plaintext) = plaintext { if let Some(expire) = plaintext.strip_prefix(self.salt.as_bytes()) { if let Some(expire) = expire.strip_prefix(&[0]) { let expire = u64::from_be_bytes(expire[0..8].try_into().unwrap()); if expire >= unix_seconds() { return self.next.handle(context, request).await; } else { debug!("auth expired"); } } else { return Err(ServiceError::CustomStatic("salt sep invalid")); } } else { warn!("salt invalid"); } } else { debug!("aead invalid"); } } else { debug!("no auth cookie"); } } if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" { let mut state = None; let mut code = None; for entry in request.uri().query().unwrap_or_default().split("&") { let (key, value) = entry.split_once("=").unwrap_or((entry, "")); match key { "state" => { state = Some( percent_decode(value.as_bytes()) .decode_utf8_lossy() .to_string(), ) } "code" => code = Some(value.to_owned()), _ => (), } } let state = state.ok_or(ServiceError::CustomStatic("state parameter missing"))?; let code = code.ok_or(ServiceError::CustomStatic("code parameter missing"))?; let (verif_cipher, return_path) = state .split_once("_") .ok_or(ServiceError::CustomStatic("state malformed"))?; let verif_cipher = hex::decode(verif_cipher) .map_err(|_| ServiceError::CustomStatic("invalid hex in code verifier"))?; if verif_cipher.len() < 12 { return Err(ServiceError::BadAuth); } let verif_plain = { let (msg, nonce) = verif_cipher.split_at(verif_cipher.len() - 12); let verif_plain = context .state .crypto_key .decrypt(Nonce::from_slice(nonce), Payload { msg, aad: &[] }) .map_err(|_| ServiceError::CustomStatic("authentication invalid"))?; String::from_utf8(verif_plain)? }; let redirect_uri = redirect_uri(&request)?.to_string(); let resp = token_request( &self.token_endpoint, &self.client_id, &self.client_secret, &redirect_uri, &code, &verif_plain, ) .await?; let jwt_pay = parse_jwt(&resp.id_token)?; if !self.authorized_emails.contains(&jwt_pay.email) { return Err(ServiceError::Unauthorized); } let nonce = [(); 12].map(|_| random::()); let mut plaintext = Vec::new(); plaintext.extend(self.salt.as_bytes()); plaintext.push(0); plaintext.extend(jwt_pay.exp.to_be_bytes()); let mut ciphertext = context .state .crypto_key .encrypt( Nonce::from_slice(&nonce), Payload { msg: &plaintext, aad: jwt_pay.email.as_bytes(), }, ) .unwrap(); ciphertext.extend(nonce); let auth = BASE64_URL_SAFE_NO_PAD.encode(ciphertext); let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().append( SET_COOKIE, HeaderValue::from_str(&format!( "gnix_oauth_email={}; Secure", percent_encode(jwt_pay.email.as_bytes(), NON_ALPHANUMERIC) )) .map_err(|_| ServiceError::InvalidHeader)?, ); resp.headers_mut().append( SET_COOKIE, HeaderValue::from_str(&format!("gnix_oauth={auth}; Secure")) .map_err(|_| ServiceError::InvalidHeader)?, ); resp.headers_mut().insert( LOCATION, HeaderValue::from_str(&return_path).map_err(|_| ServiceError::InvalidHeader)?, ); Ok(resp) } else if request.method() == Method::GET && request.uri().path() == "/favicon.ico" { let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::NO_CONTENT; Ok(resp) } else { let (chal, verif_cipher): (Vec, Vec) = { let r = [(); 32].map(|()| random::()); let r = BASE64_URL_SAFE_NO_PAD.encode(r); let r = r.as_bytes(); let nonce = [(); 12].map(|_| random::()); let mut v = context .state .crypto_key .encrypt(Nonce::from_slice(&nonce), Payload { msg: r, aad: &[] }) .unwrap(); v.extend(nonce); let mut hasher = Sha256::new(); hasher.update(r); (hasher.finalize().to_vec(), v) }; let redirect_uri = redirect_uri(&request)?.to_string(); let uri = format!( "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope={}", self.authorize_endpoint, utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC), utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC), hex::encode(verif_cipher), utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), BASE64_URL_SAFE_NO_PAD.encode(chal), utf8_percent_encode(&self.scope, NON_ALPHANUMERIC), ); info!("redirect {uri:?}"); let mut resp = Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; resp.headers_mut().insert( LOCATION, HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?, ); Ok(resp) } }) } } fn redirect_uri(request: &NodeRequest) -> Result { let mut redirect_uri = Parts::default(); redirect_uri.scheme = Some(Scheme::HTTPS); redirect_uri.path_and_query = Some(PathAndQuery::from_static("/_gnix_auth_callback")); redirect_uri.authority = Authority::from_str( request .headers() .get(HOST) .ok_or(ServiceError::InvalidHeader)? .to_str()?, ) .ok(); Uri::from_parts(redirect_uri).map_err(|_| ServiceError::InvalidUri) } async fn token_request( endpoint: &str, client_id: &str, client_secret: &str, redirect_uri: &str, code: &str, verifier: &str, ) -> Result { let url = Uri::from_str(endpoint).unwrap(); let body = format!( "client_id={}&client_secret={}&redirect_uri={}&code={}&code_verifier={}&grant_type=authorization_code", utf8_percent_encode(client_id, NON_ALPHANUMERIC), utf8_percent_encode(client_secret, NON_ALPHANUMERIC), utf8_percent_encode(redirect_uri, NON_ALPHANUMERIC), utf8_percent_encode(code, NON_ALPHANUMERIC), utf8_percent_encode(verifier, NON_ALPHANUMERIC), ); info!("token {url} {body:?}"); let authority = url.authority().unwrap().clone(); eprintln!("connect {}", authority.as_str()); let use_tls = url.scheme() == Some(&Scheme::HTTPS); let stream = TcpStream::connect(format!( "{}:{}", authority.host(), authority .port_u16() .unwrap_or(if use_tls { 443 } else { 80 }) )) .await .map_err(|_| ServiceError::CustomStatic("token request connect failed"))?; let config = rustls::ClientConfig::builder() .with_root_certificates(RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), }) .with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); let name = ServerName::try_from(authority.host().to_owned()).unwrap(); let stream = connector.connect(name, stream).await.unwrap(); let io = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::handshake(io) .await .map_err(|_| ServiceError::CustomStatic("token request handshake failed"))?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); } }); let req = Request::builder() .method(Method::POST) .uri(url) .header(HOST, authority.as_str()) .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body(body) .unwrap(); let res = sender.send_request(req).await.unwrap(); let body = res.collect().await.unwrap().aggregate(); let mut buf = String::new(); body.reader().read_to_string(&mut buf).unwrap(); serde_json::from_str(&buf).map_err(|_| ServiceError::CustomStatic("invalid token response")) } fn parse_jwt(s: &str) -> Result { let (header, rest) = s .split_once(".") .ok_or(ServiceError::CustomStatic("jwt invalid format"))?; let (payload, signature) = rest .split_once(".") .ok_or(ServiceError::CustomStatic("jwt invalid format"))?; let header: JwtHeader = serde_json::from_slice(&BASE64_URL_SAFE_NO_PAD.decode(header)?) .map_err(|_| ServiceError::CustomStatic("jwt invalid header"))?; let payload: JwtPayload = serde_json::from_slice(&BASE64_URL_SAFE_NO_PAD.decode(payload)?) .map_err(|_| ServiceError::CustomStatic("jwt invalid payload"))?; if header.typ != "JWT" { return Err(ServiceError::CustomStatic("jwt type is not jwt (duh)")); } let _ = signature; Ok(payload) } #[derive(Debug, Deserialize)] struct JwtHeader { #[allow(unused)] alg: String, typ: String, } #[derive(Debug, Deserialize)] struct JwtPayload { email: String, exp: u64, } #[derive(Debug, Deserialize)] struct OAuthTokenResponse { #[allow(unused)] access_token: String, #[allow(unused)] expires_in: i64, #[allow(unused)] token_type: String, id_token: String, } fn unix_seconds() -> u64 { SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs() }