aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/api.rs14
-rw-r--r--src/main.rs31
-rw-r--r--src/webui_ws.rs109
-rw-r--r--src/worker_ws.rs16
4 files changed, 162 insertions, 8 deletions
diff --git a/src/api.rs b/src/api.rs
new file mode 100644
index 0000000..4fd3888
--- /dev/null
+++ b/src/api.rs
@@ -0,0 +1,14 @@
+use crate::State;
+use axum::{Json, extract::State as S};
+use std::{collections::HashSet, sync::Arc};
+use tokio::sync::RwLock;
+
+pub async fn api_queue_json(S(state): S<Arc<RwLock<State>>>) -> Json<HashSet<String>> {
+ Json(state.read().await.queue.clone())
+}
+pub async fn api_loading_json(S(state): S<Arc<RwLock<State>>>) -> Json<HashSet<String>> {
+ Json(state.read().await.loading.clone())
+}
+pub async fn api_complete_json(S(state): S<Arc<RwLock<State>>>) -> Json<HashSet<String>> {
+ Json(state.read().await.complete.clone())
+}
diff --git a/src/main.rs b/src/main.rs
index 0bba960..f8f4fba 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,8 +1,11 @@
+pub mod api;
pub mod helper;
pub mod webui;
+pub mod webui_ws;
pub mod worker_ws;
use anyhow::Result;
+use api::{api_complete_json, api_loading_json, api_queue_json};
use axum::{Router, routing::get};
use log::{debug, info};
use serde_json::{Map, Value};
@@ -15,9 +18,13 @@ use tokio::{
fs::{File, read_to_string, rename},
io::AsyncWriteExt,
net::TcpListener,
- sync::{RwLock, mpsc::Sender},
+ sync::{
+ RwLock, broadcast,
+ mpsc::{self},
+ },
};
use webui::{webui, webui_style};
+use webui_ws::{WebuiEvent, webui_websocket};
use worker_ws::{WorkerID, WorkerResponse, worker_websocket};
pub struct Worker {
@@ -25,13 +32,13 @@ pub struct Worker {
name: String,
task_kinds: Vec<String>,
assigned_tasks: HashSet<String>,
- send: Sender<WorkerResponse>,
+ send: mpsc::Sender<WorkerResponse>,
}
-#[derive(Default)]
pub struct State {
worker_id_counter: WorkerID,
workers: HashMap<WorkerID, Worker>,
+ webui_broadcast: broadcast::Sender<Arc<WebuiEvent>>,
metadata: HashMap<String, Map<String, Value>>,
queue: HashSet<String>,
@@ -47,13 +54,31 @@ async fn main() -> Result<()> {
let router = Router::new()
.route("/", get(webui))
.route("/style.css", get(webui_style))
+ .route("/webui_ws", get(webui_websocket))
.route("/worker_ws", get(worker_websocket))
+ .route("/api/queue.json", get(api_queue_json))
+ .route("/api/complete.json", get(api_complete_json))
+ .route("/api/loading.json", get(api_loading_json))
.with_state(Arc::new(RwLock::new(state)));
let listener = TcpListener::bind("127.0.0.1:8080").await?;
axum::serve(listener, router).await?;
Ok(())
}
+impl Default for State {
+ fn default() -> Self {
+ Self {
+ worker_id_counter: Default::default(),
+ workers: Default::default(),
+ webui_broadcast: broadcast::channel(1024).0,
+ metadata: Default::default(),
+ queue: Default::default(),
+ loading: Default::default(),
+ complete: Default::default(),
+ }
+ }
+}
+
impl State {
pub async fn load(&mut self) -> Result<()> {
debug!("loading state");
diff --git a/src/webui_ws.rs b/src/webui_ws.rs
new file mode 100644
index 0000000..a1fd348
--- /dev/null
+++ b/src/webui_ws.rs
@@ -0,0 +1,109 @@
+use crate::{State, webui, worker_ws::WorkerID};
+use axum::{
+ extract::{
+ State as S, WebSocketUpgrade,
+ ws::{Message, WebSocket},
+ },
+ response::IntoResponse,
+};
+use log::warn;
+use serde::Serialize;
+use serde_json::Map;
+use std::sync::Arc;
+use tokio::sync::RwLock;
+
+#[derive(Debug, Serialize)]
+#[serde(tag = "t", rename_all = "snake_case")]
+pub enum WebuiEvent {
+ UpdateWorker {
+ id: WorkerID,
+ html: String,
+ },
+ RemoveWorker {
+ id: WorkerID,
+ },
+ UpdateTask {
+ bin: TaskState,
+ key: String,
+ html: String,
+ },
+ RemoveTask {
+ key: String,
+ },
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "snake_case")]
+pub enum TaskState {
+ Queue,
+ Loading,
+ Complete,
+}
+
+pub(crate) async fn webui_websocket(
+ ws: WebSocketUpgrade,
+ S(state): S<Arc<RwLock<State>>>,
+) -> impl IntoResponse {
+ ws.on_upgrade(|ws| webui_websocket_inner(ws, state))
+}
+async fn webui_websocket_inner(mut ws: WebSocket, state: Arc<RwLock<State>>) {
+ let mut stream = state.read().await.webui_broadcast.subscribe();
+ while let Ok(ev) = stream.recv().await {
+ if let Err(e) = ws
+ .send(Message::Text(serde_json::to_string(&*ev).unwrap().into()))
+ .await
+ {
+ warn!("error sending update event: {e:?}");
+ }
+ }
+}
+impl State {
+ pub fn send_webui_task_removal(&self, key: &str) {
+ if self.webui_broadcast.receiver_count() > 0 {
+ let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::RemoveTask {
+ key: key.to_owned(),
+ }));
+ }
+ }
+ pub fn send_webui_task_update(&self, key: &str) {
+ if self.webui_broadcast.receiver_count() > 0 {
+ let state = if self.queue.contains(key) {
+ TaskState::Queue
+ } else if self.loading.contains(key) {
+ TaskState::Loading
+ } else {
+ TaskState::Complete
+ };
+
+ let default = Map::new();
+ let class = match state {
+ TaskState::Queue => "task queue",
+ TaskState::Loading => "task loading",
+ TaskState::Complete => "task complete",
+ };
+ let data = self.metadata.get(key).unwrap_or(&default);
+
+ let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::UpdateTask {
+ bin: state,
+ key: key.to_owned(),
+ html: webui::Task { class, data, key }.to_string(),
+ }));
+ }
+ }
+ pub fn send_webui_worker_removal(&self, id: WorkerID) {
+ if self.webui_broadcast.receiver_count() > 0 {
+ let _ = self
+ .webui_broadcast
+ .send(Arc::new(WebuiEvent::RemoveWorker { id }));
+ }
+ }
+ pub fn send_webui_worker_update(&self, id: WorkerID) {
+ if self.webui_broadcast.receiver_count() > 0 {
+ let w = &self.workers[&id];
+ let html = webui::Worker { id, w }.to_string();
+ let _ = self
+ .webui_broadcast
+ .send(Arc::new(WebuiEvent::UpdateWorker { id, html }));
+ }
+ }
+}
diff --git a/src/worker_ws.rs b/src/worker_ws.rs
index 5f0e47c..526de82 100644
--- a/src/worker_ws.rs
+++ b/src/worker_ws.rs
@@ -163,8 +163,9 @@ impl State {
worker.task_kinds = task_kinds;
}
WorkerRequest::Metadata { key, data } => {
- let m = self.metadata.entry(key).or_default();
+ let m = self.metadata.entry(key.clone()).or_default();
m.extend(data);
+ self.send_webui_task_update(&key);
}
WorkerRequest::Enqueue {
key,
@@ -173,12 +174,14 @@ impl State {
if ignore_complete {
if !self.loading.contains(&key) {
self.complete.remove(&key);
- self.queue.insert(key);
+ self.queue.insert(key.clone());
+ self.send_webui_task_update(&key);
self.dispatch_work();
}
} else {
if !(self.complete.contains(&key) || self.loading.contains(&key)) {
- self.queue.insert(key);
+ self.queue.insert(key.clone());
+ self.send_webui_task_update(&key);
self.dispatch_work();
}
}
@@ -187,11 +190,13 @@ impl State {
if force {
self.queue.remove(&key);
self.loading.remove(&key);
- self.complete.insert(key);
+ self.complete.insert(key.clone());
+ self.send_webui_task_update(&key);
} else {
if worker.assigned_tasks.remove(&key) {
self.loading.remove(&key);
- self.complete.insert(key);
+ self.complete.insert(key.clone());
+ self.send_webui_task_update(&key);
} else {
bail!("task was not assigned")
}
@@ -231,6 +236,7 @@ impl State {
for (w, key) in to_send {
self.queue.remove(&key);
self.loading.insert(key.clone());
+ self.send_webui_task_update(&key);
self.send_to_worker(
w,
WorkerResponse::Work {