From 3b1afad1d1a697e82c003e146ef2b7d5742e5210 Mon Sep 17 00:00:00 2001 From: metamuffin Date: Tue, 14 Nov 2023 11:54:01 +0100 Subject: refactor architecture and start on http basic auth --- src/auth.rs | 41 ++++++++++++++++++++++++++++++ src/config.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++----- src/error.rs | 13 +++++++++- src/files.rs | 4 ++- src/main.rs | 69 +++++++++++++++++++++++++++++++++++++++---------- src/proxy.rs | 2 +- 6 files changed, 189 insertions(+), 22 deletions(-) create mode 100644 src/auth.rs (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..92a9ba3 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,41 @@ +use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use base64::Engine; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{ + header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, + Response, StatusCode, +}; +use log::debug; +use std::ops::ControlFlow; + +pub fn http_basic( + config: &HttpBasicAuthConfig, + req: &FilterRequest, + resp: &mut FilterResponseOut, +) -> Result, ServiceError> { + if let Some(auth) = req.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if config.valid.contains(&k) { + debug!("valid auth"); + return Ok(ControlFlow::Continue(())); + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), + ); + *resp = Some(r); + Ok(ControlFlow::Break(())) +} diff --git a/src/config.rs b/src/config.rs index c686930..a7abf3b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,14 @@ use serde::{ de::{value, Error, SeqAccess, Visitor}, Deserialize, Deserializer, Serialize, }; -use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf}; +use std::{ + collections::{HashMap, HashSet}, + fmt, + fs::read_to_string, + marker::PhantomData, + net::SocketAddr, + path::PathBuf, +}; #[derive(Debug, Serialize, Deserialize)] pub struct Config { @@ -12,7 +19,7 @@ pub struct Config { #[serde(default)] pub limits: Limits, #[serde(default)] - pub hosts: HashMap, + pub hosts: HashMap, } #[derive(Debug, Serialize, Deserialize)] @@ -37,10 +44,28 @@ pub struct HttpsConfig { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum HostConfig { - Backend { backend: SocketAddr }, - Files { files: FileserverConfig }, +pub struct Route(#[serde(deserialize_with = "seq_or_not")] pub Vec); + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RouteFilter { + HttpBasicAuth { + #[serde(flatten)] + config: HttpBasicAuthConfig, + }, + Proxy { + backend: SocketAddr, + }, + Files { + #[serde(flatten)] + config: FileserverConfig, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HttpBasicAuthConfig { + pub realm: String, + pub valid: HashSet, } #[derive(Debug, Serialize, Deserialize)] @@ -50,6 +75,51 @@ pub struct FileserverConfig { pub index: bool, } +// try deser Vec but fall back to deser T and putting that in Vec +fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct SeqOrNot(PhantomData); + impl<'de, V: Deserialize<'de>> Visitor<'de> for SeqOrNot { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence or not a sequence") + } + fn visit_enum(self, data: A) -> Result + where + A: serde::de::EnumAccess<'de>, + { + Ok(vec![V::deserialize(value::EnumAccessDeserializer::new( + data, + ))?]) + } + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + Ok(vec![V::deserialize(value::MapAccessDeserializer::new( + map, + ))?]) + } + fn visit_str(self, val: &str) -> Result, E> + where + E: Error, + { + Ok(vec![V::deserialize(value::StrDeserializer::new(val))?]) + } + + fn visit_seq(self, val: A) -> Result, A::Error> + where + A: SeqAccess<'de>, + { + Vec::::deserialize(value::SeqAccessDeserializer::new(val)) + } + } + des.deserialize_any(SeqOrNot::(PhantomData)) +} + // fall back to expecting a single string and putting that in a 1-length vector fn string_or_seq<'de, D>(des: D) -> Result, D::Error> where diff --git a/src/error.rs b/src/error.rs index 83a1ffa..b78999f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,11 @@ use tokio::sync::TryAcquireError; #[derive(Debug, thiserror::Error)] pub enum ServiceError { + #[error("no response generated; the proxy is misconfigured")] + NoResponse, + #[error("request taken; the proxy is misconfigured")] + RequestTaken, + #[error("limit reached. try again")] Limit(#[from] TryAcquireError), #[error("hyper error")] @@ -18,8 +23,14 @@ pub enum ServiceError { BadRange, #[error("bad utf8")] BadUtf8(#[from] std::str::Utf8Error), + #[error("bad utf8")] + BadUtf82(#[from] std::string::FromUtf8Error), #[error("bad path")] BadPath, - #[error("ohh. i didn't expect that this error can be generated.")] + #[error("bad auth")] + BadAuth, + #[error("bad base64: {0}")] + BadBase64(#[from] base64::DecodeError), + #[error("impossible error")] Other, } diff --git a/src/files.rs b/src/files.rs index 68a3807..cf53942 100644 --- a/src/files.rs +++ b/src/files.rs @@ -21,7 +21,7 @@ use tokio::{ use tokio_util::io::poll_read_buf; pub async fn serve_files( - req: Request, + req: &Request, config: &FileserverConfig, ) -> Result>, ServiceError> { let rpath = req.uri().path(); @@ -56,6 +56,7 @@ pub async fn serve_files( if !config.index { return Err(ServiceError::NotFound); } + debug!("sending index for {path:?}"); if !rpath.ends_with("/") { let mut r = Response::new(String::new()); @@ -78,6 +79,7 @@ pub async fn serve_files( let range = req.headers().typed_get::(); let range = bytes_range(range, metadata.len())?; + debug!("sending file {path:?}"); let file = File::open(path.clone()).await?; let mut r = Response::new(BoxBody::new(StreamBody::new( diff --git a/src/main.rs b/src/main.rs index 99374eb..059bdf0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ #![feature(try_trait_v2)] #![feature(exclusive_range_pattern)] +#![feature(slice_split_once)] +pub mod auth; pub mod config; pub mod error; pub mod files; @@ -10,11 +12,12 @@ pub mod proxy; pub mod reporting; use crate::{ - config::{Config, HostConfig}, + config::{Config, RouteFilter}, files::serve_files, proxy::proxy_request, }; use anyhow::{anyhow, bail, Context, Result}; +use bytes::Bytes; use error::ServiceError; use futures::future::try_join_all; use helper::TokioIo; @@ -30,7 +33,10 @@ use hyper::{ use log::{debug, error, info, warn}; #[cfg(feature = "mond")] use reporting::Reporting; -use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc}; +use std::{ + fs::File, io::BufReader, net::SocketAddr, ops::ControlFlow, path::Path, process::exit, + sync::Arc, +}; use tokio::{net::TcpListener, signal::ctrl_c, sync::Semaphore}; use tokio_rustls::TlsAcceptor; @@ -42,6 +48,10 @@ pub struct State { pub reporting: Reporting, } +pub type FilterRequest = Request; +pub type FilterResponseOut = Option>>; +pub type FilterResponse = Option>>; + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init_from_env("LOG"); @@ -96,12 +106,13 @@ async fn serve_http(state: Arc) -> Result<()> { let listen_futs: Result> = try_join_all(http_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; + info!("HTTP listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); let stream = TokioIo(stream); - let config = state.clone(); - tokio::spawn(async move { serve_stream(config, stream, addr).await }); + let state = state.clone(); + tokio::spawn(async move { serve_stream(state, stream, addr).await }); } })) .await; @@ -131,9 +142,9 @@ async fn serve_https(state: Arc) -> Result<()> { Arc::new(cfg) }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); - let listen_futs: Result> = try_join_all(https_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; + info!("HTTPS listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; let state = state.clone(); @@ -148,8 +159,6 @@ async fn serve_https(state: Arc) -> Result<()> { } })) .await; - - info!("serving https"); listen_futs?; Ok(()) } @@ -170,11 +179,14 @@ pub async fn serve_stream Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ - let mut resp = - Response::new(format!("gnix encountered an issue: {error}")); + let mut resp = Response::new(format!( + "Sorry, we were unable to process your request: {error}" + )); *resp.status_mut() = StatusCode::BAD_REQUEST; resp.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + resp.headers_mut() + .insert(SERVER, HeaderValue::from_static("gnix")); resp } .map(|b| b.map_err(|e| match e {}).boxed())), @@ -226,10 +238,41 @@ async fn service( #[cfg(feature = "mond")] state.reporting.hosts.get(host).unwrap().requests_in.inc(); - let mut resp = match route { - HostConfig::Backend { backend } => proxy_request(&state, req, addr, backend).await, - HostConfig::Files { files } => serve_files(req, files).await, - }?; + let mut req = Some(req); + let mut resp = None; + for filter in &route.0 { + let cf = match filter { + RouteFilter::Proxy { backend } => { + resp = Some( + proxy_request( + &state, + req.take().ok_or(ServiceError::RequestTaken)?, + addr, + backend, + ) + .await?, + ); + ControlFlow::Continue(()) + } + RouteFilter::Files { config } => { + resp = Some( + serve_files(req.as_ref().ok_or(ServiceError::RequestTaken)?, config).await?, + ); + ControlFlow::Continue(()) + } + RouteFilter::HttpBasicAuth { config } => auth::http_basic( + config, + req.as_ref().ok_or(ServiceError::RequestTaken)?, + &mut resp, + )?, + }; + match cf { + ControlFlow::Continue(_) => continue, + ControlFlow::Break(_) => break, + } + } + + let mut resp = resp.ok_or(ServiceError::NoResponse)?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( diff --git a/src/proxy.rs b/src/proxy.rs index e8d2467..40ebf17 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -77,7 +77,7 @@ pub async fn proxy_request( if do_upgrade { let on_upgrade_upstream = resp.extensions_mut().remove::(); tokio::task::spawn(async move { - debug!("about upgrading connection, sending empty response"); + debug!("about to upgrade connection, sending empty response"); match ( on_upgrade_upstream.unwrap().await, on_upgrade_downstream.unwrap().await, -- cgit v1.2.3-70-g09d2