From 886a18e0c67624d0882f04c7f6659bcfee6b4d8d Mon Sep 17 00:00:00 2001 From: metamuffin Date: Wed, 29 May 2024 16:37:44 +0200 Subject: refactor filter system --- Cargo.lock | 44 ------- Cargo.toml | 7 -- readme.md | 64 +++++++--- src/config.rs | 126 ++++++++------------ src/error.rs | 6 +- src/filters/accesslog.rs | 112 +++++++++++------- src/filters/auth.rs | 86 +++++++++----- src/filters/error.rs | 32 +++++ src/filters/files.rs | 299 +++++++++++++++++++++++++++-------------------- src/filters/hosts.rs | 45 +++++++ src/filters/mod.rs | 47 +++++++- src/filters/proxy.rs | 157 +++++++++++++------------ src/helper.rs | 1 - src/main.rs | 96 +++------------ src/reporting.rs | 39 ------- 15 files changed, 623 insertions(+), 538 deletions(-) create mode 100644 src/filters/error.rs create mode 100644 src/filters/hosts.rs delete mode 100644 src/reporting.rs diff --git a/Cargo.lock b/Cargo.lock index dc89f19..b18a3f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,25 +140,6 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" -[[package]] -name = "bincode" -version = "2.0.0-rc.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f11ea1a0346b94ef188834a65c068a03aec181c94896d481d7a0a40d85b0ce95" -dependencies = [ - "bincode_derive", - "serde", -] - -[[package]] -name = "bincode_derive" -version = "2.0.0-rc.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e30759b3b99a1b802a7a3aa21c85c3ded5c28e1c83170d82d70f08bbf7f3e4c" -dependencies = [ - "virtue", -] - [[package]] name = "bindgen" version = "0.69.4" @@ -495,7 +476,6 @@ dependencies = [ "log", "markup", "mime_guess", - "mond-client", "percent-encoding", "pin-project", "rustls", @@ -860,24 +840,6 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" -[[package]] -name = "mond-client" -version = "0.1.0" -source = "git+https://codeberg.org/metamuffin/mond#6019696ebc243f1b743823f076af0201ab0d27d5" -dependencies = [ - "log", - "mond-protocol", -] - -[[package]] -name = "mond-protocol" -version = "0.1.0" -source = "git+https://codeberg.org/metamuffin/mond#6019696ebc243f1b743823f076af0201ab0d27d5" -dependencies = [ - "bincode", - "serde", -] - [[package]] name = "nom" version = "7.1.3" @@ -1382,12 +1344,6 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" -[[package]] -name = "virtue" -version = "0.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" - [[package]] name = "want" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index 5176f3a..a1501f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "1.0.0" edition = "2021" [dependencies] - # HTTP hyper = { version = "1.3.1", features = ["full"] } hyper-util = "0.1.3" @@ -44,9 +43,3 @@ mime_guess = "2.0.4" bytes = "1.6.0" anyhow = "1.0.82" thiserror = "1.0.59" - -mond-client = { git = "https://codeberg.org/metamuffin/mond", optional = true } - -[features] -default = [] -mond = ["dep:mond-client"] diff --git a/readme.md b/readme.md index 3f30b9c..2ad4faa 100644 --- a/readme.md +++ b/readme.md @@ -6,8 +6,9 @@ a simple stupid reverse proxy - Simple to configure (see below) - Handles connection upgrades correctly by default (websocket, etc.) +- Composable modules - TLS support -- _TODO: h2; match on uris; connection pools_ +- _TODO: h2; match on uris; connection pooling_ ## Quick Start @@ -18,31 +19,36 @@ configuration file is written in YAML and could look like this: # Both the 'http' and 'https' sections are optional http: # the value for 'bind' can either be a string or a list of strings - bind: [ "127.0.0.1:8080", "[::1]:8080" ] + bind: "[::1]:8080" https: - bind: "127.0.0.1:8443" + bind: "[::1]:8443" tls_cert: "ssl/cert.pem" - tls_key: "ssl/key.pem" # only accepts pkcs8 for now - -# this is a lookup table from hostnames to a list of filters -# in this case, requests for `testdomain.local` are forwarded to 127.0.0.1:3000 -hosts: - "testdomain.local": !proxy { backend: "127.0.0.1:8000" } - "192.168.178.39": !proxy { backend: "127.0.0.1:8000" } - "localhost": !files - root: "/home/muffin/videos" + tls_key: "ssl/key.pem" # only accepts pkcs8 + +# !hosts multiplexes requests for different hostnames. +handler: !hosts + # requests for `example.org` are forwarded to 127.0.0.1:8000 + "example.org": !proxy { backend: "127.0.0.1:8000" } + # requests for `mydomain.com` will access files from /srv/http + "mydomain.com": !files + root: "/srv/http" index: true + + "panel.mydomain.com": !access_log + ``` ## Reference - **section `http`** - `bind`: string or list of strings with addresses to listen on. + - **section `https`** - `bind`: string or list of strings with addresses to listen on. - `tls_cert`: path to the SSL certificate. (Sometimes called `fullchain.pem`) - `tls_key`: path to the SSL key. (Often called `key.pem` or `privkey.pem`) + - **section `limits`** - Note: Make sure you do not exceed the maximum file descriptor limit on your platform. @@ -50,21 +56,28 @@ hosts: connections. excess connections are rejected. Default: 512 - `max_outgoing_connections` number of maximum outgoing (upstream) connections. excess connections are rejected. Default: 256 -- **section `hosts`** - - A map from hostname (a string) to a _filter_ or a list of _filters_ + +- **section `handler`** + - A module to handle all requests. Usually an instance of `hosts`. + - `watch_config`: boolean if to watch the configuration file for changes and apply them accordingly. Default: true (Note: This will watch the entire parent directory of the config since most editors first move the file. Currently any change will trigger a reload. TODO) -### Filters +### Modules -- **filter `proxy`** +- **module `hosts`** + - Hands over the requests to different modules depending on the `host` header. + - Takes a map from hostname (string) to handler (module) + +- **module `proxy`** - Forwards the request as-is to some other server. the `x-real-ip` header is injected into the request. Connection upgrades are handled by direct forwarding of network traffic. - `backend`: socket address (string) to the backend server -- **filter `files`** + +- **module `files`** - Provides a simple built-in fileserver. The server handles `accept-ranges`. The `content-type` header is inferred from the file extension and falls back to `application/octet-stream`. If a directory is requested `index.html` will @@ -72,12 +85,25 @@ hosts: prepended to the response. - `root`: root directory to be served (string) - `index`: enables directory indexing (boolean) -- **filter `http_basic_auth`** + +- **module `http_basic_auth`** - Filters requests via HTTP Basic Authentification. Unauthorized clients will be challenged on every request. - - `realm`: string that does essentially nothing + - `realm`: describes what the user is logging into (most modern browsers dont show this anymore -_-) - `valid`: list of valid logins (string) in the format `:` (password in plain text). TODO: hashing + - `next`: a module to handle this request on successfully authentificated. (module) + +- **module `access_log`** + - Logs requests to a file. + - `file`: file path to log (string) + - `reject_on_fail`: rejects requests if log could not be written (boolean) + - `flush`: flushes log on every request (boolean) + - `next`: module for further handling of the request (module) + +- **module `error`** + - Rejects every request with a custom error message. + - Takes an error message (string) ## License diff --git a/src/config.rs b/src/config.rs index e661996..dfc4e73 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,34 +1,38 @@ -use crate::State; -use anyhow::Context; +use crate::{ + filters::{Node, NodeKind}, + State, +}; +use anyhow::{anyhow, Context}; use inotify::{EventMask, Inotify, WatchMask}; use log::{error, info}; use serde::{ de::{value, Error, SeqAccess, Visitor}, Deserialize, Deserializer, Serialize, }; +use serde_yaml::value::TaggedValue; use std::{ - collections::{HashMap, HashSet}, + collections::BTreeMap, fmt, fs::read_to_string, marker::PhantomData, net::SocketAddr, + ops::Deref, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, RwLock}, }; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Deserialize)] pub struct Config { - #[serde(default = "true_default")] + #[serde(default = "return_true")] pub watch_config: bool, pub http: Option, pub https: Option, #[serde(default)] pub limits: Limits, - #[serde(default)] - pub hosts: HashMap, + pub handler: DynNode, } -fn true_default() -> bool { +pub fn return_true() -> bool { true } @@ -53,72 +57,8 @@ pub struct HttpsConfig { pub tls_key: PathBuf, } -#[derive(Debug, Serialize, Deserialize)] -pub struct Route(#[serde(deserialize_with = "seq_or_not")] pub Vec); - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum RouteFilter { - HttpBasicAuth { - #[serde(flatten)] - config: HttpBasicAuthConfig, - }, - Proxy { - backend: SocketAddr, - }, - Files { - #[serde(flatten)] - config: FileserverConfig, - }, - AccessLog { - #[serde(flatten)] - config: AccessLogConfig, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct AccessLogConfig { - pub file: PathBuf, - #[serde(default)] - pub flush: bool, - #[serde(default)] - pub reject_on_fail: bool, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct HttpBasicAuthConfig { - pub realm: String, - pub valid: HashSet, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FileserverConfig { - pub root: PathBuf, - #[serde(default)] - pub index: bool, - #[serde(default = "return_true")] - pub last_modified: bool, - #[serde(default = "return_true")] - pub etag: bool, - #[serde(default)] - pub cache: CacheConfig, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CacheConfig { - #[default] - Public, - Private, - NoStore, -} - -fn return_true() -> bool { - true -} - // try deser Vec but fall back to deser T and putting that in Vec -fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result, D::Error> +pub fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result, D::Error> where D: Deserializer<'de>, { @@ -194,11 +134,45 @@ where des.deserialize_any(StringOrList) } +pub static NODE_KINDS: RwLock> = + RwLock::new(BTreeMap::new()); + +#[derive(Clone)] +pub struct DynNode(Arc); + +impl<'de> Deserialize<'de> for DynNode { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let tv = TaggedValue::deserialize(deserializer)?; + let s = tv.tag.to_string(); + let s = s.strip_prefix("!").unwrap_or(s.as_str()); + let inst = NODE_KINDS + .read() + .unwrap() + .get(s) + .ok_or(serde::de::Error::unknown_variant(s, &[]))? + .instanciate(tv.value) + .map_err(|e| { + serde::de::Error::custom(e.context(anyhow!("instanciating modules {s:?}"))) + })?; + + Ok(Self(inst)) + } +} +impl Deref for DynNode { + type Target = dyn Node; + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + impl Config { pub fn load(path: &Path) -> anyhow::Result { info!("loading config from {path:?}"); let raw = read_to_string(path).context("reading config file")?; - let config: Config = serde_yaml::from_str(&raw).context("parsing config")?; + let config: Config = serde_yaml::from_str(&raw).context("during parsing")?; Ok(config) } } @@ -235,7 +209,7 @@ pub fn setup_file_watch(config_path: PathBuf, state: Arc) { let mut r = state.config.blocking_write(); *r = Arc::new(conf) } - Err(e) => error!("config has errors: {e}"), + Err(e) => error!("config has errors: {e:?}"), } } } diff --git a/src/error.rs b/src/error.rs index 1675b3c..e7e5af2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,8 +11,10 @@ pub enum ServiceError { Limit(#[from] TryAcquireError), #[error("hyper error")] Hyper(hyper::Error), - #[error("unknown host")] + #[error("no host")] NoHost, + #[error("unknown host")] + UnknownHost, #[error("can't connect to the backend")] CantConnect, #[error("not found")] @@ -33,6 +35,8 @@ pub enum ServiceError { BadBase64(#[from] base64::DecodeError), #[error("connection upgrade failed")] UpgradeFailed, + #[error("{0}")] + Custom(String), #[error("impossible error")] Other, } diff --git a/src/filters/accesslog.rs b/src/filters/accesslog.rs index 9a33762..1da6e5d 100644 --- a/src/filters/accesslog.rs +++ b/src/filters/accesslog.rs @@ -1,49 +1,81 @@ -use crate::{config::AccessLogConfig, error::ServiceError, FilterRequest, State}; -use futures::executor::block_on; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use futures::Future; use log::error; -use std::{net::SocketAddr, ops::ControlFlow, time::SystemTime}; +use serde::Deserialize; +use std::{path::PathBuf, pin::Pin, sync::Arc, time::SystemTime}; use tokio::{ - fs::OpenOptions, + fs::{File, OpenOptions}, io::{AsyncWriteExt, BufWriter}, + sync::RwLock, }; -pub async fn access_log( - state: &State, - host: &str, - addr: SocketAddr, - config: &AccessLogConfig, - req: &FilterRequest, -) -> Result, ServiceError> { - let mut g = state.access_logs.write().await; - - let log = g.entry(host.to_owned()).or_insert_with(|| { - BufWriter::new( - // TODO aaahh dont block the runtime and dont panic in any case.... - block_on( - OpenOptions::new() - .append(true) - .create(true) - .open(&config.file), - ) - .unwrap(), - ) - }); - - let method = req.method().as_str(); - let time = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros(); - let mut res = log - .write_all(format!("{time}\t{addr}\t{method}\t{:?}\n", req.uri()).as_bytes()) - .await; - - if config.flush && res.is_ok() { - res = log.flush().await; - } +pub struct AccessLogKind; + +#[derive(Deserialize)] +struct AccessLogConfig { + file: PathBuf, + #[serde(default)] + flush: bool, + #[serde(default)] + reject_on_fail: bool, + next: DynNode, +} - if config.reject_on_fail { - res? - } else if let Err(e) = res { - error!("failed to write log: {e:?}") +struct AccessLog { + config: AccessLogConfig, + file: RwLock>>, +} + +impl NodeKind for AccessLogKind { + fn name(&self) -> &'static str { + "access_log" + } + fn instanciate(&self, config: serde_yaml::Value) -> anyhow::Result> { + Ok(Arc::new(AccessLog { + config: serde_yaml::from_value::(config)?, + file: Default::default(), + })) } +} + +impl Node for AccessLog { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { + let mut g = self.file.write().await; + let log = match g.as_mut() { + Some(r) => r, + None => g.insert(BufWriter::new( + OpenOptions::new() + .append(true) + .create(true) + .open(&self.config.file) + .await?, + )), + }; + + let method = request.method().as_str(); + let time = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros(); + let addr = context.addr; + let mut res = log + .write_all(format!("{time}\t{addr}\t{method}\t{:?}\n", request.uri()).as_bytes()) + .await; - Ok(ControlFlow::Continue(())) + if self.config.flush && res.is_ok() { + res = log.flush().await; + } + + if self.config.reject_on_fail { + res? + } else if let Err(e) = res { + error!("failed to write log: {e:?}") + } + + self.config.next.handle(context, request).await + }) + } } diff --git a/src/filters/auth.rs b/src/filters/auth.rs index 92a9ba3..7d5b03e 100644 --- a/src/filters/auth.rs +++ b/src/filters/auth.rs @@ -1,41 +1,65 @@ -use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use super::{Node, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; use base64::Engine; +use futures::Future; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, Response, StatusCode, }; use log::debug; -use std::ops::ControlFlow; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashSet, pin::Pin, sync::Arc}; -pub fn http_basic( - config: &HttpBasicAuthConfig, - req: &FilterRequest, - resp: &mut FilterResponseOut, -) -> Result, ServiceError> { - if let Some(auth) = req.headers().get(AUTHORIZATION) { - let k = auth - .as_bytes() - .strip_prefix(b"Basic ") - .ok_or(ServiceError::BadAuth)?; - let k = base64::engine::general_purpose::STANDARD.decode(k)?; - let k = String::from_utf8(k)?; - if config.valid.contains(&k) { - debug!("valid auth"); - return Ok(ControlFlow::Continue(())); - } else { - debug!("invalid auth"); - } +pub struct HttpBasicAuthKind; +impl NodeKind for HttpBasicAuthKind { + fn name(&self) -> &'static str { + "http_basic_auth" + } + fn instanciate(&self, config: Value) -> anyhow::Result> { + Ok(Arc::new(serde_yaml::from_value::(config)?)) + } +} + +#[derive(Deserialize)] +pub struct HttpBasicAuth { + realm: String, + valid: HashSet, + next: DynNode, +} + +impl Node for HttpBasicAuth { + fn handle<'a>( + &'a self, + context: &'a mut super::NodeContext, + request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { + if let Some(auth) = request.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if self.valid.contains(&k) { + debug!("valid auth"); + return self.next.handle(context, request).await; + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", self.realm)).unwrap(), + ); + Ok(r) + }) } - debug!("unauthorized; sending auth challenge"); - let mut r = Response::new(BoxBody::<_, ServiceError>::new( - String::new().map_err(|_| unreachable!()), - )); - *r.status_mut() = StatusCode::UNAUTHORIZED; - r.headers_mut().insert( - WWW_AUTHENTICATE, - HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), - ); - *resp = Some(r); - Ok(ControlFlow::Break(())) } diff --git a/src/filters/error.rs b/src/filters/error.rs new file mode 100644 index 0000000..504802f --- /dev/null +++ b/src/filters/error.rs @@ -0,0 +1,32 @@ +use crate::error::ServiceError; + +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use futures::Future; +use serde::Deserialize; +use serde_yaml::Value; +use std::{pin::Pin, sync::Arc}; + +pub struct ErrorKind; + +#[derive(Deserialize)] +#[serde(transparent)] +struct Error(String); + +impl NodeKind for ErrorKind { + fn name(&self) -> &'static str { + "error" + } + fn instanciate(&self, config: Value) -> anyhow::Result> { + Ok(Arc::new(serde_yaml::from_value::(config)?)) + } +} + +impl Node for Error { + fn handle<'a>( + &'a self, + _context: &'a mut NodeContext, + _request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { Err(ServiceError::Custom(self.0.clone())) }) + } +} diff --git a/src/filters/files.rs b/src/filters/files.rs index ee40d70..fc3d63b 100644 --- a/src/filters/files.rs +++ b/src/filters/files.rs @@ -1,8 +1,7 @@ -use crate::{ - config::{CacheConfig, FileserverConfig}, - ServiceError, -}; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::return_true, ServiceError}; use bytes::{Bytes, BytesMut}; +use futures::Future; use futures_util::{future, future::Either, ready, stream, FutureExt, Stream, StreamExt}; use headers::{ AcceptRanges, CacheControl, ContentLength, ContentRange, ContentType, HeaderMapExt, @@ -11,154 +10,204 @@ use headers::{ use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; use humansize::FormatSizeOptions; use hyper::{ - body::{Frame, Incoming}, + body::Frame, header::{CONTENT_TYPE, LOCATION}, http::HeaderValue, - Request, Response, StatusCode, + Response, StatusCode, }; use log::debug; use markup::Render; use percent_encoding::percent_decode_str; -use std::{fs::Metadata, io, ops::Range, path::Path, pin::Pin, task::Poll}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{ + fs::Metadata, + io, + ops::Range, + path::{Path, PathBuf}, + pin::Pin, + sync::Arc, + task::Poll, +}; use tokio::{ fs::{read_to_string, File}, io::AsyncSeekExt, }; use tokio_util::io::poll_read_buf; -pub async fn serve_files( - req: &Request, - config: &FileserverConfig, -) -> Result>, ServiceError> { - let rpath = req.uri().path(); - - let mut path = config.root.clone(); - let mut user_path_depth = 0; - for seg in rpath.split("/") { - let seg = percent_decode_str(seg).decode_utf8()?; +pub struct FilesKind; + +#[derive(Debug, Deserialize)] +struct Files { + root: PathBuf, + #[serde(default)] + index: bool, + #[serde(default = "return_true")] + last_modified: bool, + // #[serde(default = "return_true")] + // etag: bool, + #[serde(default)] + cache: CacheMode, +} - if seg == "" || seg == "." { - continue; - } +#[derive(Debug, Default, Deserialize)] +#[serde(rename_all = "snake_case")] +enum CacheMode { + #[default] + Public, + Private, + NoStore, +} - if seg == ".." { - if user_path_depth <= 0 { - return Err(ServiceError::BadPath); - } - path.pop(); - user_path_depth -= 1; - } else { - path.push(seg.as_ref()); - user_path_depth += 1; - } +impl NodeKind for FilesKind { + fn name(&self) -> &'static str { + "files" } - if !path.exists() { - return Err(ServiceError::NotFound); + fn instanciate(&self, config: Value) -> anyhow::Result> { + Ok(Arc::new(serde_yaml::from_value::(config)?)) } +} - let metadata = path.metadata()?; +impl Node for Files { + fn handle<'a>( + &'a self, + _context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { + let rpath = request.uri().path(); + + let mut path = self.root.clone(); + let mut user_path_depth = 0; + for seg in rpath.split("/") { + let seg = percent_decode_str(seg).decode_utf8()?; + + if seg == "" || seg == "." { + continue; + } - if metadata.file_type().is_dir() { - debug!("sending index for {path:?}"); - if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { - return Ok(html_string_response(indexhtml)); - } + if seg == ".." { + if user_path_depth <= 0 { + return Err(ServiceError::BadPath); + } + path.pop(); + user_path_depth -= 1; + } else { + path.push(seg.as_ref()); + user_path_depth += 1; + } + } + if !path.exists() { + return Err(ServiceError::NotFound); + } + + let metadata = path.metadata()?; - if config.index { - if !rpath.ends_with("/") { - let mut r = Response::new(String::new()); - *r.status_mut() = StatusCode::FOUND; - r.headers_mut().insert( - LOCATION, - HeaderValue::from_str(&format!("{}/", rpath)) - .map_err(|_| ServiceError::Other)?, - ); - return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + if metadata.file_type().is_dir() { + debug!("sending index for {path:?}"); + if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { + return Ok(html_string_response(indexhtml)); + } + + if self.index { + if !rpath.ends_with("/") { + let mut r = Response::new(String::new()); + *r.status_mut() = StatusCode::FOUND; + r.headers_mut().insert( + LOCATION, + HeaderValue::from_str(&format!("{}/", rpath)) + .map_err(|_| ServiceError::Other)?, + ); + return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + } + + return index(&path, rpath.to_string()) + .await + .map(html_string_response); + } else { + return Err(ServiceError::NotFound); + } } - return index(&path, rpath.to_string()) - .await - .map(html_string_response); - } else { - return Err(ServiceError::NotFound); - } - } + let modified = metadata.modified()?; + + let not_modified = if self.last_modified { + request + .headers() + .typed_get::() + .map(|if_modified_since| { + Ok::<_, ServiceError>(!if_modified_since.is_modified(modified)) + }) + .transpose()? + .unwrap_or_default() + } else { + false + }; - let modified = metadata.modified()?; + // let etag = ETag::from_str(&calc_etag(modified)).map_err(|_| ServiceError::Other)?; + // let etag_matches = if self.etag { + // request.headers() + // .typed_get::() + // .map(|if_none_match| if_none_match.precondition_passes(&etag)) + // .unwrap_or_default() + // } else { + // false + // }; + + let range = request.headers().typed_get::(); + let range = bytes_range(range, metadata.len())?; + + debug!("sending file {path:?}"); + let file = File::open(path.clone()).await?; + + // let skip_body = not_modified || etag_matches; + let skip_body = not_modified; + let mut r = if skip_body { + Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()) + } else { + Response::new(BoxBody::new(StreamBody::new( + StreamBody::new(file_stream(file, 4096, range.clone())) + .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), + ))) + }; - let not_modified = if config.last_modified { - req.headers() - .typed_get::() - .map(|if_modified_since| { - Ok::<_, ServiceError>(!if_modified_since.is_modified(modified)) - }) - .transpose()? - .unwrap_or_default() - } else { - false - }; + if !skip_body { + if range.end - range.start != metadata.len() { + *r.status_mut() = StatusCode::PARTIAL_CONTENT; + r.headers_mut().typed_insert( + ContentRange::bytes(range.clone(), metadata.len()) + .expect("valid ContentRange"), + ); + } + } + // if not_modified || etag_matches { + if not_modified { + *r.status_mut() = StatusCode::NOT_MODIFIED; + } - // let etag = ETag::from_str(&calc_etag(modified)).map_err(|_| ServiceError::Other)?; - // let etag_matches = if config.etag { - // req.headers() - // .typed_get::() - // .map(|if_none_match| if_none_match.precondition_passes(&etag)) - // .unwrap_or_default() - // } else { - // false - // }; - - let range = req.headers().typed_get::(); - let range = bytes_range(range, metadata.len())?; - - debug!("sending file {path:?}"); - let file = File::open(path.clone()).await?; - - // let skip_body = not_modified || etag_matches; - let skip_body = not_modified; - let mut r = if skip_body { - Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()) - } else { - Response::new(BoxBody::new(StreamBody::new( - StreamBody::new(file_stream(file, 4096, range.clone())) - .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), - ))) - }; + r.headers_mut().typed_insert(AcceptRanges::bytes()); + r.headers_mut() + .typed_insert(ContentLength(range.end - range.start)); - if !skip_body { - if range.end - range.start != metadata.len() { - *r.status_mut() = StatusCode::PARTIAL_CONTENT; - r.headers_mut().typed_insert( - ContentRange::bytes(range.clone(), metadata.len()).expect("valid ContentRange"), - ); - } - } - // if not_modified || etag_matches { - if not_modified { - *r.status_mut() = StatusCode::NOT_MODIFIED; - } + let mime = mime_guess::from_path(path).first_or_octet_stream(); + r.headers_mut().typed_insert(ContentType::from(mime)); - r.headers_mut().typed_insert(AcceptRanges::bytes()); - r.headers_mut() - .typed_insert(ContentLength(range.end - range.start)); - - let mime = mime_guess::from_path(path).first_or_octet_stream(); - r.headers_mut().typed_insert(ContentType::from(mime)); - - r.headers_mut().typed_insert(match config.cache { - CacheConfig::Public => CacheControl::new().with_public(), - CacheConfig::Private => CacheControl::new().with_private(), - CacheConfig::NoStore => CacheControl::new().with_no_store(), - }); - - // if config.etag { - // r.headers_mut().typed_insert(etag); - // } - if config.last_modified { - r.headers_mut().typed_insert(LastModified::from(modified)); - } + r.headers_mut().typed_insert(match self.cache { + CacheMode::Public => CacheControl::new().with_public(), + CacheMode::Private => CacheControl::new().with_private(), + CacheMode::NoStore => CacheControl::new().with_no_store(), + }); + + // if self.etag { + // r.headers_mut().typed_insert(etag); + // } + if self.last_modified { + r.headers_mut().typed_insert(LastModified::from(modified)); + } - Ok(r) + Ok(r) + }) + } } // Adapted from warp (https://github.com/seanmonstar/warp/blob/master/src/filters/fs.rs). Thanks! diff --git a/src/filters/hosts.rs b/src/filters/hosts.rs new file mode 100644 index 0000000..286d478 --- /dev/null +++ b/src/filters/hosts.rs @@ -0,0 +1,45 @@ +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use futures::Future; +use hyper::header::HOST; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashMap, pin::Pin, sync::Arc}; + +#[derive(Deserialize)] +#[serde(transparent)] +struct Hosts(HashMap); + +pub struct HostsKind; +impl NodeKind for HostsKind { + fn name(&self) -> &'static str { + "hosts" + } + fn instanciate(&self, config: Value) -> anyhow::Result> { + Ok(Arc::new(serde_yaml::from_value::(config)?)) + } +} +impl Node for Hosts { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { + let host = request + .headers() + .get(HOST) + .and_then(|e| e.to_str().ok()) + .ok_or(ServiceError::NoHost)?; + + let host = remove_port(&host); + let node = self.0.get(host).ok_or(ServiceError::UnknownHost)?; + + node.handle(context, request).await + }) + } +} + +pub fn remove_port(s: &str) -> &str { + s.split_once(":").map(|(s, _)| s).unwrap_or(s) +} diff --git a/src/filters/mod.rs b/src/filters/mod.rs index fdeed51..10520a3 100644 --- a/src/filters/mod.rs +++ b/src/filters/mod.rs @@ -1,5 +1,50 @@ +use crate::error::ServiceError; +use crate::State; +use accesslog::AccessLogKind; +use auth::HttpBasicAuthKind; +use bytes::Bytes; +use error::ErrorKind; +use files::FilesKind; +use futures::Future; +use hosts::HostsKind; +use http_body_util::combinators::BoxBody; +use hyper::{body::Incoming, Request, Response}; +use proxy::ProxyKind; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; +pub mod accesslog; pub mod auth; +pub mod error; pub mod files; +pub mod hosts; pub mod proxy; -pub mod accesslog; \ No newline at end of file + +pub type NodeRequest = Request; +pub type NodeResponse = Response>; + +pub static MODULES: &'static [&'static dyn NodeKind] = &[ + &HttpBasicAuthKind, + &ProxyKind, + &HostsKind, + &FilesKind, + &AccessLogKind, + &ErrorKind, +]; + +pub struct NodeContext { + pub state: Arc, + pub addr: SocketAddr, +} + +pub trait NodeKind: Send + Sync + 'static { + fn name(&self) -> &'static str; + fn instanciate(&self, config: Value) -> anyhow::Result>; +} +pub trait Node: Send + Sync + 'static { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>>; +} diff --git a/src/filters/proxy.rs b/src/filters/proxy.rs index ad959ad..ce72f65 100644 --- a/src/filters/proxy.rs +++ b/src/filters/proxy.rs @@ -1,87 +1,98 @@ -use crate::{helper::TokioIo, ServiceError, State}; -use http_body_util::{combinators::BoxBody, BodyExt}; -use hyper::{body::Incoming, http::HeaderValue, upgrade::OnUpgrade, Request, StatusCode}; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{helper::TokioIo, ServiceError}; +use futures::Future; +use http_body_util::BodyExt; +use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; use log::{debug, warn}; -use std::{net::SocketAddr, sync::Arc}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; use tokio::net::TcpStream; -pub async fn proxy_request( - state: &Arc, - mut req: Request, - addr: SocketAddr, - backend: &SocketAddr, -) -> Result>, ServiceError> { - #[cfg(feature = "mond")] - state.reporting.request_out.inc(); +#[derive(Default)] +pub struct ProxyKind; - //? Do we know what this did? - // *req.uri_mut() = Uri::builder() - // .path_and_query( - // req.uri() - // .clone() - // .path_and_query() - // .cloned() - // .unwrap_or(PathAndQuery::from_static("/")), - // ) - // .build() - // .unwrap(); - - req.headers_mut().insert( - "x-real-ip", - HeaderValue::from_str(&format!("{}", addr.ip())).unwrap(), - ); +#[derive(Debug, Deserialize)] +struct Proxy { + backend: SocketAddr, +} - let on_upgrade_downstream = req.extensions_mut().remove::(); +impl NodeKind for ProxyKind { + fn name(&self) -> &'static str { + "proxy" + } + fn instanciate(&self, config: Value) -> anyhow::Result> { + Ok(Arc::new(serde_yaml::from_value::(config)?)) + } +} +impl Node for Proxy { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + mut request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { + request.headers_mut().insert( + "x-real-ip", + HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), + ); - let _limit_guard = state.l_outgoing.try_acquire()?; - debug!("\tforwarding to {}", backend); - let mut resp = { - let client_stream = TokioIo( - TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?, - ); + let on_upgrade_downstream = request.extensions_mut().remove::(); - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.with_upgrades().await { - warn!("connection failed: {:?}", err); - } - }); - sender - .send_request(req) - .await - .map_err(ServiceError::Hyper)? - }; + let _limit_guard = context.state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", self.backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(self.backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); - if resp.status() == StatusCode::SWITCHING_PROTOCOLS { - let on_upgrade_upstream = resp - .extensions_mut() - .remove::() - .ok_or(ServiceError::UpgradeFailed)?; - let on_upgrade_downstream = on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; - tokio::task::spawn(async move { - debug!("about to upgrade connection"); - match (on_upgrade_upstream.await, on_upgrade_downstream.await) { - (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { - debug!("upgrade successful"); - match tokio::io::copy_bidirectional( - &mut TokioIo(upgraded_downstream), - &mut TokioIo(upgraded_upstream), - ) + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(request) .await - { - Ok((from_client, from_server)) => { - debug!("proxy socket terminated: {from_server} sent, {from_client} received") + .map_err(ServiceError::Hyper)? + }; + + if resp.status() == StatusCode::SWITCHING_PROTOCOLS { + let on_upgrade_upstream = resp + .extensions_mut() + .remove::() + .ok_or(ServiceError::UpgradeFailed)?; + let on_upgrade_downstream = + on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; + tokio::task::spawn(async move { + debug!("about to upgrade connection"); + match (on_upgrade_upstream.await, on_upgrade_downstream.await) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), + } } - Err(e) => warn!("proxy socket error: {e}"), + (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), } - } - (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), + }); } - }); + + let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed()); + Ok(resp) + }) } - Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) } diff --git a/src/helper.rs b/src/helper.rs index 5914be3..0daa3d5 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,5 +1,4 @@ // From https://github.com/hyperium/hyper/blob/master/benches/support/tokiort.rs - use pin_project::pin_project; use std::{ pin::Pin, diff --git a/src/main.rs b/src/main.rs index 0109a62..7c74f70 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,14 +10,10 @@ pub mod helper; #[cfg(feature = "mond")] pub mod reporting; -use crate::{ - config::{Config, RouteFilter}, - filters::{files::serve_files, proxy::proxy_request}, -}; use anyhow::{anyhow, Context, Result}; -use bytes::Bytes; -use config::setup_file_watch; +use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; +use filters::{NodeContext, MODULES}; use futures::future::try_join_all; use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; @@ -30,14 +26,11 @@ use hyper::{ Request, Response, StatusCode, }; use log::{debug, error, info, warn, LevelFilter}; -#[cfg(feature = "mond")] -use reporting::Reporting; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use std::{ collections::HashMap, io::BufReader, net::SocketAddr, - ops::ControlFlow, path::{Path, PathBuf}, process::exit, str::FromStr, @@ -57,14 +50,7 @@ pub struct State { pub access_logs: RwLock>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, - #[cfg(feature = "mond")] - pub reporting: Reporting, } -pub struct HostState {} - -pub type FilterRequest = Request; -pub type FilterResponseOut = Option>>; -pub type FilterResponse = Option>>; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -73,6 +59,11 @@ async fn main() -> anyhow::Result<()> { .parse_env("LOG") .init(); + NODE_KINDS + .write() + .unwrap() + .extend(MODULES.iter().map(|m| (m.name().to_owned(), *m))); + let Some(config_path) = std::env::args().skip(1).next() else { eprintln!("error: first argument is expected to be the configuration file"); exit(1) @@ -249,70 +240,17 @@ fn load_private_key(path: &Path) -> anyhow::Result> { async fn service( state: Arc, config: Arc, - req: Request, + request: Request, addr: SocketAddr, ) -> Result>, ServiceError> { - debug!("{addr} ~> {:?} {}", req.headers().get(HOST), req.uri()); - #[cfg(feature = "mond")] - state.reporting.request_in.inc(); - - let host = req - .headers() - .get(HOST) - .and_then(|e| e.to_str().ok()) - .map(String::from) - .unwrap_or(String::from("")); - let host = remove_port(&host); - let route = config.hosts.get(host).ok_or(ServiceError::NoHost)?; - #[cfg(feature = "mond")] - state.reporting.hosts.get(host).unwrap().requests_in.inc(); - - // TODO this code is horrible - let mut req = Some(req); - let mut resp = None; - for filter in &route.0 { - let cf = match filter { - RouteFilter::Proxy { backend } => { - resp = Some( - proxy_request( - &state, - req.take().ok_or(ServiceError::RequestTaken)?, - addr, - backend, - ) - .await?, - ); - ControlFlow::Continue(()) - } - RouteFilter::Files { config } => { - resp = Some( - serve_files(req.as_ref().ok_or(ServiceError::RequestTaken)?, config).await?, - ); - ControlFlow::Continue(()) - } - RouteFilter::HttpBasicAuth { config } => filters::auth::http_basic( - config, - req.as_ref().ok_or(ServiceError::RequestTaken)?, - &mut resp, - )?, - RouteFilter::AccessLog { config } => { - filters::accesslog::access_log( - &state, - host, - addr, - config, - req.as_ref().ok_or(ServiceError::RequestTaken)?, - ) - .await? - } - }; - match cf { - ControlFlow::Continue(_) => continue, - ControlFlow::Break(_) => break, - } - } + debug!( + "{addr} ~> {:?} {}", + request.headers().get(HOST), + request.uri() + ); - let mut resp = resp.ok_or(ServiceError::NoResponse)?; + let mut context = NodeContext { addr, state }; + let mut resp = config.handler.handle(&mut context, request).await?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( @@ -327,7 +265,3 @@ async fn service( return Ok(resp); } - -pub fn remove_port(s: &str) -> &str { - s.split_once(":").map(|(s, _)| s).unwrap_or(s) -} diff --git a/src/reporting.rs b/src/reporting.rs deleted file mode 100644 index ee30ac5..0000000 --- a/src/reporting.rs +++ /dev/null @@ -1,39 +0,0 @@ -use crate::config::Config; -use mond_client::{make_ident, Aspect, Push, Rate, Reporter}; -use std::{collections::HashMap, marker::PhantomData}; - -pub struct Reporting { - pub request_in: Aspect>, - pub request_out: Aspect>, - pub hosts: HashMap, - // pub connections: Aspect>, -} -pub struct HostReporting { - pub requests_in: Aspect>, -} - -impl Reporting { - pub fn new(config: &Config) -> Self { - let mut rep = Reporter::new(); - Self { - request_in: rep.create(make_ident!("requests-in"), Push(Rate(PhantomData::))), - request_out: rep.create(make_ident!("requests-out"), Push(Rate(PhantomData::))), - // connections: rep.create(make_ident!("connections"), Push()), - hosts: config - .hosts - .iter() - .map(|(k, _v)| { - ( - k.to_owned(), - HostReporting { - requests_in: rep.create( - make_ident!("host", k, "request-in"), - Push(Rate(PhantomData::)), - ), - }, - ) - }) - .collect(), - } - } -} -- cgit v1.2.3-70-g09d2