summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-11-14 11:54:01 +0100
committermetamuffin <metamuffin@disroot.org>2023-11-14 11:54:01 +0100
commit3b1afad1d1a697e82c003e146ef2b7d5742e5210 (patch)
tree3a9e02470b4f78c4c34c0573c788da301a9e544e
parent4a7bd84594fb8d159a0a2af02818f283eab3e716 (diff)
downloadgnix-3b1afad1d1a697e82c003e146ef2b7d5742e5210.tar
gnix-3b1afad1d1a697e82c003e146ef2b7d5742e5210.tar.bz2
gnix-3b1afad1d1a697e82c003e146ef2b7d5742e5210.tar.zst
refactor architecture and start on http basic auth
-rw-r--r--Cargo.lock7
-rw-r--r--Cargo.toml1
-rw-r--r--src/auth.rs41
-rw-r--r--src/config.rs82
-rw-r--r--src/error.rs13
-rw-r--r--src/files.rs4
-rw-r--r--src/main.rs69
-rw-r--r--src/proxy.rs2
8 files changed, 194 insertions, 25 deletions
diff --git a/Cargo.lock b/Cargo.lock
index abd05b7..dbe3193 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -61,9 +61,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
-version = "0.21.2"
+version = "0.21.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
+checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9"
[[package]]
name = "bincode"
@@ -311,6 +311,7 @@ name = "gnix"
version = "1.0.0"
dependencies = [
"anyhow",
+ "base64 0.21.5",
"bytes",
"env_logger",
"futures",
@@ -841,7 +842,7 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
dependencies = [
- "base64 0.21.2",
+ "base64 0.21.5",
]
[[package]]
diff --git a/Cargo.toml b/Cargo.toml
index bad1d80..dd59b8b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -11,6 +11,7 @@ hyper-util = "0.0.0"
http-body-util = "0.1.0-rc.3"
headers = "0.3.8"
percent-encoding = "2.3.0"
+base64 = "0.21.5"
# TLS
rustls-pemfile = "1.0.3"
diff --git a/src/auth.rs b/src/auth.rs
new file mode 100644
index 0000000..92a9ba3
--- /dev/null
+++ b/src/auth.rs
@@ -0,0 +1,41 @@
+use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut};
+use base64::Engine;
+use http_body_util::{combinators::BoxBody, BodyExt};
+use hyper::{
+ header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE},
+ Response, StatusCode,
+};
+use log::debug;
+use std::ops::ControlFlow;
+
+pub fn http_basic(
+ config: &HttpBasicAuthConfig,
+ req: &FilterRequest,
+ resp: &mut FilterResponseOut,
+) -> Result<ControlFlow<()>, ServiceError> {
+ if let Some(auth) = req.headers().get(AUTHORIZATION) {
+ let k = auth
+ .as_bytes()
+ .strip_prefix(b"Basic ")
+ .ok_or(ServiceError::BadAuth)?;
+ let k = base64::engine::general_purpose::STANDARD.decode(k)?;
+ let k = String::from_utf8(k)?;
+ if config.valid.contains(&k) {
+ debug!("valid auth");
+ return Ok(ControlFlow::Continue(()));
+ } else {
+ debug!("invalid auth");
+ }
+ }
+ debug!("unauthorized; sending auth challenge");
+ let mut r = Response::new(BoxBody::<_, ServiceError>::new(
+ String::new().map_err(|_| unreachable!()),
+ ));
+ *r.status_mut() = StatusCode::UNAUTHORIZED;
+ r.headers_mut().insert(
+ WWW_AUTHENTICATE,
+ HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(),
+ );
+ *resp = Some(r);
+ Ok(ControlFlow::Break(()))
+}
diff --git a/src/config.rs b/src/config.rs
index c686930..a7abf3b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -3,7 +3,14 @@ use serde::{
de::{value, Error, SeqAccess, Visitor},
Deserialize, Deserializer, Serialize,
};
-use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf};
+use std::{
+ collections::{HashMap, HashSet},
+ fmt,
+ fs::read_to_string,
+ marker::PhantomData,
+ net::SocketAddr,
+ path::PathBuf,
+};
#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
@@ -12,7 +19,7 @@ pub struct Config {
#[serde(default)]
pub limits: Limits,
#[serde(default)]
- pub hosts: HashMap<String, HostConfig>,
+ pub hosts: HashMap<String, Route>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -37,10 +44,28 @@ pub struct HttpsConfig {
}
#[derive(Debug, Serialize, Deserialize)]
-#[serde(untagged)]
-pub enum HostConfig {
- Backend { backend: SocketAddr },
- Files { files: FileserverConfig },
+pub struct Route(#[serde(deserialize_with = "seq_or_not")] pub Vec<RouteFilter>);
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum RouteFilter {
+ HttpBasicAuth {
+ #[serde(flatten)]
+ config: HttpBasicAuthConfig,
+ },
+ Proxy {
+ backend: SocketAddr,
+ },
+ Files {
+ #[serde(flatten)]
+ config: FileserverConfig,
+ },
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct HttpBasicAuthConfig {
+ pub realm: String,
+ pub valid: HashSet<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -50,6 +75,51 @@ pub struct FileserverConfig {
pub index: bool,
}
+// try deser Vec<T> but fall back to deser T and putting that in Vec
+fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result<Vec<V>, D::Error>
+where
+ D: Deserializer<'de>,
+{
+ struct SeqOrNot<V>(PhantomData<V>);
+ impl<'de, V: Deserialize<'de>> Visitor<'de> for SeqOrNot<V> {
+ type Value = Vec<V>;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a sequence or not a sequence")
+ }
+ fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::EnumAccess<'de>,
+ {
+ Ok(vec![V::deserialize(value::EnumAccessDeserializer::new(
+ data,
+ ))?])
+ }
+ fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::MapAccess<'de>,
+ {
+ Ok(vec![V::deserialize(value::MapAccessDeserializer::new(
+ map,
+ ))?])
+ }
+ fn visit_str<E>(self, val: &str) -> Result<Vec<V>, E>
+ where
+ E: Error,
+ {
+ Ok(vec![V::deserialize(value::StrDeserializer::new(val))?])
+ }
+
+ fn visit_seq<A>(self, val: A) -> Result<Vec<V>, A::Error>
+ where
+ A: SeqAccess<'de>,
+ {
+ Vec::<V>::deserialize(value::SeqAccessDeserializer::new(val))
+ }
+ }
+ des.deserialize_any(SeqOrNot::<V>(PhantomData))
+}
+
// fall back to expecting a single string and putting that in a 1-length vector
fn string_or_seq<'de, D>(des: D) -> Result<Vec<SocketAddr>, D::Error>
where
diff --git a/src/error.rs b/src/error.rs
index 83a1ffa..b78999f 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -2,6 +2,11 @@ use tokio::sync::TryAcquireError;
#[derive(Debug, thiserror::Error)]
pub enum ServiceError {
+ #[error("no response generated; the proxy is misconfigured")]
+ NoResponse,
+ #[error("request taken; the proxy is misconfigured")]
+ RequestTaken,
+
#[error("limit reached. try again")]
Limit(#[from] TryAcquireError),
#[error("hyper error")]
@@ -18,8 +23,14 @@ pub enum ServiceError {
BadRange,
#[error("bad utf8")]
BadUtf8(#[from] std::str::Utf8Error),
+ #[error("bad utf8")]
+ BadUtf82(#[from] std::string::FromUtf8Error),
#[error("bad path")]
BadPath,
- #[error("ohh. i didn't expect that this error can be generated.")]
+ #[error("bad auth")]
+ BadAuth,
+ #[error("bad base64: {0}")]
+ BadBase64(#[from] base64::DecodeError),
+ #[error("impossible error")]
Other,
}
diff --git a/src/files.rs b/src/files.rs
index 68a3807..cf53942 100644
--- a/src/files.rs
+++ b/src/files.rs
@@ -21,7 +21,7 @@ use tokio::{
use tokio_util::io::poll_read_buf;
pub async fn serve_files(
- req: Request<Incoming>,
+ req: &Request<Incoming>,
config: &FileserverConfig,
) -> Result<hyper::Response<BoxBody<Bytes, ServiceError>>, ServiceError> {
let rpath = req.uri().path();
@@ -56,6 +56,7 @@ pub async fn serve_files(
if !config.index {
return Err(ServiceError::NotFound);
}
+ debug!("sending index for {path:?}");
if !rpath.ends_with("/") {
let mut r = Response::new(String::new());
@@ -78,6 +79,7 @@ pub async fn serve_files(
let range = req.headers().typed_get::<headers::Range>();
let range = bytes_range(range, metadata.len())?;
+ debug!("sending file {path:?}");
let file = File::open(path.clone()).await?;
let mut r = Response::new(BoxBody::new(StreamBody::new(
diff --git a/src/main.rs b/src/main.rs
index 99374eb..059bdf0 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,8 @@
#![feature(try_trait_v2)]
#![feature(exclusive_range_pattern)]
+#![feature(slice_split_once)]
+pub mod auth;
pub mod config;
pub mod error;
pub mod files;
@@ -10,11 +12,12 @@ pub mod proxy;
pub mod reporting;
use crate::{
- config::{Config, HostConfig},
+ config::{Config, RouteFilter},
files::serve_files,
proxy::proxy_request,
};
use anyhow::{anyhow, bail, Context, Result};
+use bytes::Bytes;
use error::ServiceError;
use futures::future::try_join_all;
use helper::TokioIo;
@@ -30,7 +33,10 @@ use hyper::{
use log::{debug, error, info, warn};
#[cfg(feature = "mond")]
use reporting::Reporting;
-use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc};
+use std::{
+ fs::File, io::BufReader, net::SocketAddr, ops::ControlFlow, path::Path, process::exit,
+ sync::Arc,
+};
use tokio::{net::TcpListener, signal::ctrl_c, sync::Semaphore};
use tokio_rustls::TlsAcceptor;
@@ -42,6 +48,10 @@ pub struct State {
pub reporting: Reporting,
}
+pub type FilterRequest = Request<Incoming>;
+pub type FilterResponseOut = Option<Response<BoxBody<Bytes, ServiceError>>>;
+pub type FilterResponse = Option<Response<BoxBody<Bytes, ServiceError>>>;
+
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init_from_env("LOG");
@@ -96,12 +106,13 @@ async fn serve_http(state: Arc<State>) -> Result<()> {
let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind.iter().map(|e| async {
let l = TcpListener::bind(e.clone()).await?;
+ info!("HTTP listener bound to {}", l.local_addr().unwrap());
loop {
let (stream, addr) = l.accept().await.context("accepting connection")?;
debug!("connection from {addr}");
let stream = TokioIo(stream);
- let config = state.clone();
- tokio::spawn(async move { serve_stream(config, stream, addr).await });
+ let state = state.clone();
+ tokio::spawn(async move { serve_stream(state, stream, addr).await });
}
}))
.await;
@@ -131,9 +142,9 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
Arc::new(cfg)
};
let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config));
-
let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind.iter().map(|e| async {
let l = TcpListener::bind(e.clone()).await?;
+ info!("HTTPS listener bound to {}", l.local_addr().unwrap());
loop {
let (stream, addr) = l.accept().await.context("accepting connection")?;
let state = state.clone();
@@ -148,8 +159,6 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
}
}))
.await;
-
- info!("serving https");
listen_futs?;
Ok(())
}
@@ -170,11 +179,14 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r
Ok(r) => Ok(r),
Err(ServiceError::Hyper(e)) => Err(e),
Err(error) => Ok({
- let mut resp =
- Response::new(format!("gnix encountered an issue: {error}"));
+ let mut resp = Response::new(format!(
+ "Sorry, we were unable to process your request: {error}"
+ ));
*resp.status_mut() = StatusCode::BAD_REQUEST;
resp.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
+ resp.headers_mut()
+ .insert(SERVER, HeaderValue::from_static("gnix"));
resp
}
.map(|b| b.map_err(|e| match e {}).boxed())),
@@ -226,10 +238,41 @@ async fn service(
#[cfg(feature = "mond")]
state.reporting.hosts.get(host).unwrap().requests_in.inc();
- let mut resp = match route {
- HostConfig::Backend { backend } => proxy_request(&state, req, addr, backend).await,
- HostConfig::Files { files } => serve_files(req, files).await,
- }?;
+ let mut req = Some(req);
+ let mut resp = None;
+ for filter in &route.0 {
+ let cf = match filter {
+ RouteFilter::Proxy { backend } => {
+ resp = Some(
+ proxy_request(
+ &state,
+ req.take().ok_or(ServiceError::RequestTaken)?,
+ addr,
+ backend,
+ )
+ .await?,
+ );
+ ControlFlow::Continue(())
+ }
+ RouteFilter::Files { config } => {
+ resp = Some(
+ serve_files(req.as_ref().ok_or(ServiceError::RequestTaken)?, config).await?,
+ );
+ ControlFlow::Continue(())
+ }
+ RouteFilter::HttpBasicAuth { config } => auth::http_basic(
+ config,
+ req.as_ref().ok_or(ServiceError::RequestTaken)?,
+ &mut resp,
+ )?,
+ };
+ match cf {
+ ControlFlow::Continue(_) => continue,
+ ControlFlow::Break(_) => break,
+ }
+ }
+
+ let mut resp = resp.ok_or(ServiceError::NoResponse)?;
let server_header = resp.headers().get(SERVER).cloned();
resp.headers_mut().insert(
diff --git a/src/proxy.rs b/src/proxy.rs
index e8d2467..40ebf17 100644
--- a/src/proxy.rs
+++ b/src/proxy.rs
@@ -77,7 +77,7 @@ pub async fn proxy_request(
if do_upgrade {
let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
tokio::task::spawn(async move {
- debug!("about upgrading connection, sending empty response");
+ debug!("about to upgrade connection, sending empty response");
match (
on_upgrade_upstream.unwrap().await,
on_upgrade_downstream.unwrap().await,