From 1c27a83409a7f51c5d07098cb6ca65bcee870d9c Mon Sep 17 00:00:00 2001 From: metamuffin Date: Sat, 17 May 2025 17:23:29 +0200 Subject: a --- src/worker_ws.rs | 169 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 src/worker_ws.rs (limited to 'src/worker_ws.rs') diff --git a/src/worker_ws.rs b/src/worker_ws.rs new file mode 100644 index 0000000..f645c11 --- /dev/null +++ b/src/worker_ws.rs @@ -0,0 +1,169 @@ +use crate::{State, Worker}; +use anyhow::{Result, bail}; +use axum::{ + extract::{ + State as S, WebSocketUpgrade, + ws::{Message, WebSocket}, + }, + response::IntoResponse, +}; +use futures::{SinkExt, StreamExt}; +use log::warn; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use std::sync::Arc; +use tokio::{ + spawn, + sync::{RwLock, mpsc::channel}, +}; + +pub type WorkerID = u64; + +#[derive(Debug, Deserialize)] +#[serde(tag = "t", rename_all = "snake_case")] +pub enum WorkerRequest { + Register { + name: String, + sources: Vec, + }, + Metadata { + key: String, + data: Map, + }, + Enqueue { + key: String, + }, + Complete { + key: String, + }, + Accept, + + _EchoError { + message: String, + }, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "t", rename_all = "snake_case")] +pub enum WorkerResponse { + Ok, + Work { + key: String, + data: Map, + }, + Error { + message: String, + }, +} + +pub(crate) async fn worker_websocket( + ws: WebSocketUpgrade, + S(state): S>>, +) -> impl IntoResponse { + ws.on_upgrade(|ws| worker_websocket_inner(ws, state)) +} +async fn worker_websocket_inner(ws: WebSocket, state: Arc>) { + let (mut send, mut recv) = ws.split(); + let (tx, mut rx) = channel(16); + + let worker = { + let mut g = state.write().await; + let id = g.worker_id_counter; + g.worker_id_counter += 1; + g.workers.insert( + id, + Worker { + send: tx, + accept: 0, + name: "unknown".to_string(), + sources: vec![], + }, + ); + id + }; + + let mut send_task = spawn(async move { + while let Some(m) = rx.recv().await { + if let Err(e) = send + .send(Message::Text(serde_json::to_string(&m).unwrap().into())) + .await + { + warn!("error sending response: {e:?}"); + break; + } + } + }); + let mut recv_task = spawn(async move { + while let Some(message) = recv.next().await { + if let Ok(m) = message { + match m { + Message::Text(t) => { + let req = match serde_json::from_str::(t.as_str()) { + Ok(v) => v, + Err(e) => { + warn!("error parsing request: {e:?}"); + WorkerRequest::_EchoError { + message: format!("{e:#}"), + } + } + }; + if let Err(e) = state.write().await.handle_worker_message(worker, req) { + warn!("error processing request: {e:?}"); + state.write().await.send_to_worker( + worker, + WorkerResponse::Error { + message: format!("{e:?}"), + }, + ); + } + } + _ => (), + } + } + } + }); + + tokio::select! { + _ = &mut send_task => recv_task.abort(), + _ = &mut recv_task => send_task.abort(), + }; +} + +impl State { + pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) { + if let Err(_) = self.workers[&w].send.try_send(resp) { + warn!("worker ws response overflow"); + } + } + pub fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> { + let worker = self.workers.get_mut(&w).unwrap(); + match req { + WorkerRequest::Register { name, sources } => { + worker.name = name; + worker.sources = sources; + } + WorkerRequest::Metadata { key, data } => { + let m = self.metadata.entry(key).or_default(); + m.extend(data); + } + WorkerRequest::Enqueue { key } => { + self.queue.insert(key); + } + WorkerRequest::Complete { key } => { + if self.loading.remove(&key) { + self.complete.insert(key); + } else { + bail!("was not loading") + } + } + WorkerRequest::Accept => { + worker.accept += 1; + } + WorkerRequest::_EchoError { message } => { + self.send_to_worker(w, WorkerResponse::Error { message }); + } + } + + Ok(()) + } +} -- cgit v1.2.3-70-g09d2