diff options
author | metamuffin <metamuffin@disroot.org> | 2023-11-14 11:54:01 +0100 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-11-14 11:54:01 +0100 |
commit | 3b1afad1d1a697e82c003e146ef2b7d5742e5210 (patch) | |
tree | 3a9e02470b4f78c4c34c0573c788da301a9e544e | |
parent | 4a7bd84594fb8d159a0a2af02818f283eab3e716 (diff) | |
download | gnix-3b1afad1d1a697e82c003e146ef2b7d5742e5210.tar gnix-3b1afad1d1a697e82c003e146ef2b7d5742e5210.tar.bz2 gnix-3b1afad1d1a697e82c003e146ef2b7d5742e5210.tar.zst |
refactor architecture and start on http basic auth
-rw-r--r-- | Cargo.lock | 7 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/auth.rs | 41 | ||||
-rw-r--r-- | src/config.rs | 82 | ||||
-rw-r--r-- | src/error.rs | 13 | ||||
-rw-r--r-- | src/files.rs | 4 | ||||
-rw-r--r-- | src/main.rs | 69 | ||||
-rw-r--r-- | src/proxy.rs | 2 |
8 files changed, 194 insertions, 25 deletions
@@ -61,9 +61,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "bincode" @@ -311,6 +311,7 @@ name = "gnix" version = "1.0.0" dependencies = [ "anyhow", + "base64 0.21.5", "bytes", "env_logger", "futures", @@ -841,7 +842,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", ] [[package]] @@ -11,6 +11,7 @@ hyper-util = "0.0.0" http-body-util = "0.1.0-rc.3" headers = "0.3.8" percent-encoding = "2.3.0" +base64 = "0.21.5" # TLS rustls-pemfile = "1.0.3" diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..92a9ba3 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,41 @@ +use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use base64::Engine; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{ + header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, + Response, StatusCode, +}; +use log::debug; +use std::ops::ControlFlow; + +pub fn http_basic( + config: &HttpBasicAuthConfig, + req: &FilterRequest, + resp: &mut FilterResponseOut, +) -> Result<ControlFlow<()>, ServiceError> { + if let Some(auth) = req.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if config.valid.contains(&k) { + debug!("valid auth"); + return Ok(ControlFlow::Continue(())); + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), + ); + *resp = Some(r); + Ok(ControlFlow::Break(())) +} diff --git a/src/config.rs b/src/config.rs index c686930..a7abf3b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,14 @@ use serde::{ de::{value, Error, SeqAccess, Visitor}, Deserialize, Deserializer, Serialize, }; -use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf}; +use std::{ + collections::{HashMap, HashSet}, + fmt, + fs::read_to_string, + marker::PhantomData, + net::SocketAddr, + path::PathBuf, +}; #[derive(Debug, Serialize, Deserialize)] pub struct Config { @@ -12,7 +19,7 @@ pub struct Config { #[serde(default)] pub limits: Limits, #[serde(default)] - pub hosts: HashMap<String, HostConfig>, + pub hosts: HashMap<String, Route>, } #[derive(Debug, Serialize, Deserialize)] @@ -37,10 +44,28 @@ pub struct HttpsConfig { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum HostConfig { - Backend { backend: SocketAddr }, - Files { files: FileserverConfig }, +pub struct Route(#[serde(deserialize_with = "seq_or_not")] pub Vec<RouteFilter>); + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RouteFilter { + HttpBasicAuth { + #[serde(flatten)] + config: HttpBasicAuthConfig, + }, + Proxy { + backend: SocketAddr, + }, + Files { + #[serde(flatten)] + config: FileserverConfig, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HttpBasicAuthConfig { + pub realm: String, + pub valid: HashSet<String>, } #[derive(Debug, Serialize, Deserialize)] @@ -50,6 +75,51 @@ pub struct FileserverConfig { pub index: bool, } +// try deser Vec<T> but fall back to deser T and putting that in Vec +fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result<Vec<V>, D::Error> +where + D: Deserializer<'de>, +{ + struct SeqOrNot<V>(PhantomData<V>); + impl<'de, V: Deserialize<'de>> Visitor<'de> for SeqOrNot<V> { + type Value = Vec<V>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence or not a sequence") + } + fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error> + where + A: serde::de::EnumAccess<'de>, + { + Ok(vec![V::deserialize(value::EnumAccessDeserializer::new( + data, + ))?]) + } + fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error> + where + A: serde::de::MapAccess<'de>, + { + Ok(vec![V::deserialize(value::MapAccessDeserializer::new( + map, + ))?]) + } + fn visit_str<E>(self, val: &str) -> Result<Vec<V>, E> + where + E: Error, + { + Ok(vec![V::deserialize(value::StrDeserializer::new(val))?]) + } + + fn visit_seq<A>(self, val: A) -> Result<Vec<V>, A::Error> + where + A: SeqAccess<'de>, + { + Vec::<V>::deserialize(value::SeqAccessDeserializer::new(val)) + } + } + des.deserialize_any(SeqOrNot::<V>(PhantomData)) +} + // fall back to expecting a single string and putting that in a 1-length vector fn string_or_seq<'de, D>(des: D) -> Result<Vec<SocketAddr>, D::Error> where diff --git a/src/error.rs b/src/error.rs index 83a1ffa..b78999f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,11 @@ use tokio::sync::TryAcquireError; #[derive(Debug, thiserror::Error)] pub enum ServiceError { + #[error("no response generated; the proxy is misconfigured")] + NoResponse, + #[error("request taken; the proxy is misconfigured")] + RequestTaken, + #[error("limit reached. try again")] Limit(#[from] TryAcquireError), #[error("hyper error")] @@ -18,8 +23,14 @@ pub enum ServiceError { BadRange, #[error("bad utf8")] BadUtf8(#[from] std::str::Utf8Error), + #[error("bad utf8")] + BadUtf82(#[from] std::string::FromUtf8Error), #[error("bad path")] BadPath, - #[error("ohh. i didn't expect that this error can be generated.")] + #[error("bad auth")] + BadAuth, + #[error("bad base64: {0}")] + BadBase64(#[from] base64::DecodeError), + #[error("impossible error")] Other, } diff --git a/src/files.rs b/src/files.rs index 68a3807..cf53942 100644 --- a/src/files.rs +++ b/src/files.rs @@ -21,7 +21,7 @@ use tokio::{ use tokio_util::io::poll_read_buf; pub async fn serve_files( - req: Request<Incoming>, + req: &Request<Incoming>, config: &FileserverConfig, ) -> Result<hyper::Response<BoxBody<Bytes, ServiceError>>, ServiceError> { let rpath = req.uri().path(); @@ -56,6 +56,7 @@ pub async fn serve_files( if !config.index { return Err(ServiceError::NotFound); } + debug!("sending index for {path:?}"); if !rpath.ends_with("/") { let mut r = Response::new(String::new()); @@ -78,6 +79,7 @@ pub async fn serve_files( let range = req.headers().typed_get::<headers::Range>(); let range = bytes_range(range, metadata.len())?; + debug!("sending file {path:?}"); let file = File::open(path.clone()).await?; let mut r = Response::new(BoxBody::new(StreamBody::new( diff --git a/src/main.rs b/src/main.rs index 99374eb..059bdf0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ #![feature(try_trait_v2)] #![feature(exclusive_range_pattern)] +#![feature(slice_split_once)] +pub mod auth; pub mod config; pub mod error; pub mod files; @@ -10,11 +12,12 @@ pub mod proxy; pub mod reporting; use crate::{ - config::{Config, HostConfig}, + config::{Config, RouteFilter}, files::serve_files, proxy::proxy_request, }; use anyhow::{anyhow, bail, Context, Result}; +use bytes::Bytes; use error::ServiceError; use futures::future::try_join_all; use helper::TokioIo; @@ -30,7 +33,10 @@ use hyper::{ use log::{debug, error, info, warn}; #[cfg(feature = "mond")] use reporting::Reporting; -use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc}; +use std::{ + fs::File, io::BufReader, net::SocketAddr, ops::ControlFlow, path::Path, process::exit, + sync::Arc, +}; use tokio::{net::TcpListener, signal::ctrl_c, sync::Semaphore}; use tokio_rustls::TlsAcceptor; @@ -42,6 +48,10 @@ pub struct State { pub reporting: Reporting, } +pub type FilterRequest = Request<Incoming>; +pub type FilterResponseOut = Option<Response<BoxBody<Bytes, ServiceError>>>; +pub type FilterResponse = Option<Response<BoxBody<Bytes, ServiceError>>>; + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init_from_env("LOG"); @@ -96,12 +106,13 @@ async fn serve_http(state: Arc<State>) -> Result<()> { let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; + info!("HTTP listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; debug!("connection from {addr}"); let stream = TokioIo(stream); - let config = state.clone(); - tokio::spawn(async move { serve_stream(config, stream, addr).await }); + let state = state.clone(); + tokio::spawn(async move { serve_stream(state, stream, addr).await }); } })) .await; @@ -131,9 +142,9 @@ async fn serve_https(state: Arc<State>) -> Result<()> { Arc::new(cfg) }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); - let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind.iter().map(|e| async { let l = TcpListener::bind(e.clone()).await?; + info!("HTTPS listener bound to {}", l.local_addr().unwrap()); loop { let (stream, addr) = l.accept().await.context("accepting connection")?; let state = state.clone(); @@ -148,8 +159,6 @@ async fn serve_https(state: Arc<State>) -> Result<()> { } })) .await; - - info!("serving https"); listen_futs?; Ok(()) } @@ -170,11 +179,14 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ - let mut resp = - Response::new(format!("gnix encountered an issue: {error}")); + let mut resp = Response::new(format!( + "Sorry, we were unable to process your request: {error}" + )); *resp.status_mut() = StatusCode::BAD_REQUEST; resp.headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + resp.headers_mut() + .insert(SERVER, HeaderValue::from_static("gnix")); resp } .map(|b| b.map_err(|e| match e {}).boxed())), @@ -226,10 +238,41 @@ async fn service( #[cfg(feature = "mond")] state.reporting.hosts.get(host).unwrap().requests_in.inc(); - let mut resp = match route { - HostConfig::Backend { backend } => proxy_request(&state, req, addr, backend).await, - HostConfig::Files { files } => serve_files(req, files).await, - }?; + let mut req = Some(req); + let mut resp = None; + for filter in &route.0 { + let cf = match filter { + RouteFilter::Proxy { backend } => { + resp = Some( + proxy_request( + &state, + req.take().ok_or(ServiceError::RequestTaken)?, + addr, + backend, + ) + .await?, + ); + ControlFlow::Continue(()) + } + RouteFilter::Files { config } => { + resp = Some( + serve_files(req.as_ref().ok_or(ServiceError::RequestTaken)?, config).await?, + ); + ControlFlow::Continue(()) + } + RouteFilter::HttpBasicAuth { config } => auth::http_basic( + config, + req.as_ref().ok_or(ServiceError::RequestTaken)?, + &mut resp, + )?, + }; + match cf { + ControlFlow::Continue(_) => continue, + ControlFlow::Break(_) => break, + } + } + + let mut resp = resp.ok_or(ServiceError::NoResponse)?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( diff --git a/src/proxy.rs b/src/proxy.rs index e8d2467..40ebf17 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -77,7 +77,7 @@ pub async fn proxy_request( if do_upgrade { let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>(); tokio::task::spawn(async move { - debug!("about upgrading connection, sending empty response"); + debug!("about to upgrade connection, sending empty response"); match ( on_upgrade_upstream.unwrap().await, on_upgrade_downstream.unwrap().await, |