diff options
Diffstat (limited to 'src/worker_ws.rs')
-rw-r--r-- | src/worker_ws.rs | 58 |
1 files changed, 49 insertions, 9 deletions
diff --git a/src/worker_ws.rs b/src/worker_ws.rs index 2d82e4d..7b76476 100644 --- a/src/worker_ws.rs +++ b/src/worker_ws.rs @@ -8,10 +8,10 @@ use axum::{ response::IntoResponse, }; use futures::{SinkExt, StreamExt}; -use log::warn; +use log::{debug, warn}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; use tokio::{ spawn, sync::{RwLock, mpsc::channel}, @@ -24,7 +24,7 @@ pub type WorkerID = u64; pub enum WorkerRequest { Register { name: String, - sources: Vec<String>, + task_kinds: Vec<String>, }, Metadata { key: String, @@ -77,7 +77,8 @@ async fn worker_websocket_inner(ws: WebSocket, state: Arc<RwLock<State>>) { send: tx, accept: 0, name: "unknown".to_string(), - sources: vec![], + task_kinds: vec![], + assigned_tasks: HashSet::new(), }, ); id @@ -133,22 +134,29 @@ async fn worker_websocket_inner(ws: WebSocket, state: Arc<RwLock<State>>) { { let mut g = state2.write().await; - g.workers.remove(&worker).unwrap(); + let w = g.workers.remove(&worker).unwrap(); + // recycle incomplete tasks + for key in w.assigned_tasks { + g.loading.remove(&key); + g.queue.insert(key); + } } } impl State { pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) { + debug!("{w} -> {resp:?}"); if let Err(_) = self.workers[&w].send.try_send(resp) { warn!("worker ws response overflow"); } } pub async fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> { + debug!("{w} <- {req:?}"); let worker = self.workers.get_mut(&w).unwrap(); match req { - WorkerRequest::Register { name, sources } => { + WorkerRequest::Register { name, task_kinds } => { worker.name = name; - worker.sources = sources; + worker.task_kinds = task_kinds; } WorkerRequest::Metadata { key, data } => { let m = self.metadata.entry(key).or_default(); @@ -156,16 +164,19 @@ impl State { } WorkerRequest::Enqueue { key } => { self.queue.insert(key); + self.dispatch_work(); } WorkerRequest::Complete { key } => { - if self.loading.remove(&key) { + if worker.assigned_tasks.remove(&key) { + self.loading.remove(&key); self.complete.insert(key); } else { - bail!("was not loading") + bail!("task was not assigned") } } WorkerRequest::Accept => { worker.accept += 1; + self.dispatch_work(); } WorkerRequest::Save => { self.save().await?; @@ -177,4 +188,33 @@ impl State { Ok(()) } + + pub fn dispatch_work(&mut self) { + let mut to_send = Vec::new(); + for (id, w) in &mut self.workers { + if w.accept >= 1 { + for kind in &w.task_kinds { + let prefix = format!("{kind}:"); + let Some(first) = self.queue.iter().find(|e| e.starts_with(&prefix)).cloned() + else { + continue; + }; + w.accept -= 1; + w.assigned_tasks.insert(first.clone()); + to_send.push((*id, first)); + } + } + } + for (w, key) in to_send { + self.queue.remove(&key); + self.loading.insert(key.clone()); + self.send_to_worker( + w, + WorkerResponse::Work { + data: self.metadata[&key].clone(), + key, + }, + ); + } + } } |