diff options
Diffstat (limited to 'src/filters')
-rw-r--r-- | src/filters/accesslog.rs | 48 | ||||
-rw-r--r-- | src/filters/auth.rs | 41 | ||||
-rw-r--r-- | src/filters/files.rs | 278 | ||||
-rw-r--r-- | src/filters/mod.rs | 5 | ||||
-rw-r--r-- | src/filters/proxy.rs | 104 |
5 files changed, 476 insertions, 0 deletions
diff --git a/src/filters/accesslog.rs b/src/filters/accesslog.rs new file mode 100644 index 0000000..ff5a8d5 --- /dev/null +++ b/src/filters/accesslog.rs @@ -0,0 +1,48 @@ +use crate::{config::AccessLogConfig, error::ServiceError, FilterRequest, State}; +use futures::executor::block_on; +use log::error; +use std::{net::SocketAddr, ops::ControlFlow}; +use tokio::{ + fs::OpenOptions, + io::{AsyncWriteExt, BufWriter}, +}; + +pub async fn access_log( + state: &State, + host: &str, + addr: SocketAddr, + config: &AccessLogConfig, + req: &FilterRequest, +) -> Result<ControlFlow<()>, ServiceError> { + let mut g = state.access_logs.write().await; + + let log = g.entry(host.to_owned()).or_insert_with(|| { + BufWriter::new( + // TODO aaahh dont block the runtime and dont panic in any case.... + block_on( + OpenOptions::new() + .append(true) + .create(true) + .open(&config.file), + ) + .unwrap(), + ) + }); + + let method = req.method().as_str(); + let mut res = log + .write_all(format!("{addr}\t{method}\t{:?}\n", req.uri()).as_bytes()) + .await; + + if config.flush && res.is_ok() { + res = log.flush().await; + } + + if config.reject_on_fail { + res? + } else if let Err(e) = res { + error!("failed to write log: {e:?}") + } + + Ok(ControlFlow::Continue(())) +} diff --git a/src/filters/auth.rs b/src/filters/auth.rs new file mode 100644 index 0000000..92a9ba3 --- /dev/null +++ b/src/filters/auth.rs @@ -0,0 +1,41 @@ +use crate::{config::HttpBasicAuthConfig, error::ServiceError, FilterRequest, FilterResponseOut}; +use base64::Engine; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{ + header::{HeaderValue, AUTHORIZATION, WWW_AUTHENTICATE}, + Response, StatusCode, +}; +use log::debug; +use std::ops::ControlFlow; + +pub fn http_basic( + config: &HttpBasicAuthConfig, + req: &FilterRequest, + resp: &mut FilterResponseOut, +) -> Result<ControlFlow<()>, ServiceError> { + if let Some(auth) = req.headers().get(AUTHORIZATION) { + let k = auth + .as_bytes() + .strip_prefix(b"Basic ") + .ok_or(ServiceError::BadAuth)?; + let k = base64::engine::general_purpose::STANDARD.decode(k)?; + let k = String::from_utf8(k)?; + if config.valid.contains(&k) { + debug!("valid auth"); + return Ok(ControlFlow::Continue(())); + } else { + debug!("invalid auth"); + } + } + debug!("unauthorized; sending auth challenge"); + let mut r = Response::new(BoxBody::<_, ServiceError>::new( + String::new().map_err(|_| unreachable!()), + )); + *r.status_mut() = StatusCode::UNAUTHORIZED; + r.headers_mut().insert( + WWW_AUTHENTICATE, + HeaderValue::from_str(&format!("Basic realm=\"{}\"", config.realm)).unwrap(), + ); + *resp = Some(r); + Ok(ControlFlow::Break(())) +} diff --git a/src/filters/files.rs b/src/filters/files.rs new file mode 100644 index 0000000..733d045 --- /dev/null +++ b/src/filters/files.rs @@ -0,0 +1,278 @@ +use crate::{config::FileserverConfig, ServiceError}; +use bytes::{Bytes, BytesMut}; +use futures_util::{future, future::Either, ready, stream, FutureExt, Stream, StreamExt}; +use headers::{AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMapExt}; +use http_body_util::{combinators::BoxBody, BodyExt, StreamBody}; +use humansize::FormatSizeOptions; +use hyper::{ + body::{Frame, Incoming}, + header::{CONTENT_TYPE, LOCATION}, + http::HeaderValue, + Request, Response, StatusCode, +}; +use log::debug; +use markup::Render; +use percent_encoding::percent_decode_str; +use std::{fs::Metadata, io, ops::Range, path::Path, pin::Pin, task::Poll}; +use tokio::{ + fs::{read_to_string, File}, + io::AsyncSeekExt, +}; +use tokio_util::io::poll_read_buf; + +pub async fn serve_files( + req: &Request<Incoming>, + config: &FileserverConfig, +) -> Result<hyper::Response<BoxBody<Bytes, ServiceError>>, ServiceError> { + let rpath = req.uri().path(); + + let mut path = config.root.clone(); + let mut user_path_depth = 0; + for seg in rpath.split("/") { + let seg = percent_decode_str(seg).decode_utf8()?; + + if seg == "" || seg == "." { + continue; + } + + if seg == ".." { + if user_path_depth <= 0 { + return Err(ServiceError::BadPath); + } + path.pop(); + user_path_depth -= 1; + } else { + path.push(seg.as_ref()); + user_path_depth += 1; + } + } + if !path.exists() { + return Err(ServiceError::NotFound); + } + + let metadata = path.metadata()?; + + if metadata.file_type().is_dir() { + debug!("sending index for {path:?}"); + if let Ok(indexhtml) = read_to_string(path.join("index.html")).await { + return Ok(html_string_response(indexhtml)); + } + + if config.index { + if !rpath.ends_with("/") { + let mut r = Response::new(String::new()); + *r.status_mut() = StatusCode::FOUND; + r.headers_mut().insert( + LOCATION, + HeaderValue::from_str(&format!("{}/", rpath)) + .map_err(|_| ServiceError::Other)?, + ); + return Ok(r.map(|b| b.map_err(|e| match e {}).boxed())); + } + + return index(&path, rpath.to_string()) + .await + .map(html_string_response); + } else { + return Err(ServiceError::NotFound); + } + } + + let range = req.headers().typed_get::<headers::Range>(); + let range = bytes_range(range, metadata.len())?; + + debug!("sending file {path:?}"); + let file = File::open(path.clone()).await?; + + let mut r = Response::new(BoxBody::new(StreamBody::new( + StreamBody::new(file_stream(file, 4096, range.clone())) + .map(|e| e.map(|e| Frame::data(e)).map_err(ServiceError::Io)), + ))); + + if range.end - range.start != metadata.len() { + *r.status_mut() = StatusCode::PARTIAL_CONTENT; + r.headers_mut().typed_insert( + ContentRange::bytes(range.clone(), metadata.len()).expect("valid ContentRange"), + ); + } + + let mime = mime_guess::from_path(path).first_or_octet_stream(); + + r.headers_mut() + .typed_insert(ContentLength(range.end - range.start)); + r.headers_mut().typed_insert(ContentType::from(mime)); + r.headers_mut().typed_insert(AcceptRanges::bytes()); + + Ok(r) +} + +// Adapted from warp (https://github.com/seanmonstar/warp/blob/master/src/filters/fs.rs). Thanks! +fn file_stream( + mut file: File, + buf_size: usize, + range: Range<u64>, +) -> impl Stream<Item = Result<Bytes, io::Error>> + Send { + use std::io::SeekFrom; + + let seek = async move { + if range.start != 0 { + file.seek(SeekFrom::Start(range.start)).await?; + } + Ok(file) + }; + + seek.into_stream() + .map(move |result| { + let mut buf = BytesMut::new(); + let mut len = range.end - range.start; + let mut f = match result { + Ok(f) => f, + Err(f) => return Either::Left(stream::once(future::err(f))), + }; + + Either::Right(stream::poll_fn(move |cx| { + if len == 0 { + return Poll::Ready(None); + } + reserve_at_least(&mut buf, buf_size); + + let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) { + Ok(n) => n as u64, + Err(err) => { + debug!("file read error: {}", err); + return Poll::Ready(Some(Err(err))); + } + }; + + if n == 0 { + debug!("file read found EOF before expected length"); + return Poll::Ready(None); + } + + let mut chunk = buf.split().freeze(); + if n > len { + chunk = chunk.split_to(len as usize); + len = 0; + } else { + len -= n; + } + + Poll::Ready(Some(Ok(chunk))) + })) + }) + .flatten() +} + +// Also adapted from warp +fn bytes_range(range: Option<headers::Range>, max_len: u64) -> Result<Range<u64>, ServiceError> { + use std::ops::Bound; + + let range = if let Some(range) = range { + range + } else { + return Ok(0..max_len); + }; + + let ret = range + .iter() + .map(|(start, end)| { + let start = match start { + Bound::Unbounded => 0, + Bound::Included(s) => s, + Bound::Excluded(s) => s + 1, + }; + + let end = match end { + Bound::Unbounded => max_len, + Bound::Included(s) => { + // For the special case where s == the file size + if s == max_len { + s + } else { + s + 1 + } + } + Bound::Excluded(s) => s, + }; + + if start < end && end <= max_len { + Ok(start..end) + } else { + Err(ServiceError::BadRange) + } + }) + .next() + .unwrap_or(Ok(0..max_len)); + ret +} + +fn reserve_at_least(buf: &mut BytesMut, cap: usize) { + if buf.capacity() - buf.len() < cap { + buf.reserve(cap); + } +} + +async fn index(path: &Path, rpath: String) -> Result<String, ServiceError> { + let files = path + .read_dir()? + .map(|e| e.and_then(|e| Ok((e.file_name().into_string().unwrap(), e.metadata()?)))) + .filter(|e| e.as_ref().map(|(e, _)| !e.starts_with(".")).unwrap_or(true)) + .collect::<Result<Vec<_>, _>>()?; + let banner = read_to_string(path.join("index.banner.html")).await.ok(); + let mut s = String::new(); + IndexTemplate { + files, + banner, + path: rpath, + } + .render(&mut s) + .unwrap(); + Ok(s) +} + +fn html_string_response(s: String) -> hyper::Response<BoxBody<Bytes, ServiceError>> { + let mut r = Response::new(s); + r.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/html")); + r.map(|b| b.map_err(|e| match e {}).boxed()) +} + +markup::define! { + IndexTemplate(path: String, banner: Option<String>, files: Vec<(String, Metadata)>) { + @markup::doctype() + html { + head { + meta[charset="UTF-8"]; + title { "Index of " @path } + } + body { + @if let Some(banner) = banner { + @markup::raw(banner) + } else { + h1 { "Index of " @path } + } + hr; + table { + @if path != "/" { + tr { td { b { a[href=".."] { "../" } } } } + } + @for (name, meta) in files { tr { + td { a[href=name] { + @name + @if meta.file_type().is_dir() { "/" } + } } + td { + @if meta.file_type().is_dir() { + i { "directory" } + } else { + @humansize::format_size(meta.len(), FormatSizeOptions::default()) + } + } + } } + } + hr; + footer { sub { "served by " a[href="https://codeberg.org/metamuffin/gnix"] { "gnix" } } } + } + } + } +} diff --git a/src/filters/mod.rs b/src/filters/mod.rs new file mode 100644 index 0000000..fdeed51 --- /dev/null +++ b/src/filters/mod.rs @@ -0,0 +1,5 @@ + +pub mod auth; +pub mod files; +pub mod proxy; +pub mod accesslog;
\ No newline at end of file diff --git a/src/filters/proxy.rs b/src/filters/proxy.rs new file mode 100644 index 0000000..40ebf17 --- /dev/null +++ b/src/filters/proxy.rs @@ -0,0 +1,104 @@ +use crate::{helper::TokioIo, ServiceError, State}; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::{ + body::Incoming, + header::UPGRADE, + http::{ + uri::{PathAndQuery, Scheme}, + HeaderValue, + }, + upgrade::OnUpgrade, + Request, Uri, +}; +use log::{debug, error, warn}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::net::TcpStream; + +pub async fn proxy_request( + state: &Arc<State>, + mut req: Request<Incoming>, + addr: SocketAddr, + backend: &SocketAddr, +) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> { + #[cfg(feature = "mond")] + state.reporting.request_out.inc(); + + let scheme_secure = req.uri().scheme() == Some(&Scheme::HTTPS); + *req.uri_mut() = Uri::builder() + .path_and_query( + req.uri() + .clone() + .path_and_query() + .cloned() + .unwrap_or(PathAndQuery::from_static("/")), + ) + .build() + .unwrap(); + + req.headers_mut().insert( + "x-forwarded-for", + HeaderValue::from_str(&format!("{addr}")).unwrap(), + ); + req.headers_mut().insert( + "x-forwarded-proto", + if scheme_secure { + HeaderValue::from_static("https") + } else { + HeaderValue::from_static("http") + }, + ); + + let do_upgrade = req.headers().contains_key(UPGRADE); + let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>(); + + let _limit_guard = state.l_outgoing.try_acquire()?; + debug!("\tforwarding to {}", backend); + let mut resp = { + let client_stream = TokioIo( + TcpStream::connect(backend) + .await + .map_err(|_| ServiceError::CantConnect)?, + ); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream) + .await + .map_err(ServiceError::Hyper)?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + warn!("connection failed: {:?}", err); + } + }); + sender + .send_request(req) + .await + .map_err(ServiceError::Hyper)? + }; + + if do_upgrade { + let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>(); + tokio::task::spawn(async move { + debug!("about to upgrade connection, sending empty response"); + match ( + on_upgrade_upstream.unwrap().await, + on_upgrade_downstream.unwrap().await, + ) { + (Ok(upgraded_upstream), Ok(upgraded_downstream)) => { + debug!("upgrade successful"); + match tokio::io::copy_bidirectional( + &mut TokioIo(upgraded_downstream), + &mut TokioIo(upgraded_upstream), + ) + .await + { + Ok((from_client, from_server)) => { + debug!("proxy socket terminated: {from_server} sent, {from_client} received") + } + Err(e) => warn!("proxy socket error: {e}"), + } + } + (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"), + } + }); + } + Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed())) +} |