diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/helper.rs | 139 | ||||
-rw-r--r-- | src/main.rs | 67 | ||||
-rw-r--r-- | src/proxy.rs | 16 |
3 files changed, 179 insertions, 43 deletions
diff --git a/src/helper.rs b/src/helper.rs new file mode 100644 index 0000000..5914be3 --- /dev/null +++ b/src/helper.rs @@ -0,0 +1,139 @@ +// From https://github.com/hyperium/hyper/blob/master/benches/support/tokiort.rs + +use pin_project::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +#[pin_project] +#[derive(Debug)] +pub struct TokioIo<T>(#[pin] pub T); + +impl<T> hyper::rt::Read for TokioIo<T> +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll<Result<(), std::io::Error>> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().0, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl<T> hyper::rt::Write for TokioIo<T> +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + tokio::io::AsyncWrite::poll_write(self.project().0, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + tokio::io::AsyncWrite::poll_flush(self.project().0, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + tokio::io::AsyncWrite::poll_shutdown(self.project().0, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().0, cx, bufs) + } +} + +impl<T> tokio::io::AsyncRead for TokioIo<T> +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<Result<(), std::io::Error>> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().0, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl<T> tokio::io::AsyncWrite for TokioIo<T> +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + hyper::rt::Write::poll_write(self.project().0, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { + hyper::rt::Write::poll_flush(self.project().0, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + hyper::rt::Write::poll_shutdown(self.project().0, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll<Result<usize, std::io::Error>> { + hyper::rt::Write::poll_write_vectored(self.project().0, cx, bufs) + } +} diff --git a/src/main.rs b/src/main.rs index aa13609..5230434 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ pub mod config; pub mod error; pub mod files; +pub mod helper; pub mod limiter; pub mod proxy; @@ -15,6 +16,7 @@ use crate::{ use anyhow::{anyhow, bail, Context, Result}; use error::ServiceError; use futures::future::try_join_all; +use helper::TokioIo; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, @@ -27,11 +29,7 @@ use hyper::{ use limiter::Limiter; use log::{debug, error, info, warn}; use std::{fs::File, io::BufReader, net::SocketAddr, path::Path, process::exit, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - net::TcpListener, - signal::ctrl_c, -}; +use tokio::{net::TcpListener, signal::ctrl_c}; use tokio_rustls::TlsAcceptor; pub struct State { @@ -90,18 +88,17 @@ async fn serve_http(state: Arc<State>) -> Result<()> { None => return Ok(()), }; - let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind - .iter() - .map(|e| async { - let l = TcpListener::bind(e.clone()).await?; - loop { - let (stream, addr) = l.accept().await.context("accepting connection")?; - debug!("connection from {addr}"); - let config = state.clone(); - tokio::spawn(async move { serve_stream(config, stream, addr).await }); - } - })) - .await; + let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind.iter().map(|e| async { + let l = TcpListener::bind(e.clone()).await?; + loop { + let (stream, addr) = l.accept().await.context("accepting connection")?; + debug!("connection from {addr}"); + let stream = TokioIo(stream); + let config = state.clone(); + tokio::spawn(async move { serve_stream(config, stream, addr).await }); + } + })) + .await; info!("serving http"); @@ -129,31 +126,29 @@ async fn serve_https(state: Arc<State>) -> Result<()> { }; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); - let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind - .iter() - .map(|e| async { - let l = TcpListener::bind(e.clone()).await?; - loop { - let (stream, addr) = l.accept().await.context("accepting connection")?; - let state = state.clone(); - let tls_acceptor = tls_acceptor.clone(); - tokio::task::spawn(async move { - debug!("connection from {addr}"); - match tls_acceptor.accept(stream).await { - Ok(stream) => serve_stream(state, stream, addr).await, - Err(e) => warn!("error accepting tls: {e}"), - }; - }); - } - })) - .await; + let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind.iter().map(|e| async { + let l = TcpListener::bind(e.clone()).await?; + loop { + let (stream, addr) = l.accept().await.context("accepting connection")?; + let state = state.clone(); + let tls_acceptor = tls_acceptor.clone(); + tokio::task::spawn(async move { + debug!("connection from {addr}"); + match tls_acceptor.accept(stream).await { + Ok(stream) => serve_stream(state, TokioIo(stream), addr).await, + Err(e) => warn!("error accepting tls: {e}"), + }; + }); + } + })) + .await; info!("serving https"); listen_futs?; Ok(()) } -pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>( +pub async fn serve_stream<T: Unpin + Send + 'static + hyper::rt::Read + hyper::rt::Write>( state: Arc<State>, stream: T, addr: SocketAddr, diff --git a/src/proxy.rs b/src/proxy.rs index 7070500..10d7e3d 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,4 +1,4 @@ -use crate::{ServiceError, State}; +use crate::{helper::TokioIo, ServiceError, State}; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, @@ -51,9 +51,11 @@ pub async fn proxy_request( if let Some(_limit_guard) = state.l_outgoing.obtain() { debug!("\tforwarding to {}", backend); let mut resp = { - let client_stream = TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?; + let client_stream = TokioIo( + TcpStream::connect(backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) .await @@ -77,11 +79,11 @@ pub async fn proxy_request( on_upgrade_upstream.unwrap().await, on_upgrade_downstream.unwrap().await, ) { - (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { debug!("upgrade successful"); match tokio::io::copy_bidirectional( - &mut upgraded_downstream, - &mut upgraded_upstream, + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), ) .await { |