diff options
Diffstat (limited to 'src/state.rs')
-rw-r--r-- | src/state.rs | 178 |
1 files changed, 131 insertions, 47 deletions
diff --git a/src/state.rs b/src/state.rs index 368e76b..ff5509b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,8 +1,15 @@ -use anyhow::Context; +use anyhow::{bail, Context}; +use log::error; use redb::{Database, ReadableTable, TableDefinition}; -use rocket::tokio::sync::Mutex; +use rocket::tokio::{ + sync::mpsc::{self, Receiver, Sender}, + task, + time::{timeout_at, Instant}, +}; use serde::Deserialize; +use std::time::Duration; use std::{collections::HashMap, net::IpAddr, path::PathBuf}; +use std::{process::exit, sync::Arc}; #[derive(Deserialize)] pub struct AdInfo { @@ -23,31 +30,131 @@ pub struct Config { pub struct Logic { pub config: Config, database: Database, - impressions_by_addr: Mutex<Vec<u16>>, pub ad_keys: Vec<String>, + event_dispatch: Sender<ImpressionEvent>, } -static T_TOTAL: TableDefinition<'static, (), u128> = TableDefinition::new("t"); -static T_IMPRESSIONS_RAW: TableDefinition<'static, &str, u128> = TableDefinition::new("ir"); +struct ImpressionEvent { + site: String, + adid: String, + address_hash: u64, +} + +static T_TOTAL: TableDefinition<'static, (), u64> = TableDefinition::new("t"); +static T_IMPRESSIONS_RAW: TableDefinition<'static, &str, u64> = TableDefinition::new("ir"); static T_IMPRESSIONS_WEIGHTED: TableDefinition<'static, &str, f64> = TableDefinition::new("iw"); -static T_ADS: TableDefinition<'static, &str, u128> = TableDefinition::new("a"); +static T_IMPRESSIONS_ADS: TableDefinition<'static, &str, u64> = TableDefinition::new("ia"); impl Logic { - pub fn new(config: Config) -> Self { - Self { - impressions_by_addr: vec![0; config.bloom_filter_size].into(), + pub fn new(config: Config) -> Arc<Self> { + let (tx, rx) = mpsc::channel(4096); + let state = Arc::new(Self { database: { let db = Database::create(&config.database_path).expect("database open failed"); { let txn = db.begin_write().unwrap(); txn.open_table(T_IMPRESSIONS_RAW).unwrap(); txn.open_table(T_IMPRESSIONS_WEIGHTED).unwrap(); - txn.open_table(T_ADS).unwrap(); + txn.open_table(T_IMPRESSIONS_ADS).unwrap(); } db }, ad_keys: config.ads.keys().map(String::from).collect(), config, + event_dispatch: tx, + }); + + { + let state = state.clone(); + task::spawn(async move { + if let Err(e) = state.commit_db(rx).await { + error!("{e:?}"); + exit(1) + } + }); + } + state + } + + async fn commit_db(&self, mut rx: Receiver<ImpressionEvent>) -> anyhow::Result<()> { + let mut deadline = None; + let mut impressions_by_addr = vec![0u16; self.config.bloom_filter_size]; + let mut imp_raw = HashMap::<String, u64>::new(); + let mut imp_weighted = HashMap::<String, f64>::new(); + let mut imp_ads = HashMap::<String, u64>::new(); + let mut total = 0; + loop { + while Instant::now() < deadline.unwrap_or(Instant::now() + Duration::from_days(1)) { + match timeout_at( + deadline.unwrap_or(Instant::now() + Duration::from_days(1)), + rx.recv(), + ) + .await + { + Ok(Some(ImpressionEvent { + site, + adid, + address_hash, + })) => { + let num_impressions = { + let ind = (address_hash % impressions_by_addr.len() as u64) as usize; + impressions_by_addr[ind] = impressions_by_addr[ind].saturating_add(1); + impressions_by_addr[ind] + } as f64; + let weight = self.config.impression_weight_falloff.powf(num_impressions); + + *imp_ads.entry(adid).or_default() += 1; + *imp_raw.entry(site.clone()).or_default() += 1; + *imp_weighted.entry(site).or_default() += weight; + total += 1; + if deadline.is_none() { + deadline = Some(Instant::now() + Duration::from_secs(10)); + } + } + Ok(None) => bail!("receiver end?!"), + Err(_) => {} + } + } + + let txn = self.database.begin_write().context("database failure")?; + { + let mut raw = txn.open_table(T_TOTAL)?; + let v = raw.get(())?.map(|g| g.value()).unwrap_or_default(); + raw.insert((), v + 1)?; + } + + { + let mut raw = txn.open_table(T_TOTAL)?; + let v = raw.get(())?.map(|g| g.value()).unwrap_or_default(); + raw.insert((), v + total)?; + total = 0; + } + for (site, amount) in imp_raw.drain() { + let mut raw = txn.open_table(T_IMPRESSIONS_RAW)?; + let v = raw + .get(site.as_str())? + .map(|g| g.value()) + .unwrap_or_default(); + raw.insert(site.as_str(), v + amount)?; + } + for (site, amount) in imp_weighted.drain() { + let mut raw = txn.open_table(T_IMPRESSIONS_WEIGHTED)?; + let v = raw + .get(site.as_str())? + .map(|g| g.value()) + .unwrap_or_default(); + raw.insert(site.as_str(), v + amount)?; + } + for (adid, amount) in imp_ads.drain() { + let mut raw = txn.open_table(T_IMPRESSIONS_ADS)?; + let v = raw + .get(adid.as_str())? + .map(|g| g.value()) + .unwrap_or_default(); + raw.insert(adid.as_str(), v + amount)?; + } + txn.commit().context("database failure")?; + deadline = None; } } @@ -57,43 +164,20 @@ impl Logic { adid: &str, address: IpAddr, ) -> anyhow::Result<()> { - let num_impressions = { - let mut bloom = self.impressions_by_addr.lock().await; - let ind = (xorshift(xorshift(xorshift( - match address { - IpAddr::V4(a) => a.to_ipv6_mapped(), - IpAddr::V6(a) => a, - } - .to_bits() as u64, - ))) % bloom.len() as u64) as usize; - bloom[ind] = bloom[ind].saturating_add(1); - bloom[ind] - } as f64; - - let weight = self.config.impression_weight_falloff.powf(num_impressions); - - let txn = self.database.begin_write().context("database failure")?; - { - let mut raw = txn.open_table(T_TOTAL)?; - let v = raw.get(())?.map(|g| g.value()).unwrap_or_default(); - raw.insert((), v + 1)?; - } - { - let mut raw = txn.open_table(T_IMPRESSIONS_RAW)?; - let v = raw.get(site)?.map(|g| g.value()).unwrap_or_default(); - raw.insert(site, v + 1)?; - } - { - let mut raw = txn.open_table(T_IMPRESSIONS_WEIGHTED)?; - let v = raw.get(site)?.map(|g| g.value()).unwrap_or_default(); - raw.insert(site, v + weight)?; - } - { - let mut raw = txn.open_table(T_ADS)?; - let v = raw.get(site)?.map(|g| g.value()).unwrap_or_default(); - raw.insert(adid, v + 1)?; - } - txn.commit().context("database failure")?; + let address_hash = xorshift(xorshift(xorshift( + match address { + IpAddr::V4(a) => a.to_ipv6_mapped(), + IpAddr::V6(a) => a, + } + .to_bits() as u64, + ))); + self.event_dispatch + .send(ImpressionEvent { + address_hash, + adid: adid.to_owned(), + site: site.to_owned(), + }) + .await?; Ok(()) } } |