aboutsummaryrefslogtreecommitdiff
path: root/src/worker_ws.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/worker_ws.rs')
-rw-r--r--src/worker_ws.rs58
1 files changed, 49 insertions, 9 deletions
diff --git a/src/worker_ws.rs b/src/worker_ws.rs
index 2d82e4d..7b76476 100644
--- a/src/worker_ws.rs
+++ b/src/worker_ws.rs
@@ -8,10 +8,10 @@ use axum::{
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
-use log::warn;
+use log::{debug, warn};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
-use std::sync::Arc;
+use std::{collections::HashSet, sync::Arc};
use tokio::{
spawn,
sync::{RwLock, mpsc::channel},
@@ -24,7 +24,7 @@ pub type WorkerID = u64;
pub enum WorkerRequest {
Register {
name: String,
- sources: Vec<String>,
+ task_kinds: Vec<String>,
},
Metadata {
key: String,
@@ -77,7 +77,8 @@ async fn worker_websocket_inner(ws: WebSocket, state: Arc<RwLock<State>>) {
send: tx,
accept: 0,
name: "unknown".to_string(),
- sources: vec![],
+ task_kinds: vec![],
+ assigned_tasks: HashSet::new(),
},
);
id
@@ -133,22 +134,29 @@ async fn worker_websocket_inner(ws: WebSocket, state: Arc<RwLock<State>>) {
{
let mut g = state2.write().await;
- g.workers.remove(&worker).unwrap();
+ let w = g.workers.remove(&worker).unwrap();
+ // recycle incomplete tasks
+ for key in w.assigned_tasks {
+ g.loading.remove(&key);
+ g.queue.insert(key);
+ }
}
}
impl State {
pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) {
+ debug!("{w} -> {resp:?}");
if let Err(_) = self.workers[&w].send.try_send(resp) {
warn!("worker ws response overflow");
}
}
pub async fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> {
+ debug!("{w} <- {req:?}");
let worker = self.workers.get_mut(&w).unwrap();
match req {
- WorkerRequest::Register { name, sources } => {
+ WorkerRequest::Register { name, task_kinds } => {
worker.name = name;
- worker.sources = sources;
+ worker.task_kinds = task_kinds;
}
WorkerRequest::Metadata { key, data } => {
let m = self.metadata.entry(key).or_default();
@@ -156,16 +164,19 @@ impl State {
}
WorkerRequest::Enqueue { key } => {
self.queue.insert(key);
+ self.dispatch_work();
}
WorkerRequest::Complete { key } => {
- if self.loading.remove(&key) {
+ if worker.assigned_tasks.remove(&key) {
+ self.loading.remove(&key);
self.complete.insert(key);
} else {
- bail!("was not loading")
+ bail!("task was not assigned")
}
}
WorkerRequest::Accept => {
worker.accept += 1;
+ self.dispatch_work();
}
WorkerRequest::Save => {
self.save().await?;
@@ -177,4 +188,33 @@ impl State {
Ok(())
}
+
+ pub fn dispatch_work(&mut self) {
+ let mut to_send = Vec::new();
+ for (id, w) in &mut self.workers {
+ if w.accept >= 1 {
+ for kind in &w.task_kinds {
+ let prefix = format!("{kind}:");
+ let Some(first) = self.queue.iter().find(|e| e.starts_with(&prefix)).cloned()
+ else {
+ continue;
+ };
+ w.accept -= 1;
+ w.assigned_tasks.insert(first.clone());
+ to_send.push((*id, first));
+ }
+ }
+ }
+ for (w, key) in to_send {
+ self.queue.remove(&key);
+ self.loading.insert(key.clone());
+ self.send_to_worker(
+ w,
+ WorkerResponse::Work {
+ data: self.metadata[&key].clone(),
+ key,
+ },
+ );
+ }
+ }
}