summaryrefslogtreecommitdiff
path: root/src/modules/proxy.rs
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2024-05-30 00:09:11 +0200
committermetamuffin <metamuffin@disroot.org>2024-05-30 00:09:11 +0200
commit532cc431d1c5ca1ffcf429a4ccb94edc7848fe7a (patch)
treec4422c4d54e01f63bae391cd95788cad74f59fbb /src/modules/proxy.rs
parent8b39940a58c28bc1bbe291eb5229e9ce1444e33c (diff)
downloadgnix-532cc431d1c5ca1ffcf429a4ccb94edc7848fe7a.tar
gnix-532cc431d1c5ca1ffcf429a4ccb94edc7848fe7a.tar.bz2
gnix-532cc431d1c5ca1ffcf429a4ccb94edc7848fe7a.tar.zst
rename filters dir
Diffstat (limited to 'src/modules/proxy.rs')
-rw-r--r--src/modules/proxy.rs98
1 files changed, 98 insertions, 0 deletions
diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs
new file mode 100644
index 0000000..ce72f65
--- /dev/null
+++ b/src/modules/proxy.rs
@@ -0,0 +1,98 @@
+use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
+use crate::{helper::TokioIo, ServiceError};
+use futures::Future;
+use http_body_util::BodyExt;
+use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode};
+use log::{debug, warn};
+use serde::Deserialize;
+use serde_yaml::Value;
+use std::{net::SocketAddr, pin::Pin, sync::Arc};
+use tokio::net::TcpStream;
+
+#[derive(Default)]
+pub struct ProxyKind;
+
+#[derive(Debug, Deserialize)]
+struct Proxy {
+ backend: SocketAddr,
+}
+
+impl NodeKind for ProxyKind {
+ fn name(&self) -> &'static str {
+ "proxy"
+ }
+ fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> {
+ Ok(Arc::new(serde_yaml::from_value::<Proxy>(config)?))
+ }
+}
+impl Node for Proxy {
+ fn handle<'a>(
+ &'a self,
+ context: &'a mut NodeContext,
+ mut request: NodeRequest,
+ ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
+ Box::pin(async move {
+ request.headers_mut().insert(
+ "x-real-ip",
+ HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(),
+ );
+
+ let on_upgrade_downstream = request.extensions_mut().remove::<OnUpgrade>();
+
+ let _limit_guard = context.state.l_outgoing.try_acquire()?;
+ debug!("\tforwarding to {}", self.backend);
+ let mut resp = {
+ let client_stream = TokioIo(
+ TcpStream::connect(self.backend)
+ .await
+ .map_err(|_| ServiceError::CantConnect)?,
+ );
+
+ let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
+ .await
+ .map_err(ServiceError::Hyper)?;
+ tokio::task::spawn(async move {
+ if let Err(err) = conn.with_upgrades().await {
+ warn!("connection failed: {:?}", err);
+ }
+ });
+ sender
+ .send_request(request)
+ .await
+ .map_err(ServiceError::Hyper)?
+ };
+
+ if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
+ let on_upgrade_upstream = resp
+ .extensions_mut()
+ .remove::<OnUpgrade>()
+ .ok_or(ServiceError::UpgradeFailed)?;
+ let on_upgrade_downstream =
+ on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?;
+ tokio::task::spawn(async move {
+ debug!("about to upgrade connection");
+ match (on_upgrade_upstream.await, on_upgrade_downstream.await) {
+ (Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
+ debug!("upgrade successful");
+ match tokio::io::copy_bidirectional(
+ &mut TokioIo(upgraded_downstream),
+ &mut TokioIo(upgraded_upstream),
+ )
+ .await
+ {
+ Ok((from_client, from_server)) => {
+ debug!("proxy socket terminated: {from_server} sent, {from_client} received")
+ }
+ Err(e) => warn!("proxy socket error: {e}"),
+ }
+ }
+ (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"),
+ }
+ });
+ }
+
+ let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed());
+ Ok(resp)
+ })
+ }
+}