aboutsummaryrefslogtreecommitdiff
path: root/src/webui_ws.rs
blob: 6a19e224e47b50f515de84e24862bd23f830a1cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use crate::{State, webui, worker_ws::WorkerID};
use axum::{
    extract::{
        State as S, WebSocketUpgrade,
        ws::{Message, WebSocket},
    },
    response::IntoResponse,
};
use log::warn;
use serde::Serialize;
use serde_json::Map;
use std::sync::Arc;
use tokio::sync::RwLock;

#[derive(Debug, Serialize)]
#[serde(tag = "t", rename_all = "snake_case")]
pub enum WebuiEvent {
    Counters {
        queue: usize,
        loading: usize,
        complete: usize,
    },
    UpdateWorker {
        id: WorkerID,
        html: String,
    },
    RemoveWorker {
        id: WorkerID,
    },
    UpdateTask {
        bin: TaskState,
        key: String,
        html: String,
    },
    RemoveTask {
        key: String,
    },
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskState {
    Queue,
    Loading,
    Complete,
}

pub(crate) async fn webui_websocket(
    ws: WebSocketUpgrade,
    S(state): S<Arc<RwLock<State>>>,
) -> impl IntoResponse {
    ws.on_upgrade(|ws| webui_websocket_inner(ws, state))
}
async fn webui_websocket_inner(mut ws: WebSocket, state: Arc<RwLock<State>>) {
    let mut stream = state.read().await.webui_broadcast.subscribe();
    while let Ok(ev) = stream.recv().await {
        if let Err(e) = ws
            .send(Message::Text(serde_json::to_string(&*ev).unwrap().into()))
            .await
        {
            warn!("error sending update event: {e:?}");
            break;
        }
    }
}
impl State {
    pub fn send_webui_task_removal(&self, key: &str) {
        if self.webui_broadcast.receiver_count() > 0 {
            let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::RemoveTask {
                key: key.to_owned(),
            }));
        }
    }
    pub fn send_webui_task_update(&self, key: &str) {
        if self.webui_broadcast.receiver_count() > 0 {
            let state = if self.queue.contains(key) {
                TaskState::Queue
            } else if self.loading.contains(key) {
                TaskState::Loading
            } else {
                TaskState::Complete
            };

            let default = Map::new();
            let class = match state {
                TaskState::Queue => "task queue",
                TaskState::Loading => "task loading",
                TaskState::Complete => "task complete",
            };
            let data = self.metadata.get(key).unwrap_or(&default);

            let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::UpdateTask {
                bin: state,
                key: key.to_owned(),
                html: webui::Task { class, data, key }.to_string(),
            }));
            let _ = self.webui_broadcast.send(Arc::new(WebuiEvent::Counters {
                queue: self.queue.len(),
                loading: self.loading.len(),
                complete: self.complete.len(),
            }));
        }
    }
    pub fn send_webui_worker_removal(&self, id: WorkerID) {
        if self.webui_broadcast.receiver_count() > 0 {
            let _ = self
                .webui_broadcast
                .send(Arc::new(WebuiEvent::RemoveWorker { id }));
        }
    }
    pub fn send_webui_worker_update(&self, id: WorkerID) {
        if self.webui_broadcast.receiver_count() > 0 {
            let w = &self.workers[&id];
            let html = webui::Worker { id, w }.to_string();
            let _ = self
                .webui_broadcast
                .send(Arc::new(WebuiEvent::UpdateWorker { id, html }));
        }
    }
}