aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-08-28 14:52:24 +0200
committermetamuffin <metamuffin@disroot.org>2023-08-28 14:52:24 +0200
commit2bc557bbddb01b535dd8512fe3aadb0d4091a42d (patch)
treeab376b39df3ef1719c77842f4caf4583e48dfcb6 /src
parent951c4e90b573f3d14a137bade0853fb3f0f21a5d (diff)
downloadgnix-2bc557bbddb01b535dd8512fe3aadb0d4091a42d.tar
gnix-2bc557bbddb01b535dd8512fe3aadb0d4091a42d.tar.bz2
gnix-2bc557bbddb01b535dd8512fe3aadb0d4091a42d.tar.zst
update all deps, including a new horrible version hyper
Diffstat (limited to 'src')
-rw-r--r--src/helper.rs139
-rw-r--r--src/main.rs67
-rw-r--r--src/proxy.rs16
3 files changed, 179 insertions, 43 deletions
diff --git a/src/helper.rs b/src/helper.rs
new file mode 100644
index 0000000..5914be3
--- /dev/null
+++ b/src/helper.rs
@@ -0,0 +1,139 @@
+// From https://github.com/hyperium/hyper/blob/master/benches/support/tokiort.rs
+
+use pin_project::pin_project;
+use std::{
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+#[pin_project]
+#[derive(Debug)]
+pub struct TokioIo<T>(#[pin] pub T);
+
+impl<T> hyper::rt::Read for TokioIo<T>
+where
+ T: tokio::io::AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ mut buf: hyper::rt::ReadBufCursor<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ let n = unsafe {
+ let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
+ match tokio::io::AsyncRead::poll_read(self.project().0, cx, &mut tbuf) {
+ Poll::Ready(Ok(())) => tbuf.filled().len(),
+ other => return other,
+ }
+ };
+
+ unsafe {
+ buf.advance(n);
+ }
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<T> hyper::rt::Write for TokioIo<T>
+where
+ T: tokio::io::AsyncWrite,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ tokio::io::AsyncWrite::poll_write(self.project().0, cx, buf)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
+ tokio::io::AsyncWrite::poll_flush(self.project().0, cx)
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ tokio::io::AsyncWrite::poll_shutdown(self.project().0, cx)
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ tokio::io::AsyncWrite::is_write_vectored(&self.0)
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[std::io::IoSlice<'_>],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ tokio::io::AsyncWrite::poll_write_vectored(self.project().0, cx, bufs)
+ }
+}
+
+impl<T> tokio::io::AsyncRead for TokioIo<T>
+where
+ T: hyper::rt::Read,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ tbuf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ //let init = tbuf.initialized().len();
+ let filled = tbuf.filled().len();
+ let sub_filled = unsafe {
+ let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
+
+ match hyper::rt::Read::poll_read(self.project().0, cx, buf.unfilled()) {
+ Poll::Ready(Ok(())) => buf.filled().len(),
+ other => return other,
+ }
+ };
+
+ let n_filled = filled + sub_filled;
+ // At least sub_filled bytes had to have been initialized.
+ let n_init = sub_filled;
+ unsafe {
+ tbuf.assume_init(n_init);
+ tbuf.set_filled(n_filled);
+ }
+
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<T> tokio::io::AsyncWrite for TokioIo<T>
+where
+ T: hyper::rt::Write,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ hyper::rt::Write::poll_write(self.project().0, cx, buf)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
+ hyper::rt::Write::poll_flush(self.project().0, cx)
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ hyper::rt::Write::poll_shutdown(self.project().0, cx)
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ hyper::rt::Write::is_write_vectored(&self.0)
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[std::io::IoSlice<'_>],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ hyper::rt::Write::poll_write_vectored(self.project().0, cx, bufs)
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index aa13609..5230434 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -4,6 +4,7 @@
pub mod config;
pub mod error;
pub mod files;
+pub mod helper;
pub mod limiter;
pub mod proxy;
@@ -15,6 +16,7 @@ use crate::{
use anyhow::{anyhow, bail, Context, Result};
use error::ServiceError;
use futures::future::try_join_all;
+use helper::TokioIo;
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
body::Incoming,
@@ -27,11 +29,7 @@ use hyper::{
use limiter::Limiter;
use log::{debug, error, info, warn};
use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc};
-use tokio::{
- io::{AsyncRead, AsyncWrite},
- net::TcpListener,
- signal::ctrl_c,
-};
+use tokio::{net::TcpListener, signal::ctrl_c};
use tokio_rustls::TlsAcceptor;
pub struct State {
@@ -90,18 +88,17 @@ async fn serve_http(state: Arc<State>) -> Result<()> {
None => return Ok(()),
};
- let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind
- .iter()
- .map(|e| async {
- let l = TcpListener::bind(e.clone()).await?;
- loop {
- let (stream, addr) = l.accept().await.context("accepting connection")?;
- debug!("connection from {addr}");
- let config = state.clone();
- tokio::spawn(async move { serve_stream(config, stream, addr).await });
- }
- }))
- .await;
+ let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind.iter().map(|e| async {
+ let l = TcpListener::bind(e.clone()).await?;
+ loop {
+ let (stream, addr) = l.accept().await.context("accepting connection")?;
+ debug!("connection from {addr}");
+ let stream = TokioIo(stream);
+ let config = state.clone();
+ tokio::spawn(async move { serve_stream(config, stream, addr).await });
+ }
+ }))
+ .await;
info!("serving http");
@@ -129,31 +126,29 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
};
let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config));
- let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind
- .iter()
- .map(|e| async {
- let l = TcpListener::bind(e.clone()).await?;
- loop {
- let (stream, addr) = l.accept().await.context("accepting connection")?;
- let state = state.clone();
- let tls_acceptor = tls_acceptor.clone();
- tokio::task::spawn(async move {
- debug!("connection from {addr}");
- match tls_acceptor.accept(stream).await {
- Ok(stream) => serve_stream(state, stream, addr).await,
- Err(e) => warn!("error accepting tls: {e}"),
- };
- });
- }
- }))
- .await;
+ let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind.iter().map(|e| async {
+ let l = TcpListener::bind(e.clone()).await?;
+ loop {
+ let (stream, addr) = l.accept().await.context("accepting connection")?;
+ let state = state.clone();
+ let tls_acceptor = tls_acceptor.clone();
+ tokio::task::spawn(async move {
+ debug!("connection from {addr}");
+ match tls_acceptor.accept(stream).await {
+ Ok(stream) => serve_stream(state, TokioIo(stream), addr).await,
+ Err(e) => warn!("error accepting tls: {e}"),
+ };
+ });
+ }
+ }))
+ .await;
info!("serving https");
listen_futs?;
Ok(())
}
-pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
+pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::rt::Write>(
state: Arc<State>,
stream: T,
addr: SocketAddr,
diff --git a/src/proxy.rs b/src/proxy.rs
index 7070500..10d7e3d 100644
--- a/src/proxy.rs
+++ b/src/proxy.rs
@@ -1,4 +1,4 @@
-use crate::{ServiceError, State};
+use crate::{helper::TokioIo, ServiceError, State};
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
body::Incoming,
@@ -51,9 +51,11 @@ pub async fn proxy_request(
if let Some(_limit_guard) = state.l_outgoing.obtain() {
debug!("\tforwarding to {}", backend);
let mut resp = {
- let client_stream = TcpStream::connect(backend)
- .await
- .map_err(|_| ServiceError::CantConnect)?;
+ let client_stream = TokioIo(
+ TcpStream::connect(backend)
+ .await
+ .map_err(|_| ServiceError::CantConnect)?,
+ );
let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
.await
@@ -77,11 +79,11 @@ pub async fn proxy_request(
on_upgrade_upstream.unwrap().await,
on_upgrade_downstream.unwrap().await,
) {
- (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => {
+ (Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
debug!("upgrade successful");
match tokio::io::copy_bidirectional(
- &mut upgraded_downstream,
- &mut upgraded_upstream,
+ &mut TokioIo(upgraded_downstream),
+ &mut TokioIo(upgraded_upstream),
)
.await
{