use crate::{State, Worker}; use anyhow::{Result, bail}; use axum::{ extract::{ State as S, WebSocketUpgrade, ws::{Message, WebSocket}, }, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use log::warn; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::sync::Arc; use tokio::{ spawn, sync::{RwLock, mpsc::channel}, }; pub type WorkerID = u64; #[derive(Debug, Deserialize)] #[serde(tag = "t", rename_all = "snake_case")] pub enum WorkerRequest { Register { name: String, sources: Vec, }, Metadata { key: String, data: Map, }, Enqueue { key: String, }, Complete { key: String, }, Accept, Save, _EchoError { message: String, }, } #[derive(Debug, Serialize)] #[serde(tag = "t", rename_all = "snake_case")] pub enum WorkerResponse { Ok, Work { key: String, data: Map, }, Error { message: String, }, } pub(crate) async fn worker_websocket( ws: WebSocketUpgrade, S(state): S>>, ) -> impl IntoResponse { ws.on_upgrade(|ws| worker_websocket_inner(ws, state)) } async fn worker_websocket_inner(ws: WebSocket, state: Arc>) { let (mut send, mut recv) = ws.split(); let (tx, mut rx) = channel(16); let worker = { let mut g = state.write().await; let id = g.worker_id_counter; g.worker_id_counter += 1; g.workers.insert( id, Worker { send: tx, accept: 0, name: "unknown".to_string(), sources: vec![], }, ); id }; let mut send_task = spawn(async move { while let Some(m) = rx.recv().await { if let Err(e) = send .send(Message::Text(serde_json::to_string(&m).unwrap().into())) .await { warn!("error sending response: {e:?}"); break; } } }); let state2 = state.clone(); let mut recv_task = spawn(async move { while let Some(message) = recv.next().await { if let Ok(m) = message { match m { Message::Text(t) => { let req = match serde_json::from_str::(t.as_str()) { Ok(v) => v, Err(e) => { warn!("error parsing request: {e:?}"); WorkerRequest::_EchoError { message: format!("{e:#}"), } } }; if let Err(e) = state.write().await.handle_worker_message(worker, req).await { warn!("error processing request: {e:?}"); state.write().await.send_to_worker( worker, WorkerResponse::Error { message: format!("{e:?}"), }, ); } } _ => (), } } } }); tokio::select! { _ = &mut send_task => recv_task.abort(), _ = &mut recv_task => send_task.abort(), }; { let mut g = state2.write().await; g.workers.remove(&worker).unwrap(); } } impl State { pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) { if let Err(_) = self.workers[&w].send.try_send(resp) { warn!("worker ws response overflow"); } } pub async fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> { let worker = self.workers.get_mut(&w).unwrap(); match req { WorkerRequest::Register { name, sources } => { worker.name = name; worker.sources = sources; } WorkerRequest::Metadata { key, data } => { let m = self.metadata.entry(key).or_default(); m.extend(data); } WorkerRequest::Enqueue { key } => { self.queue.insert(key); } WorkerRequest::Complete { key } => { if self.loading.remove(&key) { self.complete.insert(key); } else { bail!("was not loading") } } WorkerRequest::Accept => { worker.accept += 1; } WorkerRequest::Save => { self.save().await?; } WorkerRequest::_EchoError { message } => { self.send_to_worker(w, WorkerResponse::Error { message }); } } Ok(()) } }