diff options
Diffstat (limited to 'src/modules')
| -rw-r--r-- | src/modules/auth/openid.rs | 217 | 
1 files changed, 175 insertions, 42 deletions
diff --git a/src/modules/auth/openid.rs b/src/modules/auth/openid.rs index a6a6288..4bf0de7 100644 --- a/src/modules/auth/openid.rs +++ b/src/modules/auth/openid.rs @@ -12,25 +12,30 @@ use aes_gcm_siv::{      aead::{Aead, Payload},      Nonce,  }; -use base64::Engine; +use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};  use bytes::Buf;  use futures::Future; +use headers::{Cookie, HeaderMapExt};  use http::{ -    header::{CONTENT_TYPE, HOST}, +    header::{CONTENT_TYPE, HOST, LOCATION, SET_COOKIE},      uri::{Authority, Parts, PathAndQuery, Scheme},      HeaderValue, Method, Request, Uri,  }; -use http_body_util::{combinators::BoxBody, BodyExt}; +use http_body_util::BodyExt;  use hyper::{Response, StatusCode};  use hyper_util::rt::TokioIo; -use log::info; -use percent_encoding::{percent_decode, utf8_percent_encode, NON_ALPHANUMERIC}; +use log::{debug, info}; +use percent_encoding::{ +    percent_decode, percent_decode_str, percent_encode, utf8_percent_encode, NON_ALPHANUMERIC, +};  use rand::random; +use rustls::RootCertStore;  use serde::Deserialize;  use serde_yml::Value;  use sha2::{Digest, Sha256}; -use std::{io::Read, pin::Pin, str::FromStr, sync::Arc}; +use std::{io::Read, pin::Pin, str::FromStr, sync::Arc, time::SystemTime};  use tokio::net::TcpStream; +use webpki::types::ServerName;  pub struct OpenIDAuthKind;  impl NodeKind for OpenIDAuthKind { @@ -44,11 +49,11 @@ impl NodeKind for OpenIDAuthKind {  #[derive(Deserialize)]  pub struct OpenIDAuth { +    salt: String,      client_id: String,      authorize_endpoint: String,      token_endpoint: String,      scope: String, -    #[allow(unused)]      next: DynNode,  } @@ -59,6 +64,46 @@ impl Node for OpenIDAuth {          request: NodeRequest,      ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {          Box::pin(async move { +            if let Some(cookie) = request.headers().typed_get::<Cookie>() { +                if let Some(auth) = cookie.get("gnix_oauth") { +                    let username = +                        percent_decode_str(cookie.get("gnix_oauth_email").unwrap_or("default")) +                            .decode_utf8()?; + +                    let auth = BASE64_URL_SAFE_NO_PAD.decode(auth)?; +                    if auth.len() < 12 { +                        return Err(ServiceError::BadAuth); +                    } +                    let (msg, nonce) = auth.split_at(auth.len() - 12); +                    let plaintext = context.state.crypto_key.decrypt( +                        Nonce::from_slice(nonce), +                        Payload { +                            msg, +                            aad: username.as_bytes(), +                        }, +                    ); +                    if let Ok(plaintext) = plaintext { +                        if let Some(expire) = plaintext.strip_prefix(self.salt.as_bytes()) { +                            if let Some(expire) = expire.strip_prefix(&[0]) { +                                let expire = u64::from_be_bytes(expire[0..8].try_into().unwrap()); +                                if expire >= unix_seconds() { +                                    return self.next.handle(context, request).await; +                                } else { +                                    debug!("auth expired"); +                                } +                            } else { +                                return Err(ServiceError::CustomStatic("salt sep invalid")); +                            } +                        } else { +                            return Err(ServiceError::CustomStatic("salt invalid")); +                        } +                    } else { +                        return Err(ServiceError::CustomStatic("aead invalid")); +                    } +                } else { +                    debug!("no auth cookie"); +                } +            }              if request.method() == Method::GET && request.uri().path() == "/_gnix_auth_callback" {                  let mut state = None;                  let mut code = None; @@ -100,12 +145,7 @@ impl Node for OpenIDAuth {                  };                  let redirect_uri = redirect_uri(&request)?.to_string(); -                let OAuthTokenResponse { -                    access_token, -                    expires_in, -                    token_type, -                    id_token, -                } = token_request( +                let resp = token_request(                      &self.token_endpoint,                      &self.client_id,                      &redirect_uri, @@ -114,37 +154,60 @@ impl Node for OpenIDAuth {                  )                  .await?; -                let mut r = Response::new(BoxBody::<_, ServiceError>::new( -                    format!( -                        r#"Response: +                let jwt_pay = parse_jwt(&resp.id_token)?; -state={state:?} -code={code:?} -return_path={return_path:?} -access_token={access_token:?} -token_type={token_type:?} -expires_in={expires_in:?} -id_token={id_token:?}"# +                let nonce = [(); 12].map(|_| random::<u8>()); +                let mut plaintext = Vec::new(); +                plaintext.extend(self.salt.as_bytes()); +                plaintext.push(0); +                plaintext.extend(jwt_pay.exp.to_be_bytes()); +                let mut ciphertext = context +                    .state +                    .crypto_key +                    .encrypt( +                        Nonce::from_slice(&nonce), +                        Payload { +                            msg: &plaintext, +                            aad: jwt_pay.email.as_bytes(), +                        },                      ) -                    .map_err(|_| unreachable!()), -                )); -                r.headers_mut() -                    .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); -                return Ok(r); -            } -            if request.method() == Method::GET && request.uri().path() == "/favicon.ico" { +                    .unwrap(); +                ciphertext.extend(nonce); +                let auth = base64::engine::general_purpose::URL_SAFE.encode(ciphertext); + +                let mut resp = +                    Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()); +                *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT; +                resp.headers_mut().append( +                    SET_COOKIE, +                    HeaderValue::from_str(&format!( +                        "gnix_oauth_username={}; Secure", +                        percent_encode(jwt_pay.email.as_bytes(), NON_ALPHANUMERIC) +                    )) +                    .map_err(|_| ServiceError::InvalidHeader)?, +                ); +                resp.headers_mut().append( +                    SET_COOKIE, +                    HeaderValue::from_str(&format!("gnix_oauth={auth}; Secure")) +                        .map_err(|_| ServiceError::InvalidHeader)?, +                ); +                resp.headers_mut().insert( +                    LOCATION, +                    HeaderValue::from_str(&return_path).map_err(|_| ServiceError::InvalidHeader)?, +                ); +                Ok(resp) +            } else if request.method() == Method::GET && request.uri().path() == "/favicon.ico" {                  let mut resp =                      Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed());                  *resp.status_mut() = StatusCode::NO_CONTENT;                  Ok(resp)              } else {                  let (chal, verif_cipher): (Vec<u8>, Vec<u8>) = { -                    let r = [(); 16].map(|()| random::<u8>()); -                    let r = base64::engine::general_purpose::URL_SAFE.encode(r); +                    let r = [(); 32].map(|()| random::<u8>()); +                    let r = BASE64_URL_SAFE_NO_PAD.encode(r);                      let r = r.as_bytes();                      let nonce = [(); 12].map(|_| random::<u8>()); -                    // gcm and siv are overkill but its fine                      let mut v = context                          .state                          .crypto_key @@ -160,13 +223,13 @@ id_token={id_token:?}"#                  let redirect_uri = redirect_uri(&request)?.to_string();                  let uri = format!( -                    "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope=openid {}", +                    "{}?client_id={}&redirect_uri={}&state={}_{}&code_challenge={}&code_challenge_method=S256&response_type=code&scope={}",                      self.authorize_endpoint,                      utf8_percent_encode(&self.client_id, NON_ALPHANUMERIC),                      utf8_percent_encode(&redirect_uri, NON_ALPHANUMERIC),                      hex::encode(verif_cipher),                      utf8_percent_encode(&request.uri().to_string(), NON_ALPHANUMERIC), -                    base64::engine::general_purpose::URL_SAFE.encode(chal).trim_end_matches("="), +                    BASE64_URL_SAFE_NO_PAD.encode(chal),                      utf8_percent_encode(&self.scope, NON_ALPHANUMERIC),                  );                  info!("redirect {uri:?}"); @@ -174,7 +237,7 @@ id_token={id_token:?}"#                      Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed());                  *resp.status_mut() = StatusCode::TEMPORARY_REDIRECT;                  resp.headers_mut().insert( -                    "Location", +                    LOCATION,                      HeaderValue::from_str(&uri).map_err(|_| ServiceError::InvalidHeader)?,                  );                  Ok(resp) @@ -185,7 +248,7 @@ id_token={id_token:?}"#  fn redirect_uri(request: &NodeRequest) -> Result<Uri, ServiceError> {      let mut redirect_uri = Parts::default(); -    redirect_uri.scheme = Some(Scheme::HTTP); +    redirect_uri.scheme = Some(Scheme::HTTPS);      redirect_uri.path_and_query = Some(PathAndQuery::from_static("/_gnix_auth_callback"));      redirect_uri.authority = Authority::from_str(          request @@ -216,10 +279,34 @@ async fn token_request(      info!("validate {url} {body:?}");      let authority = url.authority().unwrap().clone(); -    let stream = TcpStream::connect(authority.as_str()).await.unwrap(); +    eprintln!("connect {}", authority.as_str()); + +    let use_tls = url.scheme() == Some(&Scheme::HTTPS); + +    let stream = TcpStream::connect(format!( +        "{}:{}", +        authority.host(), +        authority +            .port_u16() +            .unwrap_or(if use_tls { 443 } else { 80 }) +    )) +    .await +    .map_err(|_| ServiceError::CustomStatic("token request connect failed"))?; + +    let config = rustls::ClientConfig::builder() +        .with_root_certificates(RootCertStore { +            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), +        }) +        .with_no_client_auth(); +    let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); +    let name = ServerName::try_from(authority.host().to_owned()).unwrap(); +    let stream = connector.connect(name, stream).await.unwrap(); +      let io = TokioIo::new(stream); -    let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); +    let (mut sender, conn) = hyper::client::conn::http1::handshake(io) +        .await +        .map_err(|_| ServiceError::CustomStatic("token request handshake failed"))?;      tokio::task::spawn(async move {          if let Err(err) = conn.await {              println!("Connection failed: {:?}", err); @@ -238,15 +325,61 @@ async fn token_request(      let body = res.collect().await.unwrap().aggregate();      let mut buf = String::new();      body.reader().read_to_string(&mut buf).unwrap(); -    eprintln!("{buf}"); -    serde_json::from_str(&buf) -        .map_err(|_| ServiceError::CustomStatic("invalid token response")) + +    serde_json::from_str(&buf).map_err(|_| ServiceError::CustomStatic("invalid token response")) +} + +fn parse_jwt(s: &str) -> Result<JwtPayload, ServiceError> { +    let (header, rest) = s +        .split_once(".") +        .ok_or(ServiceError::CustomStatic("jwt invalid format"))?; +    let (payload, signature) = rest +        .split_once(".") +        .ok_or(ServiceError::CustomStatic("jwt invalid format"))?; + +    eprintln!("{header:?}"); +    let header: JwtHeader = serde_json::from_slice(&BASE64_URL_SAFE_NO_PAD.decode(header)?) +        .map_err(|_| ServiceError::CustomStatic("jwt invalid header"))?; +    eprintln!("{payload:?}"); +    let payload: JwtPayload = serde_json::from_slice(&BASE64_URL_SAFE_NO_PAD.decode(payload)?) +        .map_err(|_| ServiceError::CustomStatic("jwt invalid payload"))?; +    eprintln!("a"); + +    if header.typ != "JWT" { +        return Err(ServiceError::CustomStatic("jwt type is not jwt (duh)")); +    } + +    let _ = signature; + +    Ok(payload) +} + +#[derive(Debug, Deserialize)] +struct JwtHeader { +    #[allow(unused)] +    alg: String, +    typ: String, +} +#[derive(Debug, Deserialize)] +struct JwtPayload { +    email: String, +    exp: u64,  }  #[derive(Debug, Deserialize)]  struct OAuthTokenResponse { +    #[allow(unused)]      access_token: String, +    #[allow(unused)]      expires_in: i64, +    #[allow(unused)]      token_type: String,      id_token: String,  } + +fn unix_seconds() -> u64 { +    SystemTime::now() +        .duration_since(SystemTime::UNIX_EPOCH) +        .unwrap() +        .as_secs() +}  |