use crate::{State, webui, worker_ws::WorkerID}; use axum::{ extract::{ State as S, WebSocketUpgrade, ws::{Message, WebSocket}, }, response::IntoResponse, }; use log::warn; use serde::Serialize; use serde_json::Map; use std::sync::Arc; use tokio::sync::RwLock; #[derive(Debug, Serialize)] #[serde(tag = "t", rename_all = "snake_case")] pub enum WebuiEvent { Counters { queue: usize, loading: usize, complete: usize, }, UpdateWorker { id: WorkerID, html: String, }, RemoveWorker { id: WorkerID, }, UpdateTask { bin: TaskState, key: String, html: String, }, RemoveTask { key: String, }, } #[derive(Debug, Serialize)] #[serde(rename_all = "snake_case")] pub enum TaskState { Queue, Loading, Complete, } pub(crate) async fn webui_websocket( ws: WebSocketUpgrade, S(state): S>>, ) -> impl IntoResponse { ws.on_upgrade(|ws| webui_websocket_inner(ws, state)) } async fn webui_websocket_inner(mut ws: WebSocket, state: Arc>) { let mut stream = state.read().await.webui_broadcast.subscribe(); while let Ok(ev) = stream.recv().await { if let Err(e) = ws .send(Message::Text(serde_json::to_string(&*ev).unwrap().into())) .await { warn!("error sending update event: {e:?}"); } } } impl State { pub fn send_webui_task_removal(&self, key: &str) { if self.webui_broadcast.receiver_count() > 0 { let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::RemoveTask { key: key.to_owned(), })); } } pub fn send_webui_task_update(&self, key: &str) { if self.webui_broadcast.receiver_count() > 0 { let state = if self.queue.contains(key) { TaskState::Queue } else if self.loading.contains(key) { TaskState::Loading } else { TaskState::Complete }; let default = Map::new(); let class = match state { TaskState::Queue => "task queue", TaskState::Loading => "task loading", TaskState::Complete => "task complete", }; let data = self.metadata.get(key).unwrap_or(&default); let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::UpdateTask { bin: state, key: key.to_owned(), html: webui::Task { class, data, key }.to_string(), })); let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::Counters { queue: self.queue.len(), loading: self.loading.len(), complete: self.complete.len(), })); } } pub fn send_webui_worker_removal(&self, id: WorkerID) { if self.webui_broadcast.receiver_count() > 0 { let _ = self .webui_broadcast .send(Arc::new(WebuiEvent::RemoveWorker { id })); } } pub fn send_webui_worker_update(&self, id: WorkerID) { if self.webui_broadcast.receiver_count() > 0 { let w = &self.workers[&id]; let html = webui::Worker { id, w }.to_string(); let _ = self .webui_broadcast .send(Arc::new(WebuiEvent::UpdateWorker { id, html })); } } }