diff options
author | metamuffin <metamuffin@disroot.org> | 2024-05-29 16:37:44 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2024-05-29 16:37:44 +0200 |
commit | 886a18e0c67624d0882f04c7f6659bcfee6b4d8d (patch) | |
tree | 32a5389076b199c4e06fa10ce6b54d165d5466c5 | |
parent | 6cebab912dcf01bbe225c20ec2e7656f61ba160e (diff) | |
download | gnix-886a18e0c67624d0882f04c7f6659bcfee6b4d8d.tar gnix-886a18e0c67624d0882f04c7f6659bcfee6b4d8d.tar.bz2 gnix-886a18e0c67624d0882f04c7f6659bcfee6b4d8d.tar.zst |
refactor filter system
-rw-r--r-- | Cargo.lock | 44 | ||||
-rw-r--r-- | Cargo.toml | 7 | ||||
-rw-r--r-- | readme.md | 62 | ||||
-rw-r--r-- | src/config.rs | 126 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/filters/accesslog.rs | 106 | ||||
-rw-r--r-- | src/filters/auth.rs | 86 | ||||
-rw-r--r-- | src/filters/error.rs | 32 | ||||
-rw-r--r-- | src/filters/files.rs | 285 | ||||
-rw-r--r-- | src/filters/hosts.rs | 45 | ||||
-rw-r--r-- | src/filters/mod.rs | 47 | ||||
-rw-r--r-- | src/filters/proxy.rs | 157 | ||||
-rw-r--r-- | src/helper.rs | 1 | ||||
-rw-r--r-- | src/main.rs | 96 | ||||
-rw-r--r-- | src/reporting.rs | 39 |
15 files changed, 612 insertions, 527 deletions
@@ -141,25 +141,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" [[package]] -name = "bincode" -version = "2.0.0-rc.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f11ea1a0346b94ef188834a65c068a03aec181c94896d481d7a0a40d85b0ce95" -dependencies = [ - "bincode_derive", - "serde", -] - -[[package]] -name = "bincode_derive" -version = "2.0.0-rc.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e30759b3b99a1b802a7a3aa21c85c3ded5c28e1c83170d82d70f08bbf7f3e4c" -dependencies = [ - "virtue", -] - -[[package]] name = "bindgen" version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -495,7 +476,6 @@ dependencies = [ "log", "markup", "mime_guess", - "mond-client", "percent-encoding", "pin-project", "rustls", @@ -861,24 +841,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" [[package]] -name = "mond-client" -version = "0.1.0" -source = "git+https://codeberg.org/metamuffin/mond#6019696ebc243f1b743823f076af0201ab0d27d5" -dependencies = [ - "log", - "mond-protocol", -] - -[[package]] -name = "mond-protocol" -version = "0.1.0" -source = "git+https://codeberg.org/metamuffin/mond#6019696ebc243f1b743823f076af0201ab0d27d5" -dependencies = [ - "bincode", - "serde", -] - -[[package]] name = "nom" version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1383,12 +1345,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] -name = "virtue" -version = "0.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" - -[[package]] name = "want" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4,7 +4,6 @@ version = "1.0.0" edition = "2021" [dependencies] - # HTTP hyper = { version = "1.3.1", features = ["full"] } hyper-util = "0.1.3" @@ -44,9 +43,3 @@ mime_guess = "2.0.4" bytes = "1.6.0" anyhow = "1.0.82" thiserror = "1.0.59" - -mond-client = { git = "https://codeberg.org/metamuffin/mond", optional = true } - -[features] -default = [] -mond = ["dep:mond-client"] @@ -6,8 +6,9 @@ a simple stupid reverse proxy - Simple to configure (see below) - Handles connection upgrades correctly by default (websocket, etc.) +- Composable modules - TLS support -- _TODO: h2; match on uris; connection pools_ +- _TODO: h2; match on uris; connection pooling_ ## Quick Start @@ -18,31 +19,36 @@ configuration file is written in YAML and could look like this: # Both the 'http' and 'https' sections are optional http: # the value for 'bind' can either be a string or a list of strings - bind: [ "127.0.0.1:8080", "[::1]:8080" ] + bind: "[::1]:8080" https: - bind: "127.0.0.1:8443" + bind: "[::1]:8443" tls_cert: "ssl/cert.pem" - tls_key: "ssl/key.pem" # only accepts pkcs8 for now + tls_key: "ssl/key.pem" # only accepts pkcs8 -# this is a lookup table from hostnames to a list of filters -# in this case, requests for `testdomain.local` are forwarded to 127.0.0.1:3000 -hosts: - "testdomain.local": !proxy { backend: "127.0.0.1:8000" } - "192.168.178.39": !proxy { backend: "127.0.0.1:8000" } - "localhost": !files - root: "/home/muffin/videos" +# !hosts multiplexes requests for different hostnames. +handler: !hosts + # requests for `example.org` are forwarded to 127.0.0.1:8000 + "example.org": !proxy { backend: "127.0.0.1:8000" } + # requests for `mydomain.com` will access files from /srv/http + "mydomain.com": !files + root: "/srv/http" index: true + + "panel.mydomain.com": !access_log + ``` ## Reference - **section `http`** - `bind`: string or list of strings with addresses to listen on. + - **section `https`** - `bind`: string or list of strings with addresses to listen on. - `tls_cert`: path to the SSL certificate. (Sometimes called `fullchain.pem`) - `tls_key`: path to the SSL key. (Often called `key.pem` or `privkey.pem`) + - **section `limits`** - Note: Make sure you do not exceed the maximum file descriptor limit on your platform. @@ -50,21 +56,28 @@ hosts: connections. excess connections are rejected. Default: 512 - `max_outgoing_connections` number of maximum outgoing (upstream) connections. excess connections are rejected. Default: 256 -- **section `hosts`** - - A map from hostname (a string) to a _filter_ or a list of _filters_ + +- **section `handler`** + - A module to handle all requests. Usually an instance of `hosts`. + - `watch_config`: boolean if to watch the configuration file for changes and apply them accordingly. Default: true (Note: This will watch the entire parent directory of the config since most editors first move the file. Currently any change will trigger a reload. TODO) -### Filters +### Modules -- **filter `proxy`** +- **module `hosts`** + - Hands over the requests to different modules depending on the `host` header. + - Takes a map from hostname (string) to handler (module) + +- **module `proxy`** - Forwards the request as-is to some other server. the `x-real-ip` header is injected into the request. Connection upgrades are handled by direct forwarding of network traffic. - `backend`: socket address (string) to the backend server -- **filter `files`** + +- **module `files`** - Provides a simple built-in fileserver. The server handles `accept-ranges`. The `content-type` header is inferred from the file extension and falls back to `application/octet-stream`. If a directory is requested `index.html` will @@ -72,12 +85,25 @@ hosts: prepended to the response. - `root`: root directory to be served (string) - `index`: enables directory indexing (boolean) -- **filter `http_basic_auth`** + +- **module `http_basic_auth`** - Filters requests via HTTP Basic Authentification. Unauthorized clients will be challenged on every request. - - `realm`: string that does essentially nothing + - `realm`: describes what the user is logging into (most modern browsers dont show this anymore -_-) - `valid`: list of valid logins (string) in the format `<username>:<password>` (password in plain text). TODO: hashing + - `next`: a module to handle this request on successfully authentificated. (module) + +- **module `access_log`** + - Logs requests to a file. + - `file`: file path to log (string) + - `reject_on_fail`: rejects requests if log could not be written (boolean) + - `flush`: flushes log on every request (boolean) + - `next`: module for further handling of the request (module) + +- **module `error`** + - Rejects every request with a custom error message. + - Takes an error message (string) ## License diff --git a/src/config.rs b/src/config.rs index e661996..dfc4e73 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,34 +1,38 @@ -use crate::State; -use anyhow::Context; +use crate::{ + filters::{Node, NodeKind}, + State, +}; +use anyhow::{anyhow, Context}; use inotify::{EventMask, Inotify, WatchMask}; use log::{error, info}; use serde::{ de::{value, Error, SeqAccess, Visitor}, Deserialize, Deserializer, Serialize, }; +use serde_yaml::value::TaggedValue; use std::{ - collections::{HashMap, HashSet}, + collections::BTreeMap, fmt, fs::read_to_string, marker::PhantomData, net::SocketAddr, + ops::Deref, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, RwLock}, }; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Deserialize)] pub struct Config { - #[serde(default = "true_default")] + #[serde(default = "return_true")] pub watch_config: bool, pub http: Option<HttpConfig>, pub https: Option<HttpsConfig>, #[serde(default)] pub limits: Limits, - #[serde(default)] - pub hosts: HashMap<String, Route>, + pub handler: DynNode, } -fn true_default() -> bool { +pub fn return_true() -> bool { true } @@ -53,72 +57,8 @@ pub struct HttpsConfig { pub tls_key: PathBuf, } -#[derive(Debug, Serialize, Deserialize)] -pub struct Route(#[serde(deserialize_with = "seq_or_not")] pub Vec<RouteFilter>); - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum RouteFilter { - HttpBasicAuth { - #[serde(flatten)] - config: HttpBasicAuthConfig, - }, - Proxy { - backend: SocketAddr, - }, - Files { - #[serde(flatten)] - config: FileserverConfig, - }, - AccessLog { - #[serde(flatten)] - config: AccessLogConfig, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct AccessLogConfig { - pub file: PathBuf, - #[serde(default)] - pub flush: bool, - #[serde(default)] - pub reject_on_fail: bool, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct HttpBasicAuthConfig { - pub realm: String, - pub valid: HashSet<String>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FileserverConfig { - pub root: PathBuf, - #[serde(default)] - pub index: bool, - #[serde(default = "return_true")] - pub last_modified: bool, - #[serde(default = "return_true")] - pub etag: bool, - #[serde(default)] - pub cache: CacheConfig, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CacheConfig { - #[default] - Public, - Private, - NoStore, -} - -fn return_true() -> bool { - true -} - // try deser Vec<T> but fall back to deser T and putting that in Vec -fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result<Vec<V>, D::Error> +pub fn seq_or_not<'de, D, V: Deserialize<'de>>(des: D) -> Result<Vec<V>, D::Error> where D: Deserializer<'de>, { @@ -194,11 +134,45 @@ where des.deserialize_any(StringOrList) } +pub static NODE_KINDS: RwLock<BTreeMap<String, &'static dyn NodeKind>> = + RwLock::new(BTreeMap::new()); + +#[derive(Clone)] +pub struct DynNode(Arc<dyn Node>); + +impl<'de> Deserialize<'de> for DynNode { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + let tv = TaggedValue::deserialize(deserializer)?; + let s = tv.tag.to_string(); + let s = s.strip_prefix("!").unwrap_or(s.as_str()); + let inst = NODE_KINDS + .read() + .unwrap() + .get(s) + .ok_or(serde::de::Error::unknown_variant(s, &[]))? + .instanciate(tv.value) + .map_err(|e| { + serde::de::Error::custom(e.context(anyhow!("instanciating modules {s:?}"))) + })?; + + Ok(Self(inst)) + } +} +impl Deref for DynNode { + type Target = dyn Node; + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + impl Config { pub fn load(path: &Path) -> anyhow::Result<Config> { info!("loading config from {path:?}"); let raw = read_to_string(path).context("reading config file")?; - let config: Config = serde_yaml::from_str(&raw).context("parsing config")?; + let config: Config = serde_yaml::from_str(&raw).context("during parsing")?; Ok(config) } } @@ -235,7 +209,7 @@ pub fn setup_file_watch(config_path: PathBuf, state: Arc<State>) { let mut r = state.config.blocking_write(); *r = Arc::new(conf) } - Err(e) => error!("config has errors: {e}"), + Err(e) => error!("config has errors: {e:?}"), } } } diff --git a/src/error.rs b/src/error.rs index 1675b3c..e7e5af2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,8 +11,10 @@ pub enum ServiceError { Limit(#[from] TryAcquireError), #[error("hyper error")] Hyper(hyper::Error), - #[error("unknown host")] + #[error("no host")] NoHost, + #[error("unknown host")] + UnknownHost, #[error("can't connect to the backend")] CantConnect, #[error("not found")] @@ -33,6 +35,8 @@ pub enum ServiceError { BadBase64(#[from] base64::DecodeError), #[error("connection upgrade failed")] UpgradeFailed, + #[error("{0}")] + Custom(String), #[error("impossible error")] Other, } diff --git a/src/filters/accesslog.rs b/src/filters/accesslog.rs index 9a33762..1da6e5d 100644 --- a/src/filters/accesslog.rs +++ b/src/filters/accesslog.rs @@ -1,49 +1,81 @@ -use crate::{config::AccessLogConfig, error::ServiceError, FilterRequest, State}; -use futures::executor::block_on; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use futures::Future; use log::error; -use std::{net::SocketAddr, ops::ControlFlow, time::SystemTime}; +use serde::Deserialize; +use std::{path::PathBuf, pin::Pin, sync::Arc, time::SystemTime}; use tokio::{ - fs::OpenOptions, + fs::{File, OpenOptions}, io::{AsyncWriteExt, BufWriter}, + sync::RwLock, }; -pub async fn access_log( - state: &State, - host: &str, - addr: SocketAddr, - config: &AccessLogConfig, - req: &FilterRequest, -) -> Result<ControlFlow<()>, ServiceError> { - let mut g = state.access_logs.write().await; +pub struct AccessLogKind; - let log = g.entry(host.to_owned()).or_insert_with(|| { - BufWriter::new( - // TODO aaahh dont block the runtime and dont panic in any case.... - block_on( - OpenOptions::new() - .append(true) - .create(true) - .open(&config.file), - ) - .unwrap(), - ) - }); +#[derive(Deserialize)] +struct AccessLogConfig { + file: PathBuf, + #[serde(default)] + flush: bool, + #[serde(default)] + reject_on_fail: bool, + next: DynNode, +} - let method = req.method().as_str(); - let time = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros(); - let mut res = log - .write_all(format!("{time}\t{addr}\t{method}\t{:?}\n", req.uri()).as_bytes()) - .await; +struct AccessLog { + config: AccessLogConfig, + file: RwLock<Option<BufWriter<File>>>, +} - if config.flush && res.is_ok() { - res = log.flush().await; +impl NodeKind for AccessLogKind { + fn name(&self) -> &'static str { + "access_log" } - - if config.reject_on_fail { - res? - } else if let Err(e) = res { - error!("failed to write log: {e:?}") + fn instanciate(&self, config: serde_yaml::Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(AccessLog { + config: serde_yaml::from_value::<AccessLogConfig>(config)?, + file: Default::default(), + })) } +} + +impl Node for AccessLog { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let mut g = self.file.write().await; + let log = match g.as_mut() { + Some(r) => r, + None => g.insert(BufWriter::new( + OpenOptions::new() + .append(true) + .create(true) + .open(&self.config.file) + .await?, + )), + }; - Ok(ControlFlow::Continue(())) + let method = request.method().as_str(); + let time = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros(); + let addr = context.addr; + let mut res = log + .write_all(format!("{time}\t{addr}\t{method}\t{:?}\n", request.uri()).as_bytes()) + .await; + + if self.config.flush && res.is_ok() { + res = log.flush().await; + } + + if self.config.reject_on_fail { + res? + } else if let Err(e) = res { + error!("failed to write log: {e:?}") + } + + self.config.next.handle(context, request).await + }) + } } diff --git a/src/filters/auth.rs b/src/filters/auth.rs index 92a9ba3..7d5b03e 100644 --- a/src/filters/auth.rs +++ b/src/filters/auth.rs @@ -1,41 +1,65 @@ -use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use super::{Node, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; use base64::Engine; +use futures::Future; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, Response, StatusCode, }; use log::debug; -use std::ops::ControlFlow; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashSet, pin::Pin, sync::Arc}; -pub fn http_basic( - config: &HttpBasicAuthConfig, - req: &FilterRequest, - resp: &mut FilterResponseOut, -) -> Result<ControlFlow<()>, ServiceError> { - if let Some(auth) = req.headers().get(AUTHORIZATION) { - let k = auth - .as_bytes() - .strip_prefix(b"Basic ") - .ok_or(ServiceError::BadAuth)?; - let k = base64::engine::general_purpose::STANDARD.decode(k)?; - let k = String::from_utf8(k)?; - if config.valid.contains(&k) { - debug!("valid auth"); - return Ok(ControlFlow::Continue(())); - } else { - debug!("invalid auth"); - } +pub struct HttpBasicAuthKind; +impl NodeKind for HttpBasicAuthKind { + fn name(&self) -> &'static str { + "http_basic_auth" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn super::Node>> { + Ok(Arc::new(serde_yaml::from_value::<HttpBasicAuth>(config)?)) + } +} + +#[derive(Deserialize)] +pub struct HttpBasicAuth { + realm: String, + valid: HashSet<String>, + next: DynNode, +} + +impl Node for HttpBasicAuth { + fn handle<'a>( + &'a self, + context: &'a mut super::NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + if let Some(auth) = request.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if self.valid.contains(&k) { + debug!("valid auth"); + return self.next.handle(context, request).await; + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", self.realm)).unwrap(), + ); + Ok(r) + }) } - debug!("unauthorized; sending auth challenge"); - let mut r = Response::new(BoxBody::<_, ServiceError>::new( - String::new().map_err(|_| unreachable!()), - )); - *r.status_mut() = StatusCode::UNAUTHORIZED; - r.headers_mut().insert( - WWW_AUTHENTICATE, - HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), - ); - *resp = Some(r); - Ok(ControlFlow::Break(())) } diff --git a/src/filters/error.rs b/src/filters/error.rs new file mode 100644 index 0000000..504802f --- /dev/null +++ b/src/filters/error.rs @@ -0,0 +1,32 @@ +use crate::error::ServiceError; + +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use futures::Future; +use serde::Deserialize; +use serde_yaml::Value; +use std::{pin::Pin, sync::Arc}; + +pub struct ErrorKind; + +#[derive(Deserialize)] +#[serde(transparent)] +struct Error(String); + +impl NodeKind for ErrorKind { + fn name(&self) -> &'static str { + "error" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Error>(config)?)) + } +} + +impl Node for Error { + fn handle<'a>( + &'a self, + _context: &'a mut NodeContext, + _request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { Err(ServiceError::Custom(self.0.clone())) }) + } +} diff --git a/src/filters/files.rs b/src/filters/files.rs index ee40d70..fc3d63b 100644 --- a/src/filters/files.rs +++ b/src/filters/files.rs @@ -1,8 +1,7 @@ -use crate::{ - config::{CacheConfig, FileserverConfig}, - ServiceError, -}; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::return_true, ServiceError}; use bytes::{Bytes, BytesMut}; +use futures::Future; use futures_util::{future, future::Either, ready, stream, FutureExt, Stream, StreamExt}; use headers::{ AcceptRanges, CacheControl, ContentLength, ContentRange, ContentType, HeaderMapExt, @@ -11,154 +10,204 @@ use headers::{ use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; use humansize::FormatSizeOptions; use hyper::{ - body::{Frame, Incoming}, + body::Frame, header::{CONTENT_TYPE, LOCATION}, http::HeaderValue, - Request, Response, StatusCode, + Response, StatusCode, }; use log::debug; use markup::Render; use percent_encoding::percent_decode_str; -use std::{fs::Metadata, io, ops::Range, path::Path, pin::Pin, task::Poll}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{ + fs::Metadata, + io, + ops::Range, + path::{Path, PathBuf}, + pin::Pin, + sync::Arc, + task::Poll, +}; use tokio::{ fs::{read_to_string, File}, io::AsyncSeekExt, }; use tokio_util::io::poll_read_buf; -pub async fn serve_files( - req: &Request<Incoming>, - config: &FileserverConfig, -) -> Result<hyper::Response<BoxBody<Bytes, ServiceError>>, ServiceError> { - let rpath = req.uri().path(); +pub struct FilesKind; - let mut path = config.root.clone(); - let mut user_path_depth = 0; - for seg in rpath.split("/") { - let seg = percent_decode_str(seg).decode_utf8()?; +#[derive(Debug, Deserialize)] +struct Files { + root: PathBuf, + #[serde(default)] + index: bool, + #[serde(default = "return_true")] + last_modified: bool, + // #[serde(default = "return_true")] + // etag: bool, + #[serde(default)] + cache: CacheMode, +} - if seg == "" || seg == "." { - continue; - } +#[derive(Debug, Default, Deserialize)] +#[serde(rename_all = "snake_case")] +enum CacheMode { + #[default] + Public, + Private, + NoStore, +} - if seg == ".." { - if user_path_depth <= 0 { - return Err(ServiceError::BadPath); - } - path.pop(); - user_path_depth -= 1; - } else { - path.push(seg.as_ref()); - user_path_depth += 1; - } +impl NodeKind for FilesKind { + fn name(&self) -> &'static str { + "files" } - if !path.exists() { - return Err(ServiceError::NotFound); + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Files>(config)?)) } +} - let metadata = path.metadata()?; +impl Node for Files { + fn handle<'a>( + &'a self, + _context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let rpath = request.uri().path(); - if metadata.file_type().is_dir() { - debug!("sending index for {path:?}"); - if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { - return Ok(html_string_response(indexhtml)); - } + let mut path = self.root.clone(); + let mut user_path_depth = 0; + for seg in rpath.split("/") { + let seg = percent_decode_str(seg).decode_utf8()?; + + if seg == "" || seg == "." { + continue; + } - if config.index { - if !rpath.ends_with("/") { - let mut r = Response::new(String::new()); - *r.status_mut() = StatusCode::FOUND; - r.headers_mut().insert( - LOCATION, - HeaderValue::from_str(&format!("{}/", rpath)) - .map_err(|_| ServiceError::Other)?, - ); - return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + if seg == ".." { + if user_path_depth <= 0 { + return Err(ServiceError::BadPath); + } + path.pop(); + user_path_depth -= 1; + } else { + path.push(seg.as_ref()); + user_path_depth += 1; + } + } + if !path.exists() { + return Err(ServiceError::NotFound); } - return index(&path, rpath.to_string()) - .await - .map(html_string_response); - } else { - return Err(ServiceError::NotFound); - } - } + let metadata = path.metadata()?; - let modified = metadata.modified()?; + if metadata.file_type().is_dir() { + debug!("sending index for {path:?}"); + if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { + return Ok(html_string_response(indexhtml)); + } - let not_modified = if config.last_modified { - req.headers() - .typed_get::<headers::IfModifiedSince>() - .map(|if_modified_since| { - Ok::<_, ServiceError>(!if_modified_since.is_modified(modified)) - }) - .transpose()? - .unwrap_or_default() - } else { - false - }; + if self.index { + if !rpath.ends_with("/") { + let mut r = Response::new(String::new()); + *r.status_mut() = StatusCode::FOUND; + r.headers_mut().insert( + LOCATION, + HeaderValue::from_str(&format!("{}/", rpath)) + .map_err(|_| ServiceError::Other)?, + ); + return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + } + + return index(&path, rpath.to_string()) + .await + .map(html_string_response); + } else { + return Err(ServiceError::NotFound); + } + } - // let etag = ETag::from_str(&calc_etag(modified)).map_err(|_| ServiceError::Other)?; - // let etag_matches = if config.etag { - // req.headers() - // .typed_get::<headers::IfNoneMatch>() - // .map(|if_none_match| if_none_match.precondition_passes(&etag)) - // .unwrap_or_default() - // } else { - // false - // }; + let modified = metadata.modified()?; - let range = req.headers().typed_get::<headers::Range>(); - let range = bytes_range(range, metadata.len())?; + let not_modified = if self.last_modified { + request + .headers() + .typed_get::<headers::IfModifiedSince>() + .map(|if_modified_since| { + Ok::<_, ServiceError>(!if_modified_since.is_modified(modified)) + }) + .transpose()? + .unwrap_or_default() + } else { + false + }; - debug!("sending file {path:?}"); - let file = File::open(path.clone()).await?; + // let etag = ETag::from_str(&calc_etag(modified)).map_err(|_| ServiceError::Other)?; + // let etag_matches = if self.etag { + // request.headers() + // .typed_get::<headers::IfNoneMatch>() + // .map(|if_none_match| if_none_match.precondition_passes(&etag)) + // .unwrap_or_default() + // } else { + // false + // }; - // let skip_body = not_modified || etag_matches; - let skip_body = not_modified; - let mut r = if skip_body { - Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()) - } else { - Response::new(BoxBody::new(StreamBody::new( - StreamBody::new(file_stream(file, 4096, range.clone())) - .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), - ))) - }; + let range = request.headers().typed_get::<headers::Range>(); + let range = bytes_range(range, metadata.len())?; - if !skip_body { - if range.end - range.start != metadata.len() { - *r.status_mut() = StatusCode::PARTIAL_CONTENT; - r.headers_mut().typed_insert( - ContentRange::bytes(range.clone(), metadata.len()).expect("valid ContentRange"), - ); - } - } - // if not_modified || etag_matches { - if not_modified { - *r.status_mut() = StatusCode::NOT_MODIFIED; - } + debug!("sending file {path:?}"); + let file = File::open(path.clone()).await?; - r.headers_mut().typed_insert(AcceptRanges::bytes()); - r.headers_mut() - .typed_insert(ContentLength(range.end - range.start)); + // let skip_body = not_modified || etag_matches; + let skip_body = not_modified; + let mut r = if skip_body { + Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()) + } else { + Response::new(BoxBody::new(StreamBody::new( + StreamBody::new(file_stream(file, 4096, range.clone())) + .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), + ))) + }; - let mime = mime_guess::from_path(path).first_or_octet_stream(); - r.headers_mut().typed_insert(ContentType::from(mime)); + if !skip_body { + if range.end - range.start != metadata.len() { + *r.status_mut() = StatusCode::PARTIAL_CONTENT; + r.headers_mut().typed_insert( + ContentRange::bytes(range.clone(), metadata.len()) + .expect("valid ContentRange"), + ); + } + } + // if not_modified || etag_matches { + if not_modified { + *r.status_mut() = StatusCode::NOT_MODIFIED; + } - r.headers_mut().typed_insert(match config.cache { - CacheConfig::Public => CacheControl::new().with_public(), - CacheConfig::Private => CacheControl::new().with_private(), - CacheConfig::NoStore => CacheControl::new().with_no_store(), - }); + r.headers_mut().typed_insert(AcceptRanges::bytes()); + r.headers_mut() + .typed_insert(ContentLength(range.end - range.start)); - // if config.etag { - // r.headers_mut().typed_insert(etag); - // } - if config.last_modified { - r.headers_mut().typed_insert(LastModified::from(modified)); - } + let mime = mime_guess::from_path(path).first_or_octet_stream(); + r.headers_mut().typed_insert(ContentType::from(mime)); - Ok(r) + r.headers_mut().typed_insert(match self.cache { + CacheMode::Public => CacheControl::new().with_public(), + CacheMode::Private => CacheControl::new().with_private(), + CacheMode::NoStore => CacheControl::new().with_no_store(), + }); + + // if self.etag { + // r.headers_mut().typed_insert(etag); + // } + if self.last_modified { + r.headers_mut().typed_insert(LastModified::from(modified)); + } + + Ok(r) + }) + } } // Adapted from warp (https://github.com/seanmonstar/warp/blob/master/src/filters/fs.rs). Thanks! diff --git a/src/filters/hosts.rs b/src/filters/hosts.rs new file mode 100644 index 0000000..286d478 --- /dev/null +++ b/src/filters/hosts.rs @@ -0,0 +1,45 @@ +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use futures::Future; +use hyper::header::HOST; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashMap, pin::Pin, sync::Arc}; + +#[derive(Deserialize)] +#[serde(transparent)] +struct Hosts(HashMap<String, DynNode>); + +pub struct HostsKind; +impl NodeKind for HostsKind { + fn name(&self) -> &'static str { + "hosts" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Hosts>(config)?)) + } +} +impl Node for Hosts { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let host = request + .headers() + .get(HOST) + .and_then(|e| e.to_str().ok()) + .ok_or(ServiceError::NoHost)?; + + let host = remove_port(&host); + let node = self.0.get(host).ok_or(ServiceError::UnknownHost)?; + + node.handle(context, request).await + }) + } +} + +pub fn remove_port(s: &str) -> &str { + s.split_once(":").map(|(s, _)| s).unwrap_or(s) +} diff --git a/src/filters/mod.rs b/src/filters/mod.rs index fdeed51..10520a3 100644 --- a/src/filters/mod.rs +++ b/src/filters/mod.rs @@ -1,5 +1,50 @@ +use crate::error::ServiceError; +use crate::State; +use accesslog::AccessLogKind; +use auth::HttpBasicAuthKind; +use bytes::Bytes; +use error::ErrorKind; +use files::FilesKind; +use futures::Future; +use hosts::HostsKind; +use http_body_util::combinators::BoxBody; +use hyper::{body::Incoming, Request, Response}; +use proxy::ProxyKind; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; +pub mod accesslog; pub mod auth; +pub mod error; pub mod files; +pub mod hosts; pub mod proxy; -pub mod accesslog;
\ No newline at end of file + +pub type NodeRequest = Request<Incoming>; +pub type NodeResponse = Response<BoxBody<Bytes, ServiceError>>; + +pub static MODULES: &'static [&'static dyn NodeKind] = &[ + &HttpBasicAuthKind, + &ProxyKind, + &HostsKind, + &FilesKind, + &AccessLogKind, + &ErrorKind, +]; + +pub struct NodeContext { + pub state: Arc<State>, + pub addr: SocketAddr, +} + +pub trait NodeKind: Send + Sync + 'static { + fn name(&self) -> &'static str; + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>>; +} +pub trait Node: Send + Sync + 'static { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>>; +} diff --git a/src/filters/proxy.rs b/src/filters/proxy.rs index ad959ad..ce72f65 100644 --- a/src/filters/proxy.rs +++ b/src/filters/proxy.rs @@ -1,87 +1,98 @@ -use crate::{helper::TokioIo, ServiceError, State}; -use http_body_util::{combinators::BoxBody, BodyExt}; -use hyper::{body::Incoming, http::HeaderValue, upgrade::OnUpgrade, Request, StatusCode}; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{helper::TokioIo, ServiceError}; +use futures::Future; +use http_body_util::BodyExt; +use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; use log::{debug, warn}; -use std::{net::SocketAddr, sync::Arc}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; use tokio::net::TcpStream; -pub async fn proxy_request( - state: &Arc<State>, - mut req: Request<Incoming>, - addr: SocketAddr, - backend: &SocketAddr, -) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> { - #[cfg(feature = "mond")] - state.reporting.request_out.inc(); +#[derive(Default)] +pub struct ProxyKind; - //? Do we know what this did? - // *req.uri_mut() = Uri::builder() - // .path_and_query( - // req.uri() - // .clone() - // .path_and_query() - // .cloned() - // .unwrap_or(PathAndQuery::from_static("/")), - // ) - // .build() - // .unwrap(); - - req.headers_mut().insert( - "x-real-ip", - HeaderValue::from_str(&format!("{}", addr.ip())).unwrap(), - ); +#[derive(Debug, Deserialize)] +struct Proxy { + backend: SocketAddr, +} - let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>(); +impl NodeKind for ProxyKind { + fn name(&self) -> &'static str { + "proxy" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Proxy>(config)?)) + } +} +impl Node for Proxy { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + mut request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + request.headers_mut().insert( + "x-real-ip", + HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), + ); - let _limit_guard = state.l_outgoing.try_acquire()?; - debug!("\tforwarding to {}", backend); - let mut resp = { - let client_stream = TokioIo( - TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?, - ); + let on_upgrade_downstream = request.extensions_mut().remove::<OnUpgrade>(); - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.with_upgrades().await { - warn!("connection failed: {:?}", err); - } - }); - sender - .send_request(req) - .await - .map_err(ServiceError::Hyper)? - }; + let _limit_guard = context.state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", self.backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(self.backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); - if resp.status() == StatusCode::SWITCHING_PROTOCOLS { - let on_upgrade_upstream = resp - .extensions_mut() - .remove::<OnUpgrade>() - .ok_or(ServiceError::UpgradeFailed)?; - let on_upgrade_downstream = on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; - tokio::task::spawn(async move { - debug!("about to upgrade connection"); - match (on_upgrade_upstream.await, on_upgrade_downstream.await) { - (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { - debug!("upgrade successful"); - match tokio::io::copy_bidirectional( - &mut TokioIo(upgraded_downstream), - &mut TokioIo(upgraded_upstream), - ) + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(request) .await - { - Ok((from_client, from_server)) => { - debug!("proxy socket terminated: {from_server} sent, {from_client} received") + .map_err(ServiceError::Hyper)? + }; + + if resp.status() == StatusCode::SWITCHING_PROTOCOLS { + let on_upgrade_upstream = resp + .extensions_mut() + .remove::<OnUpgrade>() + .ok_or(ServiceError::UpgradeFailed)?; + let on_upgrade_downstream = + on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; + tokio::task::spawn(async move { + debug!("about to upgrade connection"); + match (on_upgrade_upstream.await, on_upgrade_downstream.await) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), + } } - Err(e) => warn!("proxy socket error: {e}"), + (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), } - } - (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), + }); } - }); + + let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed()); + Ok(resp) + }) } - Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) } diff --git a/src/helper.rs b/src/helper.rs index 5914be3..0daa3d5 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,5 +1,4 @@ // From https://github.com/hyperium/hyper/blob/master/benches/support/tokiort.rs - use pin_project::pin_project; use std::{ pin::Pin, diff --git a/src/main.rs b/src/main.rs index 0109a62..7c74f70 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,14 +10,10 @@ pub mod helper; #[cfg(feature = "mond")] pub mod reporting; -use crate::{ - config::{Config, RouteFilter}, - filters::{files::serve_files, proxy::proxy_request}, -}; use anyhow::{anyhow, Context, Result}; -use bytes::Bytes; -use config::setup_file_watch; +use config::{setup_file_watch, Config, NODE_KINDS}; use error::ServiceError; +use filters::{NodeContext, MODULES}; use futures::future::try_join_all; use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; @@ -30,14 +26,11 @@ use hyper::{ Request, Response, StatusCode, }; use log::{debug, error, info, warn, LevelFilter}; -#[cfg(feature = "mond")] -use reporting::Reporting; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use std::{ collections::HashMap, io::BufReader, net::SocketAddr, - ops::ControlFlow, path::{Path, PathBuf}, process::exit, str::FromStr, @@ -57,14 +50,7 @@ pub struct State { pub access_logs: RwLock<HashMap<String, BufWriter<File>>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, - #[cfg(feature = "mond")] - pub reporting: Reporting, } -pub struct HostState {} - -pub type FilterRequest = Request<Incoming>; -pub type FilterResponseOut = Option<Response<BoxBody<Bytes, ServiceError>>>; -pub type FilterResponse = Option<Response<BoxBody<Bytes, ServiceError>>>; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -73,6 +59,11 @@ async fn main() -> anyhow::Result<()> { .parse_env("LOG") .init(); + NODE_KINDS + .write() + .unwrap() + .extend(MODULES.iter().map(|m| (m.name().to_owned(), *m))); + let Some(config_path) = std::env::args().skip(1).next() else { eprintln!("error: first argument is expected to be the configuration file"); exit(1) @@ -249,70 +240,17 @@ fn load_private_key(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> { async fn service( state: Arc<State>, config: Arc<Config>, - req: Request<Incoming>, + request: Request<Incoming>, addr: SocketAddr, ) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> { - debug!("{addr} ~> {:?} {}", req.headers().get(HOST), req.uri()); - #[cfg(feature = "mond")] - state.reporting.request_in.inc(); - - let host = req - .headers() - .get(HOST) - .and_then(|e| e.to_str().ok()) - .map(String::from) - .unwrap_or(String::from("")); - let host = remove_port(&host); - let route = config.hosts.get(host).ok_or(ServiceError::NoHost)?; - #[cfg(feature = "mond")] - state.reporting.hosts.get(host).unwrap().requests_in.inc(); - - // TODO this code is horrible - let mut req = Some(req); - let mut resp = None; - for filter in &route.0 { - let cf = match filter { - RouteFilter::Proxy { backend } => { - resp = Some( - proxy_request( - &state, - req.take().ok_or(ServiceError::RequestTaken)?, - addr, - backend, - ) - .await?, - ); - ControlFlow::Continue(()) - } - RouteFilter::Files { config } => { - resp = Some( - serve_files(req.as_ref().ok_or(ServiceError::RequestTaken)?, config).await?, - ); - ControlFlow::Continue(()) - } - RouteFilter::HttpBasicAuth { config } => filters::auth::http_basic( - config, - req.as_ref().ok_or(ServiceError::RequestTaken)?, - &mut resp, - )?, - RouteFilter::AccessLog { config } => { - filters::accesslog::access_log( - &state, - host, - addr, - config, - req.as_ref().ok_or(ServiceError::RequestTaken)?, - ) - .await? - } - }; - match cf { - ControlFlow::Continue(_) => continue, - ControlFlow::Break(_) => break, - } - } + debug!( + "{addr} ~> {:?} {}", + request.headers().get(HOST), + request.uri() + ); - let mut resp = resp.ok_or(ServiceError::NoResponse)?; + let mut context = NodeContext { addr, state }; + let mut resp = config.handler.handle(&mut context, request).await?; let server_header = resp.headers().get(SERVER).cloned(); resp.headers_mut().insert( @@ -327,7 +265,3 @@ async fn service( return Ok(resp); } - -pub fn remove_port(s: &str) -> &str { - s.split_once(":").map(|(s, _)| s).unwrap_or(s) -} diff --git a/src/reporting.rs b/src/reporting.rs deleted file mode 100644 index ee30ac5..0000000 --- a/src/reporting.rs +++ /dev/null @@ -1,39 +0,0 @@ -use crate::config::Config; -use mond_client::{make_ident, Aspect, Push, Rate, Reporter}; -use std::{collections::HashMap, marker::PhantomData}; - -pub struct Reporting { - pub request_in: Aspect<Rate<i64>>, - pub request_out: Aspect<Rate<i64>>, - pub hosts: HashMap<String, HostReporting>, - // pub connections: Aspect<State<i64>>, -} -pub struct HostReporting { - pub requests_in: Aspect<Rate<i64>>, -} - -impl Reporting { - pub fn new(config: &Config) -> Self { - let mut rep = Reporter::new(); - Self { - request_in: rep.create(make_ident!("requests-in"), Push(Rate(PhantomData::<i64>))), - request_out: rep.create(make_ident!("requests-out"), Push(Rate(PhantomData::<i64>))), - // connections: rep.create(make_ident!("connections"), Push()), - hosts: config - .hosts - .iter() - .map(|(k, _v)| { - ( - k.to_owned(), - HostReporting { - requests_in: rep.create( - make_ident!("host", k, "request-in"), - Push(Rate(PhantomData::<i64>)), - ), - }, - ) - }) - .collect(), - } - } -} |