use crate::{State, Worker}; use anyhow::{Result, bail}; use axum::{ extract::{ State as S, WebSocketUpgrade, ws::{Message, WebSocket}, }, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use log::{debug, warn}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::{collections::HashSet, sync::Arc}; use tokio::{ spawn, sync::{RwLock, mpsc::channel}, }; pub type WorkerID = u64; #[derive(Debug, Deserialize)] #[serde(tag = "t", rename_all = "snake_case")] pub enum WorkerRequest { Register { name: String, task_kinds: Vec, }, Metadata { key: String, data: Map, }, Enqueue { key: String, #[serde(default)] ignore_complete: bool, }, Complete { key: String, #[serde(default)] force: bool, }, Accept, Save, _EchoError { message: String, }, } #[derive(Debug, Serialize)] #[serde(tag = "t", rename_all = "snake_case")] pub enum WorkerResponse { Ok, Work { key: String, data: Map, }, Config { config: Value, }, Error { message: String, }, } pub(crate) async fn worker_websocket( ws: WebSocketUpgrade, S(state): S>>, ) -> impl IntoResponse { ws.on_upgrade(|ws| worker_websocket_inner(ws, state)) } async fn worker_websocket_inner(ws: WebSocket, state: Arc>) { let (mut send, mut recv) = ws.split(); let (tx, mut rx) = channel(16); let worker = { let mut g = state.write().await; tx.send(WorkerResponse::Config { config: g.config.clone(), }) .await .unwrap(); let id = g.worker_id_counter; g.worker_id_counter += 1; g.workers.insert( id, Worker { send: tx, accept: 0, name: "unknown".to_string(), task_kinds: vec![], assigned_tasks: HashSet::new(), }, ); g.send_webui_worker_update(id); id }; let mut send_task = spawn(async move { while let Some(m) = rx.recv().await { if let Err(e) = send .send(Message::Text(serde_json::to_string(&m).unwrap().into())) .await { warn!("error sending response: {e:?}"); break; } } }); let state2 = state.clone(); let mut recv_task = spawn(async move { while let Some(message) = recv.next().await { if let Ok(m) = message { match m { Message::Text(t) => { let req = match serde_json::from_str::(t.as_str()) { Ok(v) => v, Err(e) => { warn!("error parsing request: {e:?}"); WorkerRequest::_EchoError { message: format!("{e:#}"), } } }; if let Err(e) = state.write().await.handle_worker_message(worker, req).await { warn!("error processing request: {e:?}"); state.write().await.send_to_worker( worker, WorkerResponse::Error { message: format!("{e:?}"), }, ); } } _ => (), } } } }); tokio::select! { _ = &mut send_task => recv_task.abort(), _ = &mut recv_task => send_task.abort(), }; { let mut g = state2.write().await; let w = g.workers.remove(&worker).unwrap(); // recycle incomplete tasks for key in w.assigned_tasks { g.loading.remove(&key); g.queue.insert(key.clone()); g.send_webui_task_update(&key); } g.send_webui_worker_removal(worker); } } impl State { pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) { debug!("{w} -> {resp:?}"); if let Err(_) = self.workers[&w].send.try_send(resp) { warn!("worker ws response overflow"); } } pub async fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> { debug!("{w} <- {req:?}"); let worker = self.workers.get_mut(&w).unwrap(); match req { WorkerRequest::Register { name, task_kinds } => { worker.name = name; worker.task_kinds = task_kinds; self.send_webui_worker_update(w); } WorkerRequest::Metadata { key, data } => { let m = self.metadata.entry(key.clone()).or_default(); m.extend(data); self.send_webui_task_update(&key); } WorkerRequest::Enqueue { key, ignore_complete, } => { if ignore_complete { if !self.loading.contains(&key) { self.complete.remove(&key); self.queue.insert(key.clone()); self.send_webui_task_update(&key); self.dispatch_work(); } } else { if !(self.complete.contains(&key) || self.loading.contains(&key)) { self.queue.insert(key.clone()); self.send_webui_task_update(&key); self.dispatch_work(); } } } WorkerRequest::Complete { key, force } => { if force { self.queue.remove(&key); self.loading.remove(&key); self.complete.insert(key.clone()); self.send_webui_task_update(&key); } else { if worker.assigned_tasks.remove(&key) { self.loading.remove(&key); self.complete.insert(key.clone()); self.send_webui_task_update(&key); self.send_webui_worker_update(w); } else { bail!("task was not assigned") } } } WorkerRequest::Accept => { worker.accept += 1; self.dispatch_work(); } WorkerRequest::Save => { self.save().await?; } WorkerRequest::_EchoError { message } => { self.send_to_worker(w, WorkerResponse::Error { message }); } } Ok(()) } pub fn dispatch_work(&mut self) { let mut to_send = Vec::new(); for (id, w) in &mut self.workers { if w.accept >= 1 { for kind in &w.task_kinds { let prefix = format!("{kind}:"); let Some(first) = self.queue.iter().find(|e| e.starts_with(&prefix)).cloned() else { continue; }; w.accept -= 1; w.assigned_tasks.insert(first.clone()); to_send.push((*id, first)); } } } for (w, key) in to_send { self.queue.remove(&key); self.loading.insert(key.clone()); self.send_webui_task_update(&key); self.send_webui_worker_update(w); self.send_to_worker( w, WorkerResponse::Work { data: self.metadata[&key].clone(), key, }, ); } } }