/* This file is part of gnix (https://codeberg.org/metamuffin/gnix) which is licensed under the GNU Affero General Public License (version 3); see /COPYING. Copyright (C) 2025 metamuffin */ use anyhow::{anyhow, Context, Result}; use log::debug; use rustls::{ crypto::CryptoProvider, server::{ClientHello, ResolvesServerCert}, sign::CertifiedKey, }; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, fs::read_to_string, io::Cursor, path::{Path, PathBuf}, sync::Arc, }; use webpki::EndEntityCert; #[derive(Debug)] pub struct CertPool { provider: &'static Arc, domains: HashMap>, wildcards: HashMap>, fallback: Option>, } #[derive(Serialize, Deserialize)] pub struct CertPackage { pub name: String, pub cert: String, pub key: String, } impl Default for CertPool { fn default() -> Self { Self { provider: CryptoProvider::get_default().unwrap(), domains: Default::default(), wildcards: Default::default(), fallback: None, } } } impl CertPackage { fn create_from(path: &Path) -> Result> { let keypath = path.join("privkey.pem"); let certpath = if path.join("fullchain.pem").exists() { path.join("fullchain.pem") } else { path.join("cert.pem") }; Ok(if keypath.exists() && certpath.exists() { debug!("creating cert package at {path:?}"); Some(CertPackage { name: path.to_string_lossy().to_string(), cert: read_to_string(certpath)?, key: read_to_string(keypath)?, }) } else { None }) } pub fn create_from_recursive(roots: &[PathBuf]) -> Result> { let mut out = Vec::new(); for r in roots { Self::create_package_recursion(r, &mut out)?; } Ok(out) } fn create_package_recursion(path: &Path, out: &mut Vec) -> Result<()> { if !path.is_dir() { return Ok(()); } for e in path.read_dir()? { let p = e?.path(); if p.is_dir() { Self::create_package_recursion(&p, out)?; } } out.extend(Self::create_from(path)?); Ok(()) } } impl CertPool { pub fn add_package(&mut self, cp: CertPackage) -> Result<()> { let certs = rustls_pemfile::certs(&mut Cursor::new(cp.cert)) .try_collect::>() .context("parsing tls certs")?; let key = rustls_pemfile::private_key(&mut Cursor::new(cp.key)) .context("parsing tls private key")? .ok_or(anyhow!("private key missing"))?; let skey = self.provider.key_provider.load_private_key(key)?; for c in &certs { let eec = EndEntityCert::try_from(c).unwrap(); for name in eec.valid_dns_names() { let ck = CertifiedKey::new(certs.clone(), skey.clone()); if let Some(name) = name.strip_prefix("*.") { debug!("loaded wildcard key for {name:?}"); self.wildcards.insert(name.to_owned(), Arc::new(ck)); } else { debug!("loaded key for {name:?}"); self.domains.insert(name.to_owned(), Arc::new(ck)); } } } Ok(()) } } impl ResolvesServerCert for CertPool { fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { let sname = client_hello.server_name()?; Some( self.domains .get(sname) .or_else(|| { // Removing first label seems fine since wildcards are not recursive. sname .split_once(".") .and_then(|(_, sname)| self.wildcards.get(sname)) }) .or(self.fallback.as_ref())? .clone(), ) } }