diff options
author | metamuffin <metamuffin@disroot.org> | 2025-03-20 11:57:37 +0100 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2025-03-20 11:57:37 +0100 |
commit | 1d5640c009dd993969c9aba749ecbbd2efe1b634 (patch) | |
tree | 3d41ad57e9f12629e10fdfc969023bc7967b4998 | |
parent | 69a77b36172503ada9047a756bb5972a4c7005dc (diff) | |
download | gnix-1d5640c009dd993969c9aba749ecbbd2efe1b634.tar gnix-1d5640c009dd993969c9aba749ecbbd2efe1b634.tar.bz2 gnix-1d5640c009dd993969c9aba749ecbbd2efe1b634.tar.zst |
proxy: unix socket support
-rw-r--r-- | Cargo.lock | 2 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | readme.md | 3 | ||||
-rw-r--r-- | src/modules/proxy.rs | 108 |
4 files changed, 92 insertions, 23 deletions
@@ -668,7 +668,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "gnix" -version = "2.3.0" +version = "2.4.0" dependencies = [ "aes-gcm-siv", "anyhow", @@ -1,6 +1,6 @@ [package] name = "gnix" -version = "2.3.0" +version = "2.4.0" edition = "2021" [dependencies] @@ -123,7 +123,8 @@ themselves; in that case the request is passed on. - Forwards the request as-is to some other server. the `x-real-ip` header is added to the request. Connection upgrades are handled by direct forwarding of network traffic. - - `backend`: socket address (string) to the backend server + - `backend`: IP socket address or absolute path to unix socket of the backend + server. (string) - `set_real_ip`: Sets the `X-Real-IP` header. (boolean) - `set_forwarded_for`: Sets the `X-Forwarded-For`, `X-Forwarded-Host`, `X-Forwaded-Proto`, `X-Forwarded-Scheme` and `X-Forwarded-Port` headers. diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs index a0e8d85..c219647 100644 --- a/src/modules/proxy.rs +++ b/src/modules/proxy.rs @@ -1,15 +1,25 @@ use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; use crate::ServiceError; use futures::Future; -use http::Version; +use http::{Response, Version}; use http_body_util::BodyExt; -use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode}; +use hyper::{body::Incoming, http::HeaderValue, upgrade::OnUpgrade, StatusCode}; use hyper_util::rt::TokioIo; use log::{debug, warn}; use serde::Deserialize; use serde_yml::Value; -use std::{net::SocketAddr, pin::Pin, sync::Arc}; -use tokio::net::TcpStream; +use std::{ + fmt::Display, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + path::PathBuf, + pin::Pin, + str::FromStr, + sync::Arc, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpStream, UnixStream}, +}; #[derive(Default)] pub struct ProxyKind; @@ -20,7 +30,13 @@ struct Proxy { set_forwarded_for: bool, #[serde(default = "ret_true")] set_real_ip: bool, - backend: SocketAddr, + backend: SocketAddrOrPath, +} + +#[derive(Debug)] +enum SocketAddrOrPath { + Ip(SocketAddr), + Unix(PathBuf), } fn ret_true() -> bool { @@ -35,6 +51,36 @@ impl NodeKind for ProxyKind { Ok(Arc::new(serde_yml::from_value::<Proxy>(config)?)) } } + +impl<'de> Deserialize<'de> for SocketAddrOrPath { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + if s.starts_with("/") { + Ok(Self::Unix(PathBuf::from_str(&s).map_err(|e| match e {})?)) + } else if let Some(port) = s.strip_prefix(":") { + Ok(Self::Ip(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::LOCALHOST, + port.parse().map_err(serde::de::Error::custom)?, + )))) + } else { + Ok(Self::Ip( + SocketAddr::from_str(&s).map_err(serde::de::Error::custom)?, + )) + } + } +} +impl Display for SocketAddrOrPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SocketAddrOrPath::Ip(addr) => write!(f, "{addr}"), + SocketAddrOrPath::Unix(path) => write!(f, "unix:{}", path.to_string_lossy()), + } + } +} + impl Node for Proxy { fn handle<'a>( &'a self, @@ -74,24 +120,46 @@ impl Node for Proxy { let _limit_guard = context.state.l_outgoing.try_acquire()?; debug!("\tforwarding to {}", self.backend); let mut resp = { - let client_stream = TokioIo::new( - TcpStream::connect(self.backend) + #[inline] + async fn send_request<T: AsyncWrite + AsyncRead + Unpin + Send + 'static>( + stream: T, + request: NodeRequest, + ) -> Result<Response<Incoming>, ServiceError> { + let client_stream = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) .await - .map_err(|_| ServiceError::CantConnect)?, - ); + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.with_upgrades().await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(request) + .await + .map_err(ServiceError::Hyper) + } - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.with_upgrades().await { - warn!("connection failed: {:?}", err); + match &self.backend { + SocketAddrOrPath::Ip(socket_addr) => { + send_request( + TcpStream::connect(socket_addr) + .await + .map_err(|_| ServiceError::CantConnect)?, + request, + ) + .await? } - }); - sender - .send_request(request) - .await - .map_err(ServiceError::Hyper)? + SocketAddrOrPath::Unix(path) => { + send_request( + UnixStream::connect(path) + .await + .map_err(|_| ServiceError::CantConnect)?, + request, + ) + .await? + } + } }; if resp.status() == StatusCode::SWITCHING_PROTOCOLS { |