diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 78 | ||||
-rw-r--r-- | src/ui.rs | 0 | ||||
-rw-r--r-- | src/worker_ws.rs | 169 |
3 files changed, 247 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..a08ae6b --- /dev/null +++ b/src/main.rs @@ -0,0 +1,78 @@ +pub mod ui; +pub mod worker_ws; + +use anyhow::Result; +use axum::{Router, routing::get}; +use serde_json::{Map, Value}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; +use tokio::{ + fs::{File, read_to_string, rename}, + io::AsyncWriteExt, + net::TcpListener, + sync::{RwLock, mpsc::Sender}, +}; +use worker_ws::{WorkerID, WorkerResponse, worker_websocket}; + +struct Worker { + accept: usize, + name: String, + sources: Vec<String>, + send: Sender<WorkerResponse>, +} + +#[derive(Default)] +struct State { + worker_id_counter: WorkerID, + workers: HashMap<WorkerID, Worker>, + + metadata: HashMap<String, Map<String, Value>>, + queue: HashSet<String>, + loading: HashSet<String>, + complete: HashSet<String>, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init_from_env("LOG"); + let mut state = State::default(); + state.load().await?; + let router = Router::new() + .route("/", get(async || "Hello world!")) + .route("/worker_ws", get(worker_websocket)) + .with_state(Arc::new(RwLock::new(state))); + let listener = TcpListener::bind("127.0.0.1:8080").await?; + axum::serve(listener, router).await?; + Ok(()) +} + +impl State { + pub async fn load(&mut self) -> Result<()> { + self.metadata = serde_json::from_str(&read_to_string("metadata.json").await?)?; + self.queue = serde_json::from_str(&read_to_string("queue.json").await?)?; + self.complete = serde_json::from_str(&read_to_string("complete.json").await?)?; + Ok(()) + } + pub async fn save(&mut self) -> Result<()> { + File::create("metadata.json~") + .await? + .write_all(&serde_json::to_vec(&self.metadata)?) + .await?; + File::create("queue.json~") + .await? + .write_all(&serde_json::to_vec(&self.queue)?) + .await?; + File::create("complete.json~") + .await? + .write_all(&serde_json::to_vec(&self.complete)?) + .await?; + + rename("metadata.json~", "metadata.json").await?; + rename("queue.json~", "queue.json").await?; + rename("complete.json~", "complete.json").await?; + + Ok(()) + } +} diff --git a/src/ui.rs b/src/ui.rs new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/ui.rs diff --git a/src/worker_ws.rs b/src/worker_ws.rs new file mode 100644 index 0000000..f645c11 --- /dev/null +++ b/src/worker_ws.rs @@ -0,0 +1,169 @@ +use crate::{State, Worker}; +use anyhow::{Result, bail}; +use axum::{ + extract::{ + State as S, WebSocketUpgrade, + ws::{Message, WebSocket}, + }, + response::IntoResponse, +}; +use futures::{SinkExt, StreamExt}; +use log::warn; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use std::sync::Arc; +use tokio::{ + spawn, + sync::{RwLock, mpsc::channel}, +}; + +pub type WorkerID = u64; + +#[derive(Debug, Deserialize)] +#[serde(tag = "t", rename_all = "snake_case")] +pub enum WorkerRequest { + Register { + name: String, + sources: Vec<String>, + }, + Metadata { + key: String, + data: Map<String, Value>, + }, + Enqueue { + key: String, + }, + Complete { + key: String, + }, + Accept, + + _EchoError { + message: String, + }, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "t", rename_all = "snake_case")] +pub enum WorkerResponse { + Ok, + Work { + key: String, + data: Map<String, Value>, + }, + Error { + message: String, + }, +} + +pub(crate) async fn worker_websocket( + ws: WebSocketUpgrade, + S(state): S<Arc<RwLock<State>>>, +) -> impl IntoResponse { + ws.on_upgrade(|ws| worker_websocket_inner(ws, state)) +} +async fn worker_websocket_inner(ws: WebSocket, state: Arc<RwLock<State>>) { + let (mut send, mut recv) = ws.split(); + let (tx, mut rx) = channel(16); + + let worker = { + let mut g = state.write().await; + let id = g.worker_id_counter; + g.worker_id_counter += 1; + g.workers.insert( + id, + Worker { + send: tx, + accept: 0, + name: "unknown".to_string(), + sources: vec![], + }, + ); + id + }; + + let mut send_task = spawn(async move { + while let Some(m) = rx.recv().await { + if let Err(e) = send + .send(Message::Text(serde_json::to_string(&m).unwrap().into())) + .await + { + warn!("error sending response: {e:?}"); + break; + } + } + }); + let mut recv_task = spawn(async move { + while let Some(message) = recv.next().await { + if let Ok(m) = message { + match m { + Message::Text(t) => { + let req = match serde_json::from_str::<WorkerRequest>(t.as_str()) { + Ok(v) => v, + Err(e) => { + warn!("error parsing request: {e:?}"); + WorkerRequest::_EchoError { + message: format!("{e:#}"), + } + } + }; + if let Err(e) = state.write().await.handle_worker_message(worker, req) { + warn!("error processing request: {e:?}"); + state.write().await.send_to_worker( + worker, + WorkerResponse::Error { + message: format!("{e:?}"), + }, + ); + } + } + _ => (), + } + } + } + }); + + tokio::select! { + _ = &mut send_task => recv_task.abort(), + _ = &mut recv_task => send_task.abort(), + }; +} + +impl State { + pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) { + if let Err(_) = self.workers[&w].send.try_send(resp) { + warn!("worker ws response overflow"); + } + } + pub fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> { + let worker = self.workers.get_mut(&w).unwrap(); + match req { + WorkerRequest::Register { name, sources } => { + worker.name = name; + worker.sources = sources; + } + WorkerRequest::Metadata { key, data } => { + let m = self.metadata.entry(key).or_default(); + m.extend(data); + } + WorkerRequest::Enqueue { key } => { + self.queue.insert(key); + } + WorkerRequest::Complete { key } => { + if self.loading.remove(&key) { + self.complete.insert(key); + } else { + bail!("was not loading") + } + } + WorkerRequest::Accept => { + worker.accept += 1; + } + WorkerRequest::_EchoError { message } => { + self.send_to_worker(w, WorkerResponse::Error { message }); + } + } + + Ok(()) + } +} |