diff options
author | metamuffin <metamuffin@disroot.org> | 2024-05-29 16:37:44 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2024-05-29 16:37:44 +0200 |
commit | 886a18e0c67624d0882f04c7f6659bcfee6b4d8d (patch) | |
tree | 32a5389076b199c4e06fa10ce6b54d165d5466c5 /src/filters | |
parent | 6cebab912dcf01bbe225c20ec2e7656f61ba160e (diff) | |
download | gnix-886a18e0c67624d0882f04c7f6659bcfee6b4d8d.tar gnix-886a18e0c67624d0882f04c7f6659bcfee6b4d8d.tar.bz2 gnix-886a18e0c67624d0882f04c7f6659bcfee6b4d8d.tar.zst |
refactor filter system
Diffstat (limited to 'src/filters')
-rw-r--r-- | src/filters/accesslog.rs | 106 | ||||
-rw-r--r-- | src/filters/auth.rs | 86 | ||||
-rw-r--r-- | src/filters/error.rs | 32 | ||||
-rw-r--r-- | src/filters/files.rs | 285 | ||||
-rw-r--r-- | src/filters/hosts.rs | 45 | ||||
-rw-r--r-- | src/filters/mod.rs | 47 | ||||
-rw-r--r-- | src/filters/proxy.rs | 157 |
7 files changed, 498 insertions, 260 deletions
diff --git a/src/filters/accesslog.rs b/src/filters/accesslog.rs index 9a33762..1da6e5d 100644 --- a/src/filters/accesslog.rs +++ b/src/filters/accesslog.rs @@ -1,49 +1,81 @@ -use crate::{config::AccessLogConfig, error::ServiceError, FilterRequest, State}; -use futures::executor::block_on; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use futures::Future; use log::error; -use std::{net::SocketAddr, ops::ControlFlow, time::SystemTime}; +use serde::Deserialize; +use std::{path::PathBuf, pin::Pin, sync::Arc, time::SystemTime}; use tokio::{ - fs::OpenOptions, + fs::{File, OpenOptions}, io::{AsyncWriteExt, BufWriter}, + sync::RwLock, }; -pub async fn access_log( - state: &State, - host: &str, - addr: SocketAddr, - config: &AccessLogConfig, - req: &FilterRequest, -) -> Result<ControlFlow<()>, ServiceError> { - let mut g = state.access_logs.write().await; +pub struct AccessLogKind; - let log = g.entry(host.to_owned()).or_insert_with(|| { - BufWriter::new( - // TODO aaahh dont block the runtime and dont panic in any case.... - block_on( - OpenOptions::new() - .append(true) - .create(true) - .open(&config.file), - ) - .unwrap(), - ) - }); +#[derive(Deserialize)] +struct AccessLogConfig { + file: PathBuf, + #[serde(default)] + flush: bool, + #[serde(default)] + reject_on_fail: bool, + next: DynNode, +} - let method = req.method().as_str(); - let time = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros(); - let mut res = log - .write_all(format!("{time}\t{addr}\t{method}\t{:?}\n", req.uri()).as_bytes()) - .await; +struct AccessLog { + config: AccessLogConfig, + file: RwLock<Option<BufWriter<File>>>, +} - if config.flush && res.is_ok() { - res = log.flush().await; +impl NodeKind for AccessLogKind { + fn name(&self) -> &'static str { + "access_log" } - - if config.reject_on_fail { - res? - } else if let Err(e) = res { - error!("failed to write log: {e:?}") + fn instanciate(&self, config: serde_yaml::Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(AccessLog { + config: serde_yaml::from_value::<AccessLogConfig>(config)?, + file: Default::default(), + })) } +} + +impl Node for AccessLog { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let mut g = self.file.write().await; + let log = match g.as_mut() { + Some(r) => r, + None => g.insert(BufWriter::new( + OpenOptions::new() + .append(true) + .create(true) + .open(&self.config.file) + .await?, + )), + }; - Ok(ControlFlow::Continue(())) + let method = request.method().as_str(); + let time = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros(); + let addr = context.addr; + let mut res = log + .write_all(format!("{time}\t{addr}\t{method}\t{:?}\n", request.uri()).as_bytes()) + .await; + + if self.config.flush && res.is_ok() { + res = log.flush().await; + } + + if self.config.reject_on_fail { + res? + } else if let Err(e) = res { + error!("failed to write log: {e:?}") + } + + self.config.next.handle(context, request).await + }) + } } diff --git a/src/filters/auth.rs b/src/filters/auth.rs index 92a9ba3..7d5b03e 100644 --- a/src/filters/auth.rs +++ b/src/filters/auth.rs @@ -1,41 +1,65 @@ -use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use super::{Node, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; use base64::Engine; +use futures::Future; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, Response, StatusCode, }; use log::debug; -use std::ops::ControlFlow; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashSet, pin::Pin, sync::Arc}; -pub fn http_basic( - config: &HttpBasicAuthConfig, - req: &FilterRequest, - resp: &mut FilterResponseOut, -) -> Result<ControlFlow<()>, ServiceError> { - if let Some(auth) = req.headers().get(AUTHORIZATION) { - let k = auth - .as_bytes() - .strip_prefix(b"Basic ") - .ok_or(ServiceError::BadAuth)?; - let k = base64::engine::general_purpose::STANDARD.decode(k)?; - let k = String::from_utf8(k)?; - if config.valid.contains(&k) { - debug!("valid auth"); - return Ok(ControlFlow::Continue(())); - } else { - debug!("invalid auth"); - } +pub struct HttpBasicAuthKind; +impl NodeKind for HttpBasicAuthKind { + fn name(&self) -> &'static str { + "http_basic_auth" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn super::Node>> { + Ok(Arc::new(serde_yaml::from_value::<HttpBasicAuth>(config)?)) + } +} + +#[derive(Deserialize)] +pub struct HttpBasicAuth { + realm: String, + valid: HashSet<String>, + next: DynNode, +} + +impl Node for HttpBasicAuth { + fn handle<'a>( + &'a self, + context: &'a mut super::NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + if let Some(auth) = request.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if self.valid.contains(&k) { + debug!("valid auth"); + return self.next.handle(context, request).await; + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", self.realm)).unwrap(), + ); + Ok(r) + }) } - debug!("unauthorized; sending auth challenge"); - let mut r = Response::new(BoxBody::<_, ServiceError>::new( - String::new().map_err(|_| unreachable!()), - )); - *r.status_mut() = StatusCode::UNAUTHORIZED; - r.headers_mut().insert( - WWW_AUTHENTICATE, - HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), - ); - *resp = Some(r); - Ok(ControlFlow::Break(())) } diff --git a/src/filters/error.rs b/src/filters/error.rs new file mode 100644 index 0000000..504802f --- /dev/null +++ b/src/filters/error.rs @@ -0,0 +1,32 @@ +use crate::error::ServiceError; + +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use futures::Future; +use serde::Deserialize; +use serde_yaml::Value; +use std::{pin::Pin, sync::Arc}; + +pub struct ErrorKind; + +#[derive(Deserialize)] +#[serde(transparent)] +struct Error(String); + +impl NodeKind for ErrorKind { + fn name(&self) -> &'static str { + "error" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Error>(config)?)) + } +} + +impl Node for Error { + fn handle<'a>( + &'a self, + _context: &'a mut NodeContext, + _request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { Err(ServiceError::Custom(self.0.clone())) }) + } +} diff --git a/src/filters/files.rs b/src/filters/files.rs index ee40d70..fc3d63b 100644 --- a/src/filters/files.rs +++ b/src/filters/files.rs @@ -1,8 +1,7 @@ -use crate::{ - config::{CacheConfig, FileserverConfig}, - ServiceError, -}; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::return_true, ServiceError}; use bytes::{Bytes, BytesMut}; +use futures::Future; use futures_util::{future, future::Either, ready, stream, FutureExt, Stream, StreamExt}; use headers::{ AcceptRanges, CacheControl, ContentLength, ContentRange, ContentType, HeaderMapExt, @@ -11,154 +10,204 @@ use headers::{ use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; use humansize::FormatSizeOptions; use hyper::{ - body::{Frame, Incoming}, + body::Frame, header::{CONTENT_TYPE, LOCATION}, http::HeaderValue, - Request, Response, StatusCode, + Response, StatusCode, }; use log::debug; use markup::Render; use percent_encoding::percent_decode_str; -use std::{fs::Metadata, io, ops::Range, path::Path, pin::Pin, task::Poll}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{ + fs::Metadata, + io, + ops::Range, + path::{Path, PathBuf}, + pin::Pin, + sync::Arc, + task::Poll, +}; use tokio::{ fs::{read_to_string, File}, io::AsyncSeekExt, }; use tokio_util::io::poll_read_buf; -pub async fn serve_files( - req: &Request<Incoming>, - config: &FileserverConfig, -) -> Result<hyper::Response<BoxBody<Bytes, ServiceError>>, ServiceError> { - let rpath = req.uri().path(); +pub struct FilesKind; - let mut path = config.root.clone(); - let mut user_path_depth = 0; - for seg in rpath.split("/") { - let seg = percent_decode_str(seg).decode_utf8()?; +#[derive(Debug, Deserialize)] +struct Files { + root: PathBuf, + #[serde(default)] + index: bool, + #[serde(default = "return_true")] + last_modified: bool, + // #[serde(default = "return_true")] + // etag: bool, + #[serde(default)] + cache: CacheMode, +} - if seg == "" || seg == "." { - continue; - } +#[derive(Debug, Default, Deserialize)] +#[serde(rename_all = "snake_case")] +enum CacheMode { + #[default] + Public, + Private, + NoStore, +} - if seg == ".." { - if user_path_depth <= 0 { - return Err(ServiceError::BadPath); - } - path.pop(); - user_path_depth -= 1; - } else { - path.push(seg.as_ref()); - user_path_depth += 1; - } +impl NodeKind for FilesKind { + fn name(&self) -> &'static str { + "files" } - if !path.exists() { - return Err(ServiceError::NotFound); + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Files>(config)?)) } +} - let metadata = path.metadata()?; +impl Node for Files { + fn handle<'a>( + &'a self, + _context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let rpath = request.uri().path(); - if metadata.file_type().is_dir() { - debug!("sending index for {path:?}"); - if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { - return Ok(html_string_response(indexhtml)); - } + let mut path = self.root.clone(); + let mut user_path_depth = 0; + for seg in rpath.split("/") { + let seg = percent_decode_str(seg).decode_utf8()?; + + if seg == "" || seg == "." { + continue; + } - if config.index { - if !rpath.ends_with("/") { - let mut r = Response::new(String::new()); - *r.status_mut() = StatusCode::FOUND; - r.headers_mut().insert( - LOCATION, - HeaderValue::from_str(&format!("{}/", rpath)) - .map_err(|_| ServiceError::Other)?, - ); - return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + if seg == ".." { + if user_path_depth <= 0 { + return Err(ServiceError::BadPath); + } + path.pop(); + user_path_depth -= 1; + } else { + path.push(seg.as_ref()); + user_path_depth += 1; + } + } + if !path.exists() { + return Err(ServiceError::NotFound); } - return index(&path, rpath.to_string()) - .await - .map(html_string_response); - } else { - return Err(ServiceError::NotFound); - } - } + let metadata = path.metadata()?; - let modified = metadata.modified()?; + if metadata.file_type().is_dir() { + debug!("sending index for {path:?}"); + if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { + return Ok(html_string_response(indexhtml)); + } - let not_modified = if config.last_modified { - req.headers() - .typed_get::<headers::IfModifiedSince>() - .map(|if_modified_since| { - Ok::<_, ServiceError>(!if_modified_since.is_modified(modified)) - }) - .transpose()? - .unwrap_or_default() - } else { - false - }; + if self.index { + if !rpath.ends_with("/") { + let mut r = Response::new(String::new()); + *r.status_mut() = StatusCode::FOUND; + r.headers_mut().insert( + LOCATION, + HeaderValue::from_str(&format!("{}/", rpath)) + .map_err(|_| ServiceError::Other)?, + ); + return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + } + + return index(&path, rpath.to_string()) + .await + .map(html_string_response); + } else { + return Err(ServiceError::NotFound); + } + } - // let etag = ETag::from_str(&calc_etag(modified)).map_err(|_| ServiceError::Other)?; - // let etag_matches = if config.etag { - // req.headers() - // .typed_get::<headers::IfNoneMatch>() - // .map(|if_none_match| if_none_match.precondition_passes(&etag)) - // .unwrap_or_default() - // } else { - // false - // }; + let modified = metadata.modified()?; - let range = req.headers().typed_get::<headers::Range>(); - let range = bytes_range(range, metadata.len())?; + let not_modified = if self.last_modified { + request + .headers() + .typed_get::<headers::IfModifiedSince>() + .map(|if_modified_since| { + Ok::<_, ServiceError>(!if_modified_since.is_modified(modified)) + }) + .transpose()? + .unwrap_or_default() + } else { + false + }; - debug!("sending file {path:?}"); - let file = File::open(path.clone()).await?; + // let etag = ETag::from_str(&calc_etag(modified)).map_err(|_| ServiceError::Other)?; + // let etag_matches = if self.etag { + // request.headers() + // .typed_get::<headers::IfNoneMatch>() + // .map(|if_none_match| if_none_match.precondition_passes(&etag)) + // .unwrap_or_default() + // } else { + // false + // }; - // let skip_body = not_modified || etag_matches; - let skip_body = not_modified; - let mut r = if skip_body { - Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()) - } else { - Response::new(BoxBody::new(StreamBody::new( - StreamBody::new(file_stream(file, 4096, range.clone())) - .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), - ))) - }; + let range = request.headers().typed_get::<headers::Range>(); + let range = bytes_range(range, metadata.len())?; - if !skip_body { - if range.end - range.start != metadata.len() { - *r.status_mut() = StatusCode::PARTIAL_CONTENT; - r.headers_mut().typed_insert( - ContentRange::bytes(range.clone(), metadata.len()).expect("valid ContentRange"), - ); - } - } - // if not_modified || etag_matches { - if not_modified { - *r.status_mut() = StatusCode::NOT_MODIFIED; - } + debug!("sending file {path:?}"); + let file = File::open(path.clone()).await?; - r.headers_mut().typed_insert(AcceptRanges::bytes()); - r.headers_mut() - .typed_insert(ContentLength(range.end - range.start)); + // let skip_body = not_modified || etag_matches; + let skip_body = not_modified; + let mut r = if skip_body { + Response::new("".to_string()).map(|b| b.map_err(|e| match e {}).boxed()) + } else { + Response::new(BoxBody::new(StreamBody::new( + StreamBody::new(file_stream(file, 4096, range.clone())) + .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), + ))) + }; - let mime = mime_guess::from_path(path).first_or_octet_stream(); - r.headers_mut().typed_insert(ContentType::from(mime)); + if !skip_body { + if range.end - range.start != metadata.len() { + *r.status_mut() = StatusCode::PARTIAL_CONTENT; + r.headers_mut().typed_insert( + ContentRange::bytes(range.clone(), metadata.len()) + .expect("valid ContentRange"), + ); + } + } + // if not_modified || etag_matches { + if not_modified { + *r.status_mut() = StatusCode::NOT_MODIFIED; + } - r.headers_mut().typed_insert(match config.cache { - CacheConfig::Public => CacheControl::new().with_public(), - CacheConfig::Private => CacheControl::new().with_private(), - CacheConfig::NoStore => CacheControl::new().with_no_store(), - }); + r.headers_mut().typed_insert(AcceptRanges::bytes()); + r.headers_mut() + .typed_insert(ContentLength(range.end - range.start)); - // if config.etag { - // r.headers_mut().typed_insert(etag); - // } - if config.last_modified { - r.headers_mut().typed_insert(LastModified::from(modified)); - } + let mime = mime_guess::from_path(path).first_or_octet_stream(); + r.headers_mut().typed_insert(ContentType::from(mime)); - Ok(r) + r.headers_mut().typed_insert(match self.cache { + CacheMode::Public => CacheControl::new().with_public(), + CacheMode::Private => CacheControl::new().with_private(), + CacheMode::NoStore => CacheControl::new().with_no_store(), + }); + + // if self.etag { + // r.headers_mut().typed_insert(etag); + // } + if self.last_modified { + r.headers_mut().typed_insert(LastModified::from(modified)); + } + + Ok(r) + }) + } } // Adapted from warp (https://github.com/seanmonstar/warp/blob/master/src/filters/fs.rs). Thanks! diff --git a/src/filters/hosts.rs b/src/filters/hosts.rs new file mode 100644 index 0000000..286d478 --- /dev/null +++ b/src/filters/hosts.rs @@ -0,0 +1,45 @@ +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{config::DynNode, error::ServiceError}; +use futures::Future; +use hyper::header::HOST; +use serde::Deserialize; +use serde_yaml::Value; +use std::{collections::HashMap, pin::Pin, sync::Arc}; + +#[derive(Deserialize)] +#[serde(transparent)] +struct Hosts(HashMap<String, DynNode>); + +pub struct HostsKind; +impl NodeKind for HostsKind { + fn name(&self) -> &'static str { + "hosts" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Hosts>(config)?)) + } +} +impl Node for Hosts { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + let host = request + .headers() + .get(HOST) + .and_then(|e| e.to_str().ok()) + .ok_or(ServiceError::NoHost)?; + + let host = remove_port(&host); + let node = self.0.get(host).ok_or(ServiceError::UnknownHost)?; + + node.handle(context, request).await + }) + } +} + +pub fn remove_port(s: &str) -> &str { + s.split_once(":").map(|(s, _)| s).unwrap_or(s) +} diff --git a/src/filters/mod.rs b/src/filters/mod.rs index fdeed51..10520a3 100644 --- a/src/filters/mod.rs +++ b/src/filters/mod.rs @@ -1,5 +1,50 @@ +use crate::error::ServiceError; +use crate::State; +use accesslog::AccessLogKind; +use auth::HttpBasicAuthKind; +use bytes::Bytes; +use error::ErrorKind; +use files::FilesKind; +use futures::Future; +use hosts::HostsKind; +use http_body_util::combinators::BoxBody; +use hyper::{body::Incoming, Request, Response}; +use proxy::ProxyKind; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; +pub mod accesslog; pub mod auth; +pub mod error; pub mod files; +pub mod hosts; pub mod proxy; -pub mod accesslog;
\ No newline at end of file + +pub type NodeRequest = Request<Incoming>; +pub type NodeResponse = Response<BoxBody<Bytes, ServiceError>>; + +pub static MODULES: &'static [&'static dyn NodeKind] = &[ + &HttpBasicAuthKind, + &ProxyKind, + &HostsKind, + &FilesKind, + &AccessLogKind, + &ErrorKind, +]; + +pub struct NodeContext { + pub state: Arc<State>, + pub addr: SocketAddr, +} + +pub trait NodeKind: Send + Sync + 'static { + fn name(&self) -> &'static str; + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>>; +} +pub trait Node: Send + Sync + 'static { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>>; +} diff --git a/src/filters/proxy.rs b/src/filters/proxy.rs index ad959ad..ce72f65 100644 --- a/src/filters/proxy.rs +++ b/src/filters/proxy.rs @@ -1,87 +1,98 @@ -use crate::{helper::TokioIo, ServiceError, State}; -use http_body_util::{combinators::BoxBody, BodyExt}; -use hyper::{body::Incoming, http::HeaderValue, upgrade::OnUpgrade, Request, StatusCode}; +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{helper::TokioIo, ServiceError}; +use futures::Future; +use http_body_util::BodyExt; +use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; use log::{debug, warn}; -use std::{net::SocketAddr, sync::Arc}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; use tokio::net::TcpStream; -pub async fn proxy_request( - state: &Arc<State>, - mut req: Request<Incoming>, - addr: SocketAddr, - backend: &SocketAddr, -) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> { - #[cfg(feature = "mond")] - state.reporting.request_out.inc(); +#[derive(Default)] +pub struct ProxyKind; - //? Do we know what this did? - // *req.uri_mut() = Uri::builder() - // .path_and_query( - // req.uri() - // .clone() - // .path_and_query() - // .cloned() - // .unwrap_or(PathAndQuery::from_static("/")), - // ) - // .build() - // .unwrap(); - - req.headers_mut().insert( - "x-real-ip", - HeaderValue::from_str(&format!("{}", addr.ip())).unwrap(), - ); +#[derive(Debug, Deserialize)] +struct Proxy { + backend: SocketAddr, +} - let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>(); +impl NodeKind for ProxyKind { + fn name(&self) -> &'static str { + "proxy" + } + fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> { + Ok(Arc::new(serde_yaml::from_value::<Proxy>(config)?)) + } +} +impl Node for Proxy { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + mut request: NodeRequest, + ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> { + Box::pin(async move { + request.headers_mut().insert( + "x-real-ip", + HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), + ); - let _limit_guard = state.l_outgoing.try_acquire()?; - debug!("\tforwarding to {}", backend); - let mut resp = { - let client_stream = TokioIo( - TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?, - ); + let on_upgrade_downstream = request.extensions_mut().remove::<OnUpgrade>(); - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.with_upgrades().await { - warn!("connection failed: {:?}", err); - } - }); - sender - .send_request(req) - .await - .map_err(ServiceError::Hyper)? - }; + let _limit_guard = context.state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", self.backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(self.backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); - if resp.status() == StatusCode::SWITCHING_PROTOCOLS { - let on_upgrade_upstream = resp - .extensions_mut() - .remove::<OnUpgrade>() - .ok_or(ServiceError::UpgradeFailed)?; - let on_upgrade_downstream = on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; - tokio::task::spawn(async move { - debug!("about to upgrade connection"); - match (on_upgrade_upstream.await, on_upgrade_downstream.await) { - (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { - debug!("upgrade successful"); - match tokio::io::copy_bidirectional( - &mut TokioIo(upgraded_downstream), - &mut TokioIo(upgraded_upstream), - ) + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(request) .await - { - Ok((from_client, from_server)) => { - debug!("proxy socket terminated: {from_server} sent, {from_client} received") + .map_err(ServiceError::Hyper)? + }; + + if resp.status() == StatusCode::SWITCHING_PROTOCOLS { + let on_upgrade_upstream = resp + .extensions_mut() + .remove::<OnUpgrade>() + .ok_or(ServiceError::UpgradeFailed)?; + let on_upgrade_downstream = + on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; + tokio::task::spawn(async move { + debug!("about to upgrade connection"); + match (on_upgrade_upstream.await, on_upgrade_downstream.await) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), + } } - Err(e) => warn!("proxy socket error: {e}"), + (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), } - } - (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), + }); } - }); + + let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed()); + Ok(resp) + }) } - Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) } |