summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2025-03-04 23:27:30 +0100
committermetamuffin <metamuffin@disroot.org>2025-03-04 23:27:30 +0100
commit10dba4e7c79e622249cf4539cff3347619e15059 (patch)
tree18851e4100b6b906d99f06b102ad0cd8cde6815f
parent322e6b32aa313b99861bd4f03ebed68bcb2fd031 (diff)
downloadgnix-10dba4e7c79e622249cf4539cff3347619e15059.tar
gnix-10dba4e7c79e622249cf4539cff3347619e15059.tar.bz2
gnix-10dba4e7c79e622249cf4539cff3347619e15059.tar.zst
set correct port on forwarded headers
-rw-r--r--src/main.rs15
-rw-r--r--src/modules/mod.rs1
-rw-r--r--src/modules/proxy.rs2
3 files changed, 14 insertions, 4 deletions
diff --git a/src/main.rs b/src/main.rs
index fe7856e..b42d868 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -121,13 +121,16 @@ async fn serve_http(state: Arc<State>) -> Result<()> {
let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind.iter().map(|e| async {
let l = TcpListener::bind(*e).await?;
+ let listen_addr = l.local_addr()?;
info!("HTTP listener bound to {}", l.local_addr().unwrap());
loop {
let (stream, addr) = l.accept().await.context("accepting connection")?;
debug!("connection from {addr}");
let stream = TokioIo::new(stream);
let state = state.clone();
- tokio::spawn(async move { serve_stream(state, stream, addr, false).await });
+ tokio::spawn(
+ async move { serve_stream(state, stream, addr, false, listen_addr).await },
+ );
}
}))
.await;
@@ -155,6 +158,7 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config));
let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind.iter().map(|e| async {
let l = TcpListener::bind(*e).await?;
+ let listen_addr = l.local_addr()?;
info!("HTTPS listener bound to {}", l.local_addr().unwrap());
loop {
let (stream, addr) = l.accept().await.context("accepting connection")?;
@@ -163,7 +167,9 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
tokio::task::spawn(async move {
debug!("connection from {addr}");
match tls_acceptor.accept(stream).await {
- Ok(stream) => serve_stream(state, TokioIo::new(stream), addr, true).await,
+ Ok(stream) => {
+ serve_stream(state, TokioIo::new(stream), addr, true, listen_addr).await
+ }
Err(e) => warn!("error accepting tls: {e}"),
};
});
@@ -179,6 +185,7 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r
stream: T,
addr: SocketAddr,
secure: bool,
+ listen_addr: SocketAddr,
) {
if let Ok(_semaphore) = state.l_incoming.try_acquire() {
let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
@@ -188,7 +195,7 @@ pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::r
let state = state.clone();
async move {
let config = state.config.read().await.clone();
- match service(state, config, req, addr, secure).await {
+ match service(state, config, req, addr, secure, listen_addr).await {
Ok(r) => Ok(r),
Err(ServiceError::Hyper(e)) => Err(e),
Err(error) => Ok({
@@ -222,6 +229,7 @@ async fn service(
mut request: Request<Incoming>,
addr: SocketAddr,
secure: bool,
+ listen_addr: SocketAddr,
) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> {
// move uri authority used in HTTP/2 to Host header field
{
@@ -247,6 +255,7 @@ async fn service(
addr,
state,
secure,
+ listen_addr,
};
let mut resp = config
.handler
diff --git a/src/modules/mod.rs b/src/modules/mod.rs
index d3e0dd3..61f4b74 100644
--- a/src/modules/mod.rs
+++ b/src/modules/mod.rs
@@ -56,6 +56,7 @@ pub struct NodeContext {
pub state: Arc<State>,
pub addr: SocketAddr,
pub secure: bool,
+ pub listen_addr: SocketAddr,
}
pub trait NodeKind: Send + Sync + 'static {
diff --git a/src/modules/proxy.rs b/src/modules/proxy.rs
index e92b8fc..2fa3538 100644
--- a/src/modules/proxy.rs
+++ b/src/modules/proxy.rs
@@ -54,7 +54,7 @@ impl Node for Proxy {
);
request.headers_mut().insert(
"x-forwarded-port",
- HeaderValue::from_str(&context.addr.port().to_string()).unwrap(),
+ HeaderValue::from_str(&context.listen_addr.port().to_string()).unwrap(),
);
let scheme =
HeaderValue::from_str(if context.secure { "https" } else { "http" }).unwrap();