use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; use crate::ServiceError; use futures::Future; use http_body_util::BodyExt; use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; use hyper_util::rt::TokioIo; use log::{debug, warn}; use serde::Deserialize; use serde_yml::Value; use std::{net::SocketAddr, pin::Pin, sync::Arc}; use tokio::net::TcpStream; #[derive(Default)] pub struct ProxyKind; #[derive(Debug, Deserialize)] struct Proxy { set_forwarded_for: bool, #[serde(default = "ret_true")] set_real_ip: bool, backend: SocketAddr, } fn ret_true() -> bool { true } impl NodeKind for ProxyKind { fn name(&self) -> &'static str { "proxy" } fn instanciate(&self, config: Value) -> anyhow::Result> { Ok(Arc::new(serde_yml::from_value::(config)?)) } } impl Node for Proxy { fn handle<'a>( &'a self, context: &'a mut NodeContext, mut request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { if self.set_real_ip { request.headers_mut().insert( "x-real-ip", HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), ); } if self.set_forwarded_for { request.headers_mut().insert( "x-forwarded-for", HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), ); request.headers_mut().insert( "x-forwarded-port", HeaderValue::from_str(&context.addr.port().to_string()).unwrap(), ); let scheme = HeaderValue::from_str(if context.secure { "https" } else { "http" }).unwrap(); request .headers_mut() .insert("x-forwarded-scheme", scheme.clone()); request.headers_mut().insert("x-forwarded-proto", scheme); if let Some(host) = request.headers().get("host").cloned() { request.headers_mut().insert("x-forwarded-host", host); } } let on_upgrade_downstream = request.extensions_mut().remove::(); let _limit_guard = context.state.l_outgoing.try_acquire()?; debug!("\tforwarding to {}", self.backend); let mut resp = { let client_stream = TokioIo::new( TcpStream::connect(self.backend) .await .map_err(|_| ServiceError::CantConnect)?, ); let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) .await .map_err(ServiceError::Hyper)?; tokio::task::spawn(async move { if let Err(err) = conn.with_upgrades().await { warn!("connection failed: {:?}", err); } }); sender .send_request(request) .await .map_err(ServiceError::Hyper)? }; if resp.status() == StatusCode::SWITCHING_PROTOCOLS { let on_upgrade_upstream = resp .extensions_mut() .remove::() .ok_or(ServiceError::UpgradeFailed)?; let on_upgrade_downstream = on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; tokio::task::spawn(async move { debug!("about to upgrade connection"); match (on_upgrade_upstream.await, on_upgrade_downstream.await) { (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { debug!("upgrade successful"); match tokio::io::copy_bidirectional( &mut TokioIo::new(upgraded_downstream), &mut TokioIo::new(upgraded_upstream), ) .await { Ok((from_client, from_server)) => { debug!("proxy socket terminated: {from_server} sent, {from_client} received") } Err(e) => warn!("proxy socket error: {e}"), } } (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), } }); } let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed()); Ok(resp) }) } }