diff options
author | metamuffin <metamuffin@disroot.org> | 2023-04-06 14:43:59 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-04-06 14:43:59 +0200 |
commit | 6d06d9a0eaf8d7c3f03df1501b9acd0a71cb53ae (patch) | |
tree | 55f049ae28e87ea46679e53f5385be529ece32cf | |
parent | 7bcbf526f357aca014e6bff8e833572ba38f721f (diff) | |
download | gnix-6d06d9a0eaf8d7c3f03df1501b9acd0a71cb53ae.tar gnix-6d06d9a0eaf8d7c3f03df1501b9acd0a71cb53ae.tar.bz2 gnix-6d06d9a0eaf8d7c3f03df1501b9acd0a71cb53ae.tar.zst |
state struct
-rw-r--r-- | src/main.rs | 60 |
1 files changed, 36 insertions, 24 deletions
diff --git a/src/main.rs b/src/main.rs index e39cf58..ef62af7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,6 +31,10 @@ use tokio::{ }; use tokio_rustls::TlsAcceptor; +pub struct State { + pub config: Config, +} + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init_from_env("LOG"); @@ -38,26 +42,33 @@ async fn main() -> anyhow::Result<()> { let config_path = std::env::args().skip(1).next().ok_or(anyhow!( "first argument is expected to be the configuration file" ))?; - let config = Arc::new(Config::load(&config_path)?); - let config2 = config.clone(); + + let config = Config::load(&config_path)?; + let state = Arc::new(State { config }); - tokio::spawn(async move { - if let Err(e) = serve_http(config).await { - error!("{e}") - } - }); - tokio::spawn(async move { - if let Err(e) = serve_https(config2).await { - error!("{e}") - } - }); + { + let state = state.clone(); + tokio::spawn(async move { + if let Err(e) = serve_http(state).await { + error!("{e}") + } + }); + } + { + let state = state.clone(); + tokio::spawn(async move { + if let Err(e) = serve_https(state).await { + error!("{e}") + } + }); + } ctrl_c().await.unwrap(); Ok(()) } -async fn serve_http(config: Arc<Config>) -> Result<()> { - let http_config = match &config.http { +async fn serve_http(state: Arc<State>) -> Result<()> { + let http_config = match &state.config.http { Some(n) => n, None => return Ok(()), }; @@ -66,13 +77,13 @@ async fn serve_http(config: Arc<Config>) -> Result<()> { loop { let (stream, addr) = listener.accept().await.context("accepting connection")?; debug!("connection from {addr}"); - let config = config.clone(); + let config = state.clone(); tokio::spawn(async move { serve_stream(config, stream, addr).await }); } } -async fn serve_https(config: Arc<Config>) -> Result<()> { - let https_config = match &config.https { +async fn serve_https(state: Arc<State>) -> Result<()> { + let https_config = match &state.config.https { Some(n) => n, None => return Ok(()), }; @@ -95,12 +106,12 @@ async fn serve_https(config: Arc<Config>) -> Result<()> { info!("serving https"); loop { let (stream, addr) = listener.accept().await.context("accepting connection")?; - let config = config.clone(); + let state = state.clone(); let tls_acceptor = tls_acceptor.clone(); tokio::task::spawn(async move { debug!("connection from {addr}"); match tls_acceptor.accept(stream).await { - Ok(stream) => serve_stream(config, stream, addr).await, + Ok(stream) => serve_stream(state, stream, addr).await, Err(e) => warn!("error accepting tls: {e}"), }; }); @@ -108,7 +119,7 @@ async fn serve_https(config: Arc<Config>) -> Result<()> { } pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>( - config: Arc<Config>, + state: Arc<State>, stream: T, addr: SocketAddr, ) { @@ -116,9 +127,9 @@ pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>( .serve_connection( stream, service_fn(move |req| { - let config = config.clone(); + let state = state.clone(); async move { - match service(config, req, addr).await { + match service(state, req, addr).await { Ok(r) => Ok(r), Err(ServiceError::Hyper(e)) => Err(e), Err(error) => Ok({ @@ -156,13 +167,14 @@ fn load_private_key(path: &Path) -> anyhow::Result<rustls::PrivateKey> { } async fn service( - config: Arc<Config>, + state: Arc<State>, req: Request<Incoming>, addr: SocketAddr, ) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> { debug!("{addr} ~> {:?} {}", req.headers().get(HOST), req.uri()); - let route = config + let route = state + .config .hosts .get(remove_port( &req.headers() |