From 951c4e90b573f3d14a137bade0853fb3f0f21a5d Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Mon, 28 Aug 2023 12:42:43 +0200 Subject: supporting listening on a list of addresses Signed-off-by: metamuffin --- src/config.rs | 36 ++++++++++++++++++++++++++++++++---- src/main.rs | 59 +++++++++++++++++++++++++++++++++++++++-------------------- 2 files changed, 71 insertions(+), 24 deletions(-) (limited to 'src') diff --git a/src/config.rs b/src/config.rs index 9e542d7..89a2b54 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ use anyhow::Context; -use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fs::read_to_string, net::SocketAddr, path::PathBuf}; +use serde::{Deserialize, Serialize, Deserializer, de::{Visitor, Error, SeqAccess, value}}; +use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf}; #[derive(Debug, Serialize, Deserialize)] pub struct Config { @@ -21,12 +21,14 @@ pub struct Limits { #[derive(Debug, Serialize, Deserialize)] pub struct HttpConfig { - pub bind: SocketAddr, + #[serde(deserialize_with = "string_or_seq")] + pub bind: Vec, } #[derive(Debug, Serialize, Deserialize)] pub struct HttpsConfig { - pub bind: SocketAddr, + #[serde(deserialize_with = "string_or_seq")] + pub bind: Vec, pub tls_cert: PathBuf, pub tls_key: PathBuf, } @@ -45,6 +47,32 @@ pub struct FileserverConfig { pub index: bool, } +// fall back to expecting a single string and putting that in a 1-length vector +fn string_or_seq<'de, D>(des: D) -> Result, D::Error> +where D: Deserializer<'de> { + struct StringOrList; + impl<'de> Visitor<'de> for StringOrList { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("sequence or list") + } + + fn visit_str(self, val: &str) -> Result, E> + where E: Error { + let addr = SocketAddr::deserialize(value::StrDeserializer::new(val))?; + Ok(vec![addr]) + } + + fn visit_seq(self, val: A) -> Result, A::Error> + where A: SeqAccess<'de> { + Vec::::deserialize(value::SeqAccessDeserializer::new(val)) + } + } + + des.deserialize_any(StringOrList) +} + impl Config { pub fn load(path: &str) -> anyhow::Result { let raw = read_to_string(path).context("reading config file")?; diff --git a/src/main.rs b/src/main.rs index f51cd67..aa13609 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,6 +14,7 @@ use crate::{ }; use anyhow::{anyhow, bail, Context, Result}; use error::ServiceError; +use futures::future::try_join_all; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ body::Incoming, @@ -88,14 +89,24 @@ async fn serve_http(state: Arc) -> Result<()> { Some(n) => n, None => return Ok(()), }; - let listener = TcpListener::bind(http_config.bind).await?; + + let listen_futs: Result> = try_join_all(http_config.bind + .iter() + .map(|e| async { + let l = TcpListener::bind(e.clone()).await?; + loop { + let (stream, addr) = l.accept().await.context("accepting connection")?; + debug!("connection from {addr}"); + let config = state.clone(); + tokio::spawn(async move { serve_stream(config, stream, addr).await }); + } + })) + .await; + info!("serving http"); - loop { - let (stream, addr) = listener.accept().await.context("accepting connection")?; - debug!("connection from {addr}"); - let config = state.clone(); - tokio::spawn(async move { serve_stream(config, stream, addr).await }); - } + + listen_futs?; + Ok(()) } async fn serve_https(state: Arc) -> Result<()> { @@ -116,22 +127,30 @@ async fn serve_https(state: Arc) -> Result<()> { ]; Arc::new(cfg) }; - let listener = TcpListener::bind(https_config.bind).await?; let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config)); + let listen_futs: Result> = try_join_all(https_config.bind + .iter() + .map(|e| async { + let l = TcpListener::bind(e.clone()).await?; + loop { + let (stream, addr) = l.accept().await.context("accepting connection")?; + let state = state.clone(); + let tls_acceptor = tls_acceptor.clone(); + tokio::task::spawn(async move { + debug!("connection from {addr}"); + match tls_acceptor.accept(stream).await { + Ok(stream) => serve_stream(state, stream, addr).await, + Err(e) => warn!("error accepting tls: {e}"), + }; + }); + } + })) + .await; + info!("serving https"); - loop { - let (stream, addr) = listener.accept().await.context("accepting connection")?; - let state = state.clone(); - let tls_acceptor = tls_acceptor.clone(); - tokio::task::spawn(async move { - debug!("connection from {addr}"); - match tls_acceptor.accept(stream).await { - Ok(stream) => serve_stream(state, stream, addr).await, - Err(e) => warn!("error accepting tls: {e}"), - }; - }); - } + listen_futs?; + Ok(()) } pub async fn serve_stream( -- cgit v1.2.3-70-g09d2