aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/config.rs36
-rw-r--r--src/main.rs59
2 files changed, 71 insertions, 24 deletions
diff --git a/src/config.rs b/src/config.rs
index 9e542d7..89a2b54 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,6 +1,6 @@
use anyhow::Context;
-use serde::{Deserialize, Serialize};
-use std::{collections::HashMap, fs::read_to_string, net::SocketAddr, path::PathBuf};
+use serde::{Deserialize, Serialize, Deserializer, de::{Visitor, Error, SeqAccess, value}};
+use std::{collections::HashMap, fmt, fs::read_to_string, net::SocketAddr, path::PathBuf};
#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
@@ -21,12 +21,14 @@ pub struct Limits {
#[derive(Debug, Serialize, Deserialize)]
pub struct HttpConfig {
- pub bind: SocketAddr,
+ #[serde(deserialize_with = "string_or_seq")]
+ pub bind: Vec<SocketAddr>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HttpsConfig {
- pub bind: SocketAddr,
+ #[serde(deserialize_with = "string_or_seq")]
+ pub bind: Vec<SocketAddr>,
pub tls_cert: PathBuf,
pub tls_key: PathBuf,
}
@@ -45,6 +47,32 @@ pub struct FileserverConfig {
pub index: bool,
}
+// fall back to expecting a single string and putting that in a 1-length vector
+fn string_or_seq<'de, D>(des: D) -> Result<Vec<SocketAddr>, D::Error>
+where D: Deserializer<'de> {
+ struct StringOrList;
+ impl<'de> Visitor<'de> for StringOrList {
+ type Value = Vec<SocketAddr>;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("sequence or list")
+ }
+
+ fn visit_str<E>(self, val: &str) -> Result<Vec<SocketAddr>, E>
+ where E: Error {
+ let addr = SocketAddr::deserialize(value::StrDeserializer::new(val))?;
+ Ok(vec![addr])
+ }
+
+ fn visit_seq<A>(self, val: A) -> Result<Vec<SocketAddr>, A::Error>
+ where A: SeqAccess<'de> {
+ Vec::<SocketAddr>::deserialize(value::SeqAccessDeserializer::new(val))
+ }
+ }
+
+ des.deserialize_any(StringOrList)
+}
+
impl Config {
pub fn load(path: &str) -> anyhow::Result<Config> {
let raw = read_to_string(path).context("reading config file")?;
diff --git a/src/main.rs b/src/main.rs
index f51cd67..aa13609 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -14,6 +14,7 @@ use crate::{
};
use anyhow::{anyhow, bail, Context, Result};
use error::ServiceError;
+use futures::future::try_join_all;
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
body::Incoming,
@@ -88,14 +89,24 @@ async fn serve_http(state: Arc<State>) -> Result<()> {
Some(n) => n,
None => return Ok(()),
};
- let listener = TcpListener::bind(http_config.bind).await?;
+
+ let listen_futs: Result<Vec<()>> = try_join_all(http_config.bind
+ .iter()
+ .map(|e| async {
+ let l = TcpListener::bind(e.clone()).await?;
+ loop {
+ let (stream, addr) = l.accept().await.context("accepting connection")?;
+ debug!("connection from {addr}");
+ let config = state.clone();
+ tokio::spawn(async move { serve_stream(config, stream, addr).await });
+ }
+ }))
+ .await;
+
info!("serving http");
- loop {
- let (stream, addr) = listener.accept().await.context("accepting connection")?;
- debug!("connection from {addr}");
- let config = state.clone();
- tokio::spawn(async move { serve_stream(config, stream, addr).await });
- }
+
+ listen_futs?;
+ Ok(())
}
async fn serve_https(state: Arc<State>) -> Result<()> {
@@ -116,22 +127,30 @@ async fn serve_https(state: Arc<State>) -> Result<()> {
];
Arc::new(cfg)
};
- let listener = TcpListener::bind(https_config.bind).await?;
let tls_acceptor = Arc::new(TlsAcceptor::from(tls_config));
+ let listen_futs: Result<Vec<()>> = try_join_all(https_config.bind
+ .iter()
+ .map(|e| async {
+ let l = TcpListener::bind(e.clone()).await?;
+ loop {
+ let (stream, addr) = l.accept().await.context("accepting connection")?;
+ let state = state.clone();
+ let tls_acceptor = tls_acceptor.clone();
+ tokio::task::spawn(async move {
+ debug!("connection from {addr}");
+ match tls_acceptor.accept(stream).await {
+ Ok(stream) => serve_stream(state, stream, addr).await,
+ Err(e) => warn!("error accepting tls: {e}"),
+ };
+ });
+ }
+ }))
+ .await;
+
info!("serving https");
- loop {
- let (stream, addr) = listener.accept().await.context("accepting connection")?;
- let state = state.clone();
- let tls_acceptor = tls_acceptor.clone();
- tokio::task::spawn(async move {
- debug!("connection from {addr}");
- match tls_acceptor.accept(stream).await {
- Ok(stream) => serve_stream(state, stream, addr).await,
- Err(e) => warn!("error accepting tls: {e}"),
- };
- });
- }
+ listen_futs?;
+ Ok(())
}
pub async fn serve_stream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(