diff options
Diffstat (limited to 'server/src')
-rw-r--r-- | server/src/main.rs | 19 | ||||
-rw-r--r-- | server/src/room.rs | 66 |
2 files changed, 62 insertions, 23 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index 6a8f11d..268e2a4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,18 +1,19 @@ pub mod protocol; pub mod room; -use chashmap::CHashMap; use hyper::StatusCode; use listenfd::ListenFd; use log::error; use room::Room; +use std::collections::HashMap; use std::convert::Infallible; use std::sync::Arc; +use tokio::sync::RwLock; use warp::hyper::Server; use warp::ws::WebSocket; use warp::{Filter, Rejection, Reply}; -type Rooms = Arc<CHashMap<String, Arc<Room>>>; +type Rooms = Arc<RwLock<HashMap<String, Arc<Room>>>>; fn main() { tokio::runtime::Builder::new_multi_thread() @@ -82,16 +83,20 @@ async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> { fn signaling_connect(rname: String, rooms: Rooms, ws: warp::ws::Ws) -> impl Reply { async fn inner(sock: WebSocket, rname: String, rooms: Rooms) { - let room = match rooms.get(&rname) { - Some(r) => r, + let guard = rooms.read().await; + let room = match guard.get(&rname) { + Some(r) => r.to_owned(), None => { - rooms.insert(rname.to_owned(), Default::default()); - rooms.get(&rname).unwrap() // TODO never expect this to always work!! + drop(guard); // make sure read-lock is dropped to avoid deadlock + let mut guard = rooms.write().await; + guard.insert(rname.to_owned(), Default::default()); + guard.get(&rname).unwrap().to_owned() // TODO never expect this to always work!! } }; + room.client_connect(sock).await; if room.should_remove().await { - rooms.remove(&rname); + rooms.write().await.remove(&rname); } } ws.on_upgrade(move |sock| inner(sock, rname, rooms)) diff --git a/server/src/room.rs b/server/src/room.rs index a47d2e5..db545a6 100644 --- a/server/src/room.rs +++ b/server/src/room.rs @@ -1,7 +1,7 @@ use crate::protocol::{ClientboundPacket, ServerboundPacket}; use futures_util::{SinkExt, StreamExt, TryFutureExt}; -use log::error; -use std::collections::HashMap; +use log::{debug, error}; +use std::{collections::HashMap, sync::atomic::AtomicUsize}; use tokio::sync::{mpsc, RwLock}; use warp::ws::{Message, WebSocket}; @@ -13,17 +13,36 @@ pub struct Client { #[derive(Debug, Default)] pub struct Room { + pub id_counter: AtomicUsize, pub clients: RwLock<HashMap<usize, Client>>, } impl Room { pub async fn client_connect(&self, ws: WebSocket) { + debug!("new client connected"); let (mut user_ws_tx, mut user_ws_rx) = ws.split(); let (tx, mut rx) = mpsc::unbounded_channel(); + let mut g = self.clients.write().await; + // ensure write guard to client exists when using id_counter + let id = self + .id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let name = format!("user no. {id}"); + g.insert( + id, + Client { + out: tx, + name: name.clone(), + }, + ); + drop(g); + debug!("assigned id={id}, init connection"); + tokio::task::spawn(async move { while let Some(packet) = rx.recv().await { + debug!("{id} -> {packet:?}"); user_ws_tx .send(Message::text(serde_json::to_string(&packet).unwrap())) .unwrap_or_else(|e| { @@ -33,20 +52,33 @@ impl Room { } }); - let mut g = self.clients.write().await; - let id = g.len(); - let name = format!("user no. {id}"); - g.insert( + self.send_to_client( id, - Client { - out: tx, - name: name.clone(), + ClientboundPacket::Init { + your_id: id, + version: format!("keks-meet {}", env!("CARGO_PKG_VERSION")), }, - ); - drop(g); + ) + .await; - self.broadcast(id, ClientboundPacket::ClientJoin { id, name }) + // send join of this client to all clients + self.broadcast(None, ClientboundPacket::ClientJoin { id, name }) .await; + // send join of all other clients to this one + for (&cid, c) in self.clients.read().await.iter() { + // skip self + if cid != id { + self.send_to_client( + id, + ClientboundPacket::ClientJoin { + id: cid, + name: c.name.clone(), + }, + ) + .await; + } + } + debug!("client should be ready!"); while let Some(result) = user_ws_rx.next().await { let msg = match result { @@ -64,16 +96,18 @@ impl Room { break; } }; + debug!("{id} <- {p:?}"); self.client_message(id, p).await; - }; + } } - self.clients.write().await.remove(&id); + self.broadcast(Some(id), ClientboundPacket::ClientLeave { id }) + .await; } - pub async fn broadcast(&self, sender: usize, packet: ClientboundPacket) { + pub async fn broadcast(&self, sender: Option<usize>, packet: ClientboundPacket) { for (&id, tx) in self.clients.read().await.iter() { - if sender != id { + if sender != Some(id) { let _ = tx.out.send(packet.clone()); } } |