use anyhow::{anyhow, Context, Result}; use log::debug; use rustls::{ crypto::CryptoProvider, pki_types::{CertificateDer, PrivateKeyDer}, server::{ClientHello, ResolvesServerCert}, sign::CertifiedKey, }; use std::{ collections::HashMap, fs::File, io::BufReader, path::{Path, PathBuf}, sync::Arc, }; use webpki::EndEntityCert; #[derive(Debug)] pub struct CertPool { provider: &'static Arc, domains: HashMap>, } impl Default for CertPool { fn default() -> Self { Self { provider: CryptoProvider::get_default().unwrap(), domains: Default::default(), } } } impl CertPool { pub fn load(roots: &[PathBuf]) -> Result { let mut s = Self::default(); for r in roots { s.load_recursive(&r)?; } Ok(s) } fn load_recursive(&mut self, path: &Path) -> Result<()> { if !path.is_dir() { return Ok(()); } let keypath = path.join("privkey.pem"); let certpath = if path.join("fullchain.pem").exists() { path.join("fullchain.pem") } else { path.join("cert.pem") }; if keypath.exists() && certpath.exists() { let certs = load_certs(&certpath)?; let key = load_private_key(&keypath)?; let skey = self.provider.key_provider.load_private_key(key)?; for c in &certs { let eec = EndEntityCert::try_from(c).unwrap(); for name in eec.valid_dns_names() { debug!("loaded key for {name:?}"); let ck = CertifiedKey::new(certs.clone(), skey.clone()); self.domains.insert(name.to_owned(), Arc::new(ck)); } } } Ok(()) } } impl ResolvesServerCert for CertPool { fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { Some(self.domains.get(client_hello.server_name()?)?.clone()) } } fn load_certs(path: &Path) -> Result>> { let mut reader = BufReader::new(File::open(path).context("reading tls certs")?); let certs = rustls_pemfile::certs(&mut reader) .try_collect::>() .context("parsing tls certs")?; Ok(certs) } fn load_private_key(path: &Path) -> Result> { let mut reader = BufReader::new(File::open(path).context("reading tls private key")?); let keys = rustls_pemfile::private_key(&mut reader).context("parsing tls private key")?; keys.ok_or(anyhow!("no private key found")) }