From 6566cbb3f25aa8b1247c259b5e546910b6044f93 Mon Sep 17 00:00:00 2001 From: metamuffin Date: Thu, 7 Dec 2023 14:35:48 +0100 Subject: move some files around and add horrible access log --- src/auth.rs | 41 ------- src/config.rs | 13 +++ src/files.rs | 278 ----------------------------------------------- src/filters/accesslog.rs | 48 ++++++++ src/filters/auth.rs | 41 +++++++ src/filters/files.rs | 278 +++++++++++++++++++++++++++++++++++++++++++++++ src/filters/mod.rs | 5 + src/filters/proxy.rs | 104 ++++++++++++++++++ src/main.rs | 31 ++++-- src/proxy.rs | 104 ------------------ 10 files changed, 511 insertions(+), 432 deletions(-) delete mode 100644 src/auth.rs delete mode 100644 src/files.rs create mode 100644 src/filters/accesslog.rs create mode 100644 src/filters/auth.rs create mode 100644 src/filters/files.rs create mode 100644 src/filters/mod.rs create mode 100644 src/filters/proxy.rs delete mode 100644 src/proxy.rs (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs deleted file mode 100644 index 92a9ba3..0000000 --- a/src/auth.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; -use base64::Engine; -use http_body_util::{combinators::BoxBody, BodyExt}; -use hyper::{ - header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, - Response, StatusCode, -}; -use log::debug; -use std::ops::ControlFlow; - -pub fn http_basic( - config: &HttpBasicAuthConfig, - req: &FilterRequest, - resp: &mut FilterResponseOut, -) -> Result, ServiceError> { - if let Some(auth) = req.headers().get(AUTHORIZATION) { - let k = auth - .as_bytes() - .strip_prefix(b"Basic ") - .ok_or(ServiceError::BadAuth)?; - let k = base64::engine::general_purpose::STANDARD.decode(k)?; - let k = String::from_utf8(k)?; - if config.valid.contains(&k) { - debug!("valid auth"); - return Ok(ControlFlow::Continue(())); - } else { - debug!("invalid auth"); - } - } - debug!("unauthorized; sending auth challenge"); - let mut r = Response::new(BoxBody::<_, ServiceError>::new( - String::new().map_err(|_| unreachable!()), - )); - *r.status_mut() = StatusCode::UNAUTHORIZED; - r.headers_mut().insert( - WWW_AUTHENTICATE, - HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), - ); - *resp = Some(r); - Ok(ControlFlow::Break(())) -} diff --git a/src/config.rs b/src/config.rs index b60ac8c..bc4369d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -70,6 +70,19 @@ pub enum RouteFilter { #[serde(flatten)] config: FileserverConfig, }, + AccessLog { + #[serde(flatten)] + config: AccessLogConfig, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AccessLogConfig { + pub file: PathBuf, + #[serde(default)] + pub flush: bool, + #[serde(default)] + pub reject_on_fail: bool, } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/files.rs b/src/files.rs deleted file mode 100644 index 733d045..0000000 --- a/src/files.rs +++ /dev/null @@ -1,278 +0,0 @@ -use crate::{config::FileserverConfig, ServiceError}; -use bytes::{Bytes, BytesMut}; -use futures_util::{future, future::Either, ready, stream, FutureExt, Stream, StreamExt}; -use headers::{AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMapExt}; -use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; -use humansize::FormatSizeOptions; -use hyper::{ - body::{Frame, Incoming}, - header::{CONTENT_TYPE, LOCATION}, - http::HeaderValue, - Request, Response, StatusCode, -}; -use log::debug; -use markup::Render; -use percent_encoding::percent_decode_str; -use std::{fs::Metadata, io, ops::Range, path::Path, pin::Pin, task::Poll}; -use tokio::{ - fs::{read_to_string, File}, - io::AsyncSeekExt, -}; -use tokio_util::io::poll_read_buf; - -pub async fn serve_files( - req: &Request, - config: &FileserverConfig, -) -> Result>, ServiceError> { - let rpath = req.uri().path(); - - let mut path = config.root.clone(); - let mut user_path_depth = 0; - for seg in rpath.split("/") { - let seg = percent_decode_str(seg).decode_utf8()?; - - if seg == "" || seg == "." { - continue; - } - - if seg == ".." { - if user_path_depth <= 0 { - return Err(ServiceError::BadPath); - } - path.pop(); - user_path_depth -= 1; - } else { - path.push(seg.as_ref()); - user_path_depth += 1; - } - } - if !path.exists() { - return Err(ServiceError::NotFound); - } - - let metadata = path.metadata()?; - - if metadata.file_type().is_dir() { - debug!("sending index for {path:?}"); - if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { - return Ok(html_string_response(indexhtml)); - } - - if config.index { - if !rpath.ends_with("/") { - let mut r = Response::new(String::new()); - *r.status_mut() = StatusCode::FOUND; - r.headers_mut().insert( - LOCATION, - HeaderValue::from_str(&format!("{}/", rpath)) - .map_err(|_| ServiceError::Other)?, - ); - return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); - } - - return index(&path, rpath.to_string()) - .await - .map(html_string_response); - } else { - return Err(ServiceError::NotFound); - } - } - - let range = req.headers().typed_get::(); - let range = bytes_range(range, metadata.len())?; - - debug!("sending file {path:?}"); - let file = File::open(path.clone()).await?; - - let mut r = Response::new(BoxBody::new(StreamBody::new( - StreamBody::new(file_stream(file, 4096, range.clone())) - .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), - ))); - - if range.end - range.start != metadata.len() { - *r.status_mut() = StatusCode::PARTIAL_CONTENT; - r.headers_mut().typed_insert( - ContentRange::bytes(range.clone(), metadata.len()).expect("valid ContentRange"), - ); - } - - let mime = mime_guess::from_path(path).first_or_octet_stream(); - - r.headers_mut() - .typed_insert(ContentLength(range.end - range.start)); - r.headers_mut().typed_insert(ContentType::from(mime)); - r.headers_mut().typed_insert(AcceptRanges::bytes()); - - Ok(r) -} - -// Adapted from warp (https://github.com/seanmonstar/warp/blob/master/src/filters/fs.rs). Thanks! -fn file_stream( - mut file: File, - buf_size: usize, - range: Range, -) -> impl Stream> + Send { - use std::io::SeekFrom; - - let seek = async move { - if range.start != 0 { - file.seek(SeekFrom::Start(range.start)).await?; - } - Ok(file) - }; - - seek.into_stream() - .map(move |result| { - let mut buf = BytesMut::new(); - let mut len = range.end - range.start; - let mut f = match result { - Ok(f) => f, - Err(f) => return Either::Left(stream::once(future::err(f))), - }; - - Either::Right(stream::poll_fn(move |cx| { - if len == 0 { - return Poll::Ready(None); - } - reserve_at_least(&mut buf, buf_size); - - let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) { - Ok(n) => n as u64, - Err(err) => { - debug!("file read error: {}", err); - return Poll::Ready(Some(Err(err))); - } - }; - - if n == 0 { - debug!("file read found EOF before expected length"); - return Poll::Ready(None); - } - - let mut chunk = buf.split().freeze(); - if n > len { - chunk = chunk.split_to(len as usize); - len = 0; - } else { - len -= n; - } - - Poll::Ready(Some(Ok(chunk))) - })) - }) - .flatten() -} - -// Also adapted from warp -fn bytes_range(range: Option, max_len: u64) -> Result, ServiceError> { - use std::ops::Bound; - - let range = if let Some(range) = range { - range - } else { - return Ok(0..max_len); - }; - - let ret = range - .iter() - .map(|(start, end)| { - let start = match start { - Bound::Unbounded => 0, - Bound::Included(s) => s, - Bound::Excluded(s) => s + 1, - }; - - let end = match end { - Bound::Unbounded => max_len, - Bound::Included(s) => { - // For the special case where s == the file size - if s == max_len { - s - } else { - s + 1 - } - } - Bound::Excluded(s) => s, - }; - - if start < end && end <= max_len { - Ok(start..end) - } else { - Err(ServiceError::BadRange) - } - }) - .next() - .unwrap_or(Ok(0..max_len)); - ret -} - -fn reserve_at_least(buf: &mut BytesMut, cap: usize) { - if buf.capacity() - buf.len() < cap { - buf.reserve(cap); - } -} - -async fn index(path: &Path, rpath: String) -> Result { - let files = path - .read_dir()? - .map(|e| e.and_then(|e| Ok((e.file_name().into_string().unwrap(), e.metadata()?)))) - .filter(|e| e.as_ref().map(|(e, _)| !e.starts_with(".")).unwrap_or(true)) - .collect::, _>>()?; - let banner = read_to_string(path.join("index.banner.html")).await.ok(); - let mut s = String::new(); - IndexTemplate { - files, - banner, - path: rpath, - } - .render(&mut s) - .unwrap(); - Ok(s) -} - -fn html_string_response(s: String) -> hyper::Response> { - let mut r = Response::new(s); - r.headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/html")); - r.map(|b| b.map_err(|e| match e {}).boxed()) -} - -markup::define! { - IndexTemplate(path: String, banner: Option, files: Vec<(String, Metadata)>) { - @markup::doctype() - html { - head { - meta[charset="UTF-8"]; - title { "Index of " @path } - } - body { - @if let Some(banner) = banner { - @markup::raw(banner) - } else { - h1 { "Index of " @path } - } - hr; - table { - @if path != "/" { - tr { td { b { a[href=".."] { "../" } } } } - } - @for (name, meta) in files { tr { - td { a[href=name] { - @name - @if meta.file_type().is_dir() { "/" } - } } - td { - @if meta.file_type().is_dir() { - i { "directory" } - } else { - @humansize::format_size(meta.len(), FormatSizeOptions::default()) - } - } - } } - } - hr; - footer { sub { "served by " a[href="https://codeberg.org/metamuffin/gnix"] { "gnix" } } } - } - } - } -} diff --git a/src/filters/accesslog.rs b/src/filters/accesslog.rs new file mode 100644 index 0000000..ff5a8d5 --- /dev/null +++ b/src/filters/accesslog.rs @@ -0,0 +1,48 @@ +use crate::{config::AccessLogConfig, error::ServiceError, FilterRequest, State}; +use futures::executor::block_on; +use log::error; +use std::{net::SocketAddr, ops::ControlFlow}; +use tokio::{ + fs::OpenOptions, + io::{AsyncWriteExt, BufWriter}, +}; + +pub async fn access_log( + state: &State, + host: &str, + addr: SocketAddr, + config: &AccessLogConfig, + req: &FilterRequest, +) -> Result, ServiceError> { + let mut g = state.access_logs.write().await; + + let log = g.entry(host.to_owned()).or_insert_with(|| { + BufWriter::new( + // TODO aaahh dont block the runtime and dont panic in any case.... + block_on( + OpenOptions::new() + .append(true) + .create(true) + .open(&config.file), + ) + .unwrap(), + ) + }); + + let method = req.method().as_str(); + let mut res = log + .write_all(format!("{addr}\t{method}\t{:?}\n", req.uri()).as_bytes()) + .await; + + if config.flush && res.is_ok() { + res = log.flush().await; + } + + if config.reject_on_fail { + res? + } else if let Err(e) = res { + error!("failed to write log: {e:?}") + } + + Ok(ControlFlow::Continue(())) +} diff --git a/src/filters/auth.rs b/src/filters/auth.rs new file mode 100644 index 0000000..92a9ba3 --- /dev/null +++ b/src/filters/auth.rs @@ -0,0 +1,41 @@ +use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use base64::Engine; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{ + header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, + Response, StatusCode, +}; +use log::debug; +use std::ops::ControlFlow; + +pub fn http_basic( + config: &HttpBasicAuthConfig, + req: &FilterRequest, + resp: &mut FilterResponseOut, +) -> Result, ServiceError> { + if let Some(auth) = req.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if config.valid.contains(&k) { + debug!("valid auth"); + return Ok(ControlFlow::Continue(())); + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), + ); + *resp = Some(r); + Ok(ControlFlow::Break(())) +} diff --git a/src/filters/files.rs b/src/filters/files.rs new file mode 100644 index 0000000..733d045 --- /dev/null +++ b/src/filters/files.rs @@ -0,0 +1,278 @@ +use crate::{config::FileserverConfig, ServiceError}; +use bytes::{Bytes, BytesMut}; +use futures_util::{future, future::Either, ready, stream, FutureExt, Stream, StreamExt}; +use headers::{AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMapExt}; +use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; +use humansize::FormatSizeOptions; +use hyper::{ + body::{Frame, Incoming}, + header::{CONTENT_TYPE, LOCATION}, + http::HeaderValue, + Request, Response, StatusCode, +}; +use log::debug; +use markup::Render; +use percent_encoding::percent_decode_str; +use std::{fs::Metadata, io, ops::Range, path::Path, pin::Pin, task::Poll}; +use tokio::{ + fs::{read_to_string, File}, + io::AsyncSeekExt, +}; +use tokio_util::io::poll_read_buf; + +pub async fn serve_files( + req: &Request, + config: &FileserverConfig, +) -> Result>, ServiceError> { + let rpath = req.uri().path(); + + let mut path = config.root.clone(); + let mut user_path_depth = 0; + for seg in rpath.split("/") { + let seg = percent_decode_str(seg).decode_utf8()?; + + if seg == "" || seg == "." { + continue; + } + + if seg == ".." { + if user_path_depth <= 0 { + return Err(ServiceError::BadPath); + } + path.pop(); + user_path_depth -= 1; + } else { + path.push(seg.as_ref()); + user_path_depth += 1; + } + } + if !path.exists() { + return Err(ServiceError::NotFound); + } + + let metadata = path.metadata()?; + + if metadata.file_type().is_dir() { + debug!("sending index for {path:?}"); + if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { + return Ok(html_string_response(indexhtml)); + } + + if config.index { + if !rpath.ends_with("/") { + let mut r = Response::new(String::new()); + *r.status_mut() = StatusCode::FOUND; + r.headers_mut().insert( + LOCATION, + HeaderValue::from_str(&format!("{}/", rpath)) + .map_err(|_| ServiceError::Other)?, + ); + return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + } + + return index(&path, rpath.to_string()) + .await + .map(html_string_response); + } else { + return Err(ServiceError::NotFound); + } + } + + let range = req.headers().typed_get::(); + let range = bytes_range(range, metadata.len())?; + + debug!("sending file {path:?}"); + let file = File::open(path.clone()).await?; + + let mut r = Response::new(BoxBody::new(StreamBody::new( + StreamBody::new(file_stream(file, 4096, range.clone())) + .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), + ))); + + if range.end - range.start != metadata.len() { + *r.status_mut() = StatusCode::PARTIAL_CONTENT; + r.headers_mut().typed_insert( + ContentRange::bytes(range.clone(), metadata.len()).expect("valid ContentRange"), + ); + } + + let mime = mime_guess::from_path(path).first_or_octet_stream(); + + r.headers_mut() + .typed_insert(ContentLength(range.end - range.start)); + r.headers_mut().typed_insert(ContentType::from(mime)); + r.headers_mut().typed_insert(AcceptRanges::bytes()); + + Ok(r) +} + +// Adapted from warp (https://github.com/seanmonstar/warp/blob/master/src/filters/fs.rs). Thanks! +fn file_stream( + mut file: File, + buf_size: usize, + range: Range, +) -> impl Stream> + Send { + use std::io::SeekFrom; + + let seek = async move { + if range.start != 0 { + file.seek(SeekFrom::Start(range.start)).await?; + } + Ok(file) + }; + + seek.into_stream() + .map(move |result| { + let mut buf = BytesMut::new(); + let mut len = range.end - range.start; + let mut f = match result { + Ok(f) => f, + Err(f) => return Either::Left(stream::once(future::err(f))), + }; + + Either::Right(stream::poll_fn(move |cx| { + if len == 0 { + return Poll::Ready(None); + } + reserve_at_least(&mut buf, buf_size); + + let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) { + Ok(n) => n as u64, + Err(err) => { + debug!("file read error: {}", err); + return Poll::Ready(Some(Err(err))); + } + }; + + if n == 0 { + debug!("file read found EOF before expected length"); + return Poll::Ready(None); + } + + let mut chunk = buf.split().freeze(); + if n > len { + chunk = chunk.split_to(len as usize); + len = 0; + } else { + len -= n; + } + + Poll::Ready(Some(Ok(chunk))) + })) + }) + .flatten() +} + +// Also adapted from warp +fn bytes_range(range: Option, max_len: u64) -> Result, ServiceError> { + use std::ops::Bound; + + let range = if let Some(range) = range { + range + } else { + return Ok(0..max_len); + }; + + let ret = range + .iter() + .map(|(start, end)| { + let start = match start { + Bound::Unbounded => 0, + Bound::Included(s) => s, + Bound::Excluded(s) => s + 1, + }; + + let end = match end { + Bound::Unbounded => max_len, + Bound::Included(s) => { + // For the special case where s == the file size + if s == max_len { + s + } else { + s + 1 + } + } + Bound::Excluded(s) => s, + }; + + if start < end && end <= max_len { + Ok(start..end) + } else { + Err(ServiceError::BadRange) + } + }) + .next() + .unwrap_or(Ok(0..max_len)); + ret +} + +fn reserve_at_least(buf: &mut BytesMut, cap: usize) { + if buf.capacity() - buf.len() < cap { + buf.reserve(cap); + } +} + +async fn index(path: &Path, rpath: String) -> Result { + let files = path + .read_dir()? + .map(|e| e.and_then(|e| Ok((e.file_name().into_string().unwrap(), e.metadata()?)))) + .filter(|e| e.as_ref().map(|(e, _)| !e.starts_with(".")).unwrap_or(true)) + .collect::, _>>()?; + let banner = read_to_string(path.join("index.banner.html")).await.ok(); + let mut s = String::new(); + IndexTemplate { + files, + banner, + path: rpath, + } + .render(&mut s) + .unwrap(); + Ok(s) +} + +fn html_string_response(s: String) -> hyper::Response> { + let mut r = Response::new(s); + r.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/html")); + r.map(|b| b.map_err(|e| match e {}).boxed()) +} + +markup::define! { + IndexTemplate(path: String, banner: Option, files: Vec<(String, Metadata)>) { + @markup::doctype() + html { + head { + meta[charset="UTF-8"]; + title { "Index of " @path } + } + body { + @if let Some(banner) = banner { + @markup::raw(banner) + } else { + h1 { "Index of " @path } + } + hr; + table { + @if path != "/" { + tr { td { b { a[href=".."] { "../" } } } } + } + @for (name, meta) in files { tr { + td { a[href=name] { + @name + @if meta.file_type().is_dir() { "/" } + } } + td { + @if meta.file_type().is_dir() { + i { "directory" } + } else { + @humansize::format_size(meta.len(), FormatSizeOptions::default()) + } + } + } } + } + hr; + footer { sub { "served by " a[href="https://codeberg.org/metamuffin/gnix"] { "gnix" } } } + } + } + } +} diff --git a/src/filters/mod.rs b/src/filters/mod.rs new file mode 100644 index 0000000..fdeed51 --- /dev/null +++ b/src/filters/mod.rs @@ -0,0 +1,5 @@ + +pub mod auth; +pub mod files; +pub mod proxy; +pub mod accesslog; \ No newline at end of file diff --git a/src/filters/proxy.rs b/src/filters/proxy.rs new file mode 100644 index 0000000..40ebf17 --- /dev/null +++ b/src/filters/proxy.rs @@ -0,0 +1,104 @@ +use crate::{helper::TokioIo, ServiceError, State}; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{ + body::Incoming, + header::UPGRADE, + http::{ + uri::{PathAndQuery, Scheme}, + HeaderValue, + }, + upgrade::OnUpgrade, + Request, Uri, +}; +use log::{debug, error, warn}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::net::TcpStream; + +pub async fn proxy_request( + state: &Arc, + mut req: Request, + addr: SocketAddr, + backend: &SocketAddr, +) -> Result>, ServiceError> { + #[cfg(feature = "mond")] + state.reporting.request_out.inc(); + + let scheme_secure = req.uri().scheme() == Some(&Scheme::HTTPS); + *req.uri_mut() = Uri::builder() + .path_and_query( + req.uri() + .clone() + .path_and_query() + .cloned() + .unwrap_or(PathAndQuery::from_static("/")), + ) + .build() + .unwrap(); + + req.headers_mut().insert( + "x-forwarded-for", + HeaderValue::from_str(&format!("{addr}")).unwrap(), + ); + req.headers_mut().insert( + "x-forwarded-proto", + if scheme_secure { + HeaderValue::from_static("https") + } else { + HeaderValue::from_static("http") + }, + ); + + let do_upgrade = req.headers().contains_key(UPGRADE); + let on_upgrade_downstream = req.extensions_mut().remove::(); + + let _limit_guard = state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(req) + .await + .map_err(ServiceError::Hyper)? + }; + + if do_upgrade { + let on_upgrade_upstream = resp.extensions_mut().remove::(); + tokio::task::spawn(async move { + debug!("about to upgrade connection, sending empty response"); + match ( + on_upgrade_upstream.unwrap().await, + on_upgrade_downstream.unwrap().await, + ) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), + } + } + (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), + } + }); + } + Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) +} diff --git a/src/main.rs b/src/main.rs index 807d0ab..07f3d5c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,19 +2,16 @@ #![feature(exclusive_range_pattern)] #![feature(slice_split_once)] -pub mod auth; pub mod config; pub mod error; -pub mod files; +pub mod filters; pub mod helper; -pub mod proxy; #[cfg(feature = "mond")] pub mod reporting; use crate::{ config::{Config, RouteFilter}, - files::serve_files, - proxy::proxy_request, + filters::{files::serve_files, proxy::proxy_request}, }; use anyhow::{bail, Context, Result}; use bytes::Bytes; @@ -35,7 +32,7 @@ use log::{debug, error, info, warn}; #[cfg(feature = "mond")] use reporting::Reporting; use std::{ - fs::File, + collections::HashMap, io::BufReader, net::SocketAddr, ops::ControlFlow, @@ -45,6 +42,8 @@ use std::{ sync::Arc, }; use tokio::{ + fs::File, + io::BufWriter, net::TcpListener, signal::ctrl_c, sync::{RwLock, Semaphore}, @@ -53,11 +52,13 @@ use tokio_rustls::TlsAcceptor; pub struct State { pub config: RwLock>, + pub access_logs: RwLock>>, pub l_incoming: Semaphore, pub l_outgoing: Semaphore, #[cfg(feature = "mond")] pub reporting: Reporting, } +pub struct HostState {} pub type FilterRequest = Request; pub type FilterResponseOut = Option>>; @@ -89,6 +90,7 @@ async fn main() -> anyhow::Result<()> { #[cfg(feature = "mond")] reporting: Reporting::new(&config), config: RwLock::new(Arc::new(config)), + access_logs: Default::default(), }); if state.config.read().await.watch_config { @@ -228,12 +230,12 @@ pub async fn serve_stream anyhow::Result> { - let mut reader = BufReader::new(File::open(path).context("reading tls certs")?); + let mut reader = BufReader::new(std::fs::File::open(path).context("reading tls certs")?); let certs = rustls_pemfile::certs(&mut reader).context("parsing tls certs")?; Ok(certs.into_iter().map(rustls::Certificate).collect()) } fn load_private_key(path: &Path) -> anyhow::Result { - let mut reader = BufReader::new(File::open(path).context("reading tls private key")?); + let mut reader = BufReader::new(std::fs::File::open(path).context("reading tls private key")?); let keys = rustls_pemfile::pkcs8_private_keys(&mut reader).context("parsing tls private key")?; if keys.len() != 1 { @@ -263,6 +265,7 @@ async fn service( #[cfg(feature = "mond")] state.reporting.hosts.get(host).unwrap().requests_in.inc(); + // TODO this code is horrible let mut req = Some(req); let mut resp = None; for filter in &route.0 { @@ -285,11 +288,21 @@ async fn service( ); ControlFlow::Continue(()) } - RouteFilter::HttpBasicAuth { config } => auth::http_basic( + RouteFilter::HttpBasicAuth { config } => filters::auth::http_basic( config, req.as_ref().ok_or(ServiceError::RequestTaken)?, &mut resp, )?, + RouteFilter::AccessLog { config } => { + filters::accesslog::access_log( + &state, + host, + addr, + config, + req.as_ref().ok_or(ServiceError::RequestTaken)?, + ) + .await? + } }; match cf { ControlFlow::Continue(_) => continue, diff --git a/src/proxy.rs b/src/proxy.rs deleted file mode 100644 index 40ebf17..0000000 --- a/src/proxy.rs +++ /dev/null @@ -1,104 +0,0 @@ -use crate::{helper::TokioIo, ServiceError, State}; -use http_body_util::{combinators::BoxBody, BodyExt}; -use hyper::{ - body::Incoming, - header::UPGRADE, - http::{ - uri::{PathAndQuery, Scheme}, - HeaderValue, - }, - upgrade::OnUpgrade, - Request, Uri, -}; -use log::{debug, error, warn}; -use std::{net::SocketAddr, sync::Arc}; -use tokio::net::TcpStream; - -pub async fn proxy_request( - state: &Arc, - mut req: Request, - addr: SocketAddr, - backend: &SocketAddr, -) -> Result>, ServiceError> { - #[cfg(feature = "mond")] - state.reporting.request_out.inc(); - - let scheme_secure = req.uri().scheme() == Some(&Scheme::HTTPS); - *req.uri_mut() = Uri::builder() - .path_and_query( - req.uri() - .clone() - .path_and_query() - .cloned() - .unwrap_or(PathAndQuery::from_static("/")), - ) - .build() - .unwrap(); - - req.headers_mut().insert( - "x-forwarded-for", - HeaderValue::from_str(&format!("{addr}")).unwrap(), - ); - req.headers_mut().insert( - "x-forwarded-proto", - if scheme_secure { - HeaderValue::from_static("https") - } else { - HeaderValue::from_static("http") - }, - ); - - let do_upgrade = req.headers().contains_key(UPGRADE); - let on_upgrade_downstream = req.extensions_mut().remove::(); - - let _limit_guard = state.l_outgoing.try_acquire()?; - debug!("\tforwarding to {}", backend); - let mut resp = { - let client_stream = TokioIo( - TcpStream::connect(backend) - .await - .map_err(|_| ServiceError::CantConnect)?, - ); - - let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) - .await - .map_err(ServiceError::Hyper)?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - warn!("connection failed: {:?}", err); - } - }); - sender - .send_request(req) - .await - .map_err(ServiceError::Hyper)? - }; - - if do_upgrade { - let on_upgrade_upstream = resp.extensions_mut().remove::(); - tokio::task::spawn(async move { - debug!("about to upgrade connection, sending empty response"); - match ( - on_upgrade_upstream.unwrap().await, - on_upgrade_downstream.unwrap().await, - ) { - (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { - debug!("upgrade successful"); - match tokio::io::copy_bidirectional( - &mut TokioIo(upgraded_downstream), - &mut TokioIo(upgraded_upstream), - ) - .await - { - Ok((from_client, from_server)) => { - debug!("proxy socket terminated: {from_server} sent, {from_client} received") - } - Err(e) => warn!("proxy socket error: {e}"), - } - } - (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), - } - }); - } - Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) -} -- cgit v1.2.3-70-g09d2