From 532cc431d1c5ca1ffcf429a4ccb94edc7848fe7a Mon Sep 17 00:00:00 2001 From: metamuffin Date: Thu, 30 May 2024 00:09:11 +0200 Subject: rename filters dir --- src/modules/proxy.rs | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/modules/proxy.rs (limited to 'src/modules/proxy.rs') diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs new file mode 100644 index 0000000..ce72f65 --- /dev/null +++ b/src/modules/proxy.rs @@ -0,0 +1,98 @@ +use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; +use crate::{helper::TokioIo, ServiceError}; +use futures::Future; +use http_body_util::BodyExt; +use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; +use log::{debug, warn}; +use serde::Deserialize; +use serde_yaml::Value; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; +use tokio::net::TcpStream; + +#[derive(Default)] +pub struct ProxyKind; + +#[derive(Debug, Deserialize)] +struct Proxy { + backend: SocketAddr, +} + +impl NodeKind for ProxyKind { + fn name(&self) -> &'static str { + "proxy" + } + fn instanciate(&self, config: Value) -> anyhow::Result> { + Ok(Arc::new(serde_yaml::from_value::(config)?)) + } +} +impl Node for Proxy { + fn handle<'a>( + &'a self, + context: &'a mut NodeContext, + mut request: NodeRequest, + ) -> Pin> + Send + Sync + 'a>> { + Box::pin(async move { + request.headers_mut().insert( + "x-real-ip", + HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), + ); + + let on_upgrade_downstream = request.extensions_mut().remove::(); + + let _limit_guard = context.state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", self.backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(self.backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(request) + .await + .map_err(ServiceError::Hyper)? + }; + + if resp.status() == StatusCode::SWITCHING_PROTOCOLS { + let on_upgrade_upstream = resp + .extensions_mut() + .remove::() + .ok_or(ServiceError::UpgradeFailed)?; + let on_upgrade_downstream = + on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; + tokio::task::spawn(async move { + debug!("about to upgrade connection"); + match (on_upgrade_upstream.await, on_upgrade_downstream.await) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), + } + } + (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), + } + }); + } + + let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed()); + Ok(resp) + }) + } +} -- cgit v1.2.3-70-g09d2