/* This file is part of gnix (https://codeberg.org/metamuffin/gnix) which is licensed under the GNU Affero General Public License (version 3); see /COPYING. Copyright (C) 2025 metamuffin */ use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; use crate::ServiceError; use futures::Future; use http::{Response, Version}; use http_body_util::BodyExt; use hyper::{body::Incoming, http::HeaderValue, upgrade::OnUpgrade, StatusCode}; use hyper_util::rt::TokioIo; use log::{debug, warn}; use serde::Deserialize; use serde_yml::Value; use std::{ fmt::Display, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, path::PathBuf, pin::Pin, str::FromStr, sync::Arc, }; use tokio::{ io::{AsyncRead, AsyncWrite}, net::{TcpStream, UnixStream}, }; #[derive(Default)] pub struct ProxyKind; #[derive(Debug, Deserialize)] struct Proxy { #[serde(default)] set_forwarded_for: bool, #[serde(default = "ret_true")] set_real_ip: bool, backend: SocketAddrOrPath, } #[derive(Debug)] enum SocketAddrOrPath { Ip(SocketAddr), Unix(PathBuf), } fn ret_true() -> bool { true } impl NodeKind for ProxyKind { fn name(&self) -> &'static str { "proxy" } fn instanciate(&self, config: Value) -> anyhow::Result> { Ok(Arc::new(serde_yml::from_value::(config)?)) } } impl<'de> Deserialize<'de> for SocketAddrOrPath { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s = String::deserialize(deserializer)?; if s.starts_with("/") { Ok(Self::Unix(PathBuf::from_str(&s).map_err(|e| match e {})?)) } else if let Some(port) = s.strip_prefix(":") { Ok(Self::Ip(SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::LOCALHOST, port.parse().map_err(serde::de::Error::custom)?, )))) } else { Ok(Self::Ip( SocketAddr::from_str(&s).map_err(serde::de::Error::custom)?, )) } } } impl Display for SocketAddrOrPath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { SocketAddrOrPath::Ip(addr) => write!(f, "{addr}"), SocketAddrOrPath::Unix(path) => write!(f, "unix:{}", path.to_string_lossy()), } } } impl Node for Proxy { fn handle<'a>( &'a self, context: &'a mut NodeContext, mut request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { *request.version_mut() = Version::HTTP_11; // totally not a lie if self.set_real_ip { request.headers_mut().insert( "x-real-ip", HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), ); } if self.set_forwarded_for { request.headers_mut().insert( "x-forwarded-for", HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(), ); request.headers_mut().insert( "x-forwarded-port", HeaderValue::from_str(&context.listen_addr.port().to_string()).unwrap(), ); let scheme = HeaderValue::from_str(if context.secure { "https" } else { "http" }).unwrap(); request .headers_mut() .insert("x-forwarded-scheme", scheme.clone()); request.headers_mut().insert("x-forwarded-proto", scheme); if let Some(host) = request.headers().get("host").cloned() { request.headers_mut().insert("x-forwarded-host", host); } } let on_upgrade_downstream = request.extensions_mut().remove::(); let _limit_guard = context.state.l_outgoing.try_acquire()?; debug!("\tforwarding to {}", self.backend); let mut resp = { #[inline] async fn send_request( stream: T, request: NodeRequest, ) -> Result, ServiceError> { let client_stream = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) .await .map_err(ServiceError::Hyper)?; tokio::task::spawn(async move { if let Err(err) = conn.with_upgrades().await { warn!("connection failed: {:?}", err); } }); sender .send_request(request) .await .map_err(ServiceError::Hyper) } match &self.backend { SocketAddrOrPath::Ip(socket_addr) => { send_request( TcpStream::connect(socket_addr) .await .map_err(|_| ServiceError::CantConnect)?, request, ) .await? } SocketAddrOrPath::Unix(path) => { send_request( UnixStream::connect(path) .await .map_err(|_| ServiceError::CantConnect)?, request, ) .await? } } }; if resp.status() == StatusCode::SWITCHING_PROTOCOLS { let on_upgrade_upstream = resp .extensions_mut() .remove::() .ok_or(ServiceError::UpgradeFailed)?; let on_upgrade_downstream = on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?; tokio::task::spawn(async move { debug!("about to upgrade connection"); match (on_upgrade_upstream.await, on_upgrade_downstream.await) { (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { debug!("upgrade successful"); match tokio::io::copy_bidirectional( &mut TokioIo::new(upgraded_downstream), &mut TokioIo::new(upgraded_upstream), ) .await { Ok((from_client, from_server)) => { debug!("proxy socket terminated: {from_server} sent, {from_client} received") } Err(e) => warn!("proxy socket error: {e}"), } } (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"), } }); } let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed()); Ok(resp) }) } }