aboutsummaryrefslogtreecommitdiff
path: root/src/worker_ws.rs
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2025-05-17 17:23:29 +0200
committermetamuffin <metamuffin@disroot.org>2025-05-17 17:23:29 +0200
commit1c27a83409a7f51c5d07098cb6ca65bcee870d9c (patch)
treeed166844290443932c4fbf078d1cf65a1d79f833 /src/worker_ws.rs
downloadisda-1c27a83409a7f51c5d07098cb6ca65bcee870d9c.tar
isda-1c27a83409a7f51c5d07098cb6ca65bcee870d9c.tar.bz2
isda-1c27a83409a7f51c5d07098cb6ca65bcee870d9c.tar.zst
a
Diffstat (limited to 'src/worker_ws.rs')
-rw-r--r--src/worker_ws.rs169
1 files changed, 169 insertions, 0 deletions
diff --git a/src/worker_ws.rs b/src/worker_ws.rs
new file mode 100644
index 0000000..f645c11
--- /dev/null
+++ b/src/worker_ws.rs
@@ -0,0 +1,169 @@
+use crate::{State, Worker};
+use anyhow::{Result, bail};
+use axum::{
+ extract::{
+ State as S, WebSocketUpgrade,
+ ws::{Message, WebSocket},
+ },
+ response::IntoResponse,
+};
+use futures::{SinkExt, StreamExt};
+use log::warn;
+use serde::{Deserialize, Serialize};
+use serde_json::{Map, Value};
+use std::sync::Arc;
+use tokio::{
+ spawn,
+ sync::{RwLock, mpsc::channel},
+};
+
+pub type WorkerID = u64;
+
+#[derive(Debug, Deserialize)]
+#[serde(tag = "t", rename_all = "snake_case")]
+pub enum WorkerRequest {
+ Register {
+ name: String,
+ sources: Vec<String>,
+ },
+ Metadata {
+ key: String,
+ data: Map<String, Value>,
+ },
+ Enqueue {
+ key: String,
+ },
+ Complete {
+ key: String,
+ },
+ Accept,
+
+ _EchoError {
+ message: String,
+ },
+}
+
+#[derive(Debug, Serialize)]
+#[serde(tag = "t", rename_all = "snake_case")]
+pub enum WorkerResponse {
+ Ok,
+ Work {
+ key: String,
+ data: Map<String, Value>,
+ },
+ Error {
+ message: String,
+ },
+}
+
+pub(crate) async fn worker_websocket(
+ ws: WebSocketUpgrade,
+ S(state): S<Arc<RwLock<State>>>,
+) -> impl IntoResponse {
+ ws.on_upgrade(|ws| worker_websocket_inner(ws, state))
+}
+async fn worker_websocket_inner(ws: WebSocket, state: Arc<RwLock<State>>) {
+ let (mut send, mut recv) = ws.split();
+ let (tx, mut rx) = channel(16);
+
+ let worker = {
+ let mut g = state.write().await;
+ let id = g.worker_id_counter;
+ g.worker_id_counter += 1;
+ g.workers.insert(
+ id,
+ Worker {
+ send: tx,
+ accept: 0,
+ name: "unknown".to_string(),
+ sources: vec![],
+ },
+ );
+ id
+ };
+
+ let mut send_task = spawn(async move {
+ while let Some(m) = rx.recv().await {
+ if let Err(e) = send
+ .send(Message::Text(serde_json::to_string(&m).unwrap().into()))
+ .await
+ {
+ warn!("error sending response: {e:?}");
+ break;
+ }
+ }
+ });
+ let mut recv_task = spawn(async move {
+ while let Some(message) = recv.next().await {
+ if let Ok(m) = message {
+ match m {
+ Message::Text(t) => {
+ let req = match serde_json::from_str::<WorkerRequest>(t.as_str()) {
+ Ok(v) => v,
+ Err(e) => {
+ warn!("error parsing request: {e:?}");
+ WorkerRequest::_EchoError {
+ message: format!("{e:#}"),
+ }
+ }
+ };
+ if let Err(e) = state.write().await.handle_worker_message(worker, req) {
+ warn!("error processing request: {e:?}");
+ state.write().await.send_to_worker(
+ worker,
+ WorkerResponse::Error {
+ message: format!("{e:?}"),
+ },
+ );
+ }
+ }
+ _ => (),
+ }
+ }
+ }
+ });
+
+ tokio::select! {
+ _ = &mut send_task => recv_task.abort(),
+ _ = &mut recv_task => send_task.abort(),
+ };
+}
+
+impl State {
+ pub fn send_to_worker(&mut self, w: WorkerID, resp: WorkerResponse) {
+ if let Err(_) = self.workers[&w].send.try_send(resp) {
+ warn!("worker ws response overflow");
+ }
+ }
+ pub fn handle_worker_message(&mut self, w: WorkerID, req: WorkerRequest) -> Result<()> {
+ let worker = self.workers.get_mut(&w).unwrap();
+ match req {
+ WorkerRequest::Register { name, sources } => {
+ worker.name = name;
+ worker.sources = sources;
+ }
+ WorkerRequest::Metadata { key, data } => {
+ let m = self.metadata.entry(key).or_default();
+ m.extend(data);
+ }
+ WorkerRequest::Enqueue { key } => {
+ self.queue.insert(key);
+ }
+ WorkerRequest::Complete { key } => {
+ if self.loading.remove(&key) {
+ self.complete.insert(key);
+ } else {
+ bail!("was not loading")
+ }
+ }
+ WorkerRequest::Accept => {
+ worker.accept += 1;
+ }
+ WorkerRequest::_EchoError { message } => {
+ self.send_to_worker(w, WorkerResponse::Error { message });
+ }
+ }
+
+ Ok(())
+ }
+}