aboutsummaryrefslogtreecommitdiff
path: root/src/state.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/state.rs')
-rw-r--r--src/state.rs105
1 files changed, 105 insertions, 0 deletions
diff --git a/src/state.rs b/src/state.rs
new file mode 100644
index 0000000..4de21b4
--- /dev/null
+++ b/src/state.rs
@@ -0,0 +1,105 @@
+use anyhow::Context;
+use redb::{Database, ReadableTable, TableDefinition};
+use rocket::tokio::sync::Mutex;
+use serde::Deserialize;
+use std::{collections::HashMap, net::IpAddr, path::PathBuf};
+
+#[derive(Deserialize)]
+pub struct AdInfo {
+ pub image: PathBuf,
+ pub target: String,
+}
+
+#[derive(Deserialize)]
+pub struct Config {
+ bloom_filter_size: usize,
+ impression_weight_falloff: f64,
+ pub image_base: PathBuf,
+ database_path: PathBuf,
+ pub ads: HashMap<String, AdInfo>,
+}
+
+pub struct Logic {
+ pub config: Config,
+ database: Database,
+ impressions_by_addr: Mutex<Vec<u16>>,
+ pub ad_keys: Vec<String>,
+}
+
+static T_TOTAL: TableDefinition<'static, (), u128> = TableDefinition::new("t");
+static T_IMPRESSIONS_RAW: TableDefinition<'static, &str, u128> = TableDefinition::new("ir");
+static T_IMPRESSIONS_WEIGHTED: TableDefinition<'static, &str, f64> = TableDefinition::new("iw");
+static T_ADS: TableDefinition<'static, &str, u128> = TableDefinition::new("a");
+
+impl Logic {
+ pub fn new(config: Config) -> Self {
+ Self {
+ impressions_by_addr: vec![0; config.bloom_filter_size].into(),
+ database: {
+ let db = Database::create(&config.database_path).expect("database open failed");
+ {
+ let txn = db.begin_write().unwrap();
+ txn.open_table(T_IMPRESSIONS_RAW).unwrap();
+ txn.open_table(T_IMPRESSIONS_WEIGHTED).unwrap();
+ txn.open_table(T_ADS).unwrap();
+ }
+ db
+ },
+ ad_keys: config.ads.keys().map(String::from).collect(),
+ config,
+ }
+ }
+
+ pub async fn register_impression(
+ &self,
+ site: &str,
+ adid: &str,
+ address: IpAddr,
+ ) -> anyhow::Result<()> {
+ let num_impressions = {
+ let mut bloom = self.impressions_by_addr.lock().await;
+ let ind = (xorshift(xorshift(xorshift(
+ match address {
+ IpAddr::V4(a) => a.to_ipv6_mapped(),
+ IpAddr::V6(a) => a,
+ }
+ .to_bits() as u64,
+ ))) % bloom.len() as u64) as usize;
+ bloom[ind] = bloom[ind].saturating_add(1);
+ bloom[ind]
+ } as f64;
+
+ let weight = self.config.impression_weight_falloff.powf(num_impressions);
+
+ let txn = self.database.begin_write().context("database failure")?;
+ {
+ let mut raw = txn.open_table(T_TOTAL)?;
+ let v = raw.get(())?.map(|g| g.value()).unwrap_or_default();
+ raw.insert((), v + 1)?;
+ }
+ {
+ let mut raw = txn.open_table(T_IMPRESSIONS_RAW)?;
+ let v = raw.get(site)?.map(|g| g.value()).unwrap_or_default();
+ raw.insert(site, v + 1)?;
+ }
+ {
+ let mut raw = txn.open_table(T_IMPRESSIONS_WEIGHTED)?;
+ let v = raw.get(site)?.map(|g| g.value()).unwrap_or_default();
+ raw.insert(site, v + weight)?;
+ }
+ {
+ let mut raw = txn.open_table(T_ADS)?;
+ let v = raw.get(site)?.map(|g| g.value()).unwrap_or_default();
+ raw.insert(adid, v + 1)?;
+ }
+ txn.commit().context("database failure")?;
+ Ok(())
+ }
+}
+
+fn xorshift(mut x: u64) -> u64 {
+ x ^= x << 13;
+ x ^= x >> 7;
+ x ^= x << 17;
+ x
+}