diff options
Diffstat (limited to 'src/filters/auth.rs')
-rw-r--r-- | src/filters/auth.rs | 86 |
1 files changed, 55 insertions, 31 deletions
diff --git a/src/filters/auth.rs b/src/filters/auth.rs index 92a9ba3..7d5b03e 100644 --- a/src/filters/auth.rs +++ b/src/filters/auth.rs @@ -1,41 +1,65 @@ -use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use super::{Node, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; use base64::Engine; +use futures::Future; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, Response, StatusCode, }; use log::debug; -use std::ops::ControlFlow; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashSet, pin::Pin, sync::Arc}; -pub fn http_basic( - config: &HttpBasicAuthConfig, - req: &FilterRequest, - resp: &mut FilterResponseOut, -) -> Result<ControlFlow<()>, ServiceError> { - if let Some(auth) = req.headers().get(AUTHORIZATION) { - let k = auth - .as_bytes() - .strip_prefix(b"Basic ") - .ok_or(ServiceError::BadAuth)?; - let k = base64::engine::general_purpose::STANDARD.decode(k)?; - let k = String::from_utf8(k)?; - if config.valid.contains(&k) { - debug!("valid auth"); - return Ok(ControlFlow::Continue(())); - } else { - debug!("invalid auth"); - } +pub struct HttpBasicAuthKind; +impl NodeKind for HttpBasicAuthKind { + fn name(&self) -> &'static str { + "http_basic_auth" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn super::Node>> { + Ok(Arc::new(serde_yaml::from_value::<HttpBasicAuth>(config)?)) + } +} + +#[derive(Deserialize)] +pub struct HttpBasicAuth { + realm: String, + valid: HashSet<String>, + next: DynNode, +} + +impl Node for HttpBasicAuth { + fn handle<'a>( + &'a self, + context: &'a mut super::NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + if let Some(auth) = request.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if self.valid.contains(&k) { + debug!("valid auth"); + return self.next.handle(context, request).await; + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", self.realm)).unwrap(), + ); + Ok(r) + }) } - debug!("unauthorized; sending auth challenge"); - let mut r = Response::new(BoxBody::<_, ServiceError>::new( - String::new().map_err(|_| unreachable!()), - )); - *r.status_mut() = StatusCode::UNAUTHORIZED; - r.headers_mut().insert( - WWW_AUTHENTICATE, - HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), - ); - *resp = Some(r); - Ok(ControlFlow::Break(())) } |