use futures_util::{SinkExt, StreamExt, TryFutureExt}; use listenfd::ListenFd; use log::error; use std::collections::HashMap; use std::convert::Infallible; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; use tokio::sync::{mpsc, RwLock}; use warp::hyper::Server; use warp::ws::{Message, WebSocket}; use warp::Filter; static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1); type Users = Arc>>>; #[tokio::main] async fn main() { env_logger::init_from_env("LOG"); let users = Users::default(); let users = warp::any().map(move || users.clone()); let signaling = warp::path("signaling") .and(warp::ws()) .and(users) .map(|ws: warp::ws::Ws, users| ws.on_upgrade(move |socket| user_connected(socket, users))); let static_routes = { let index = warp::get() .and(warp::path!()) .and(warp::fs::file("../client-web/public/index.html")); let assets = warp::path("_assets").and(warp::fs::dir("../client-web/public/assets")); index }; let routes = static_routes.or(signaling).or(four_oh_four_tm); // if listender fd is passed from the outside world, use it. let mut listenfd = ListenFd::from_env(); let server = if let Some(l) = listenfd.take_tcp_listener(0).unwrap() { Server::from_tcp(l).unwrap() } else { Server::bind(&([127, 0, 0, 1], 3030).into()) }; let service = warp::service(routes); server .serve(hyper::service::make_service_fn(|_| { let service = service.clone(); async move { Ok::<_, Infallible>(service) } })) .await .unwrap(); } async fn user_connected(ws: WebSocket, users: Users) { let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed); eprintln!("new chat user: {}", my_id); let (mut user_ws_tx, mut user_ws_rx) = ws.split(); let (tx, mut rx) = mpsc::unbounded_channel(); tokio::task::spawn(async move { while let Some(message) = rx.recv().await { user_ws_tx .send(message) .unwrap_or_else(|e| { eprintln!("websocket send error: {}", e); }) .await; } }); users.write().await.insert(my_id, tx); while let Some(result) = user_ws_rx.next().await { let msg = match result { Ok(msg) => msg, Err(e) => { error!("websocket error(uid={my_id}): {e}"); break; } }; user_message(my_id, msg, &users).await; } users.write().await.remove(&my_id); } async fn user_message(my_id: usize, msg: Message, users: &Users) { // Skip any non-Text messages... let msg = if let Ok(s) = msg.to_str() { s } else { return; }; let new_msg = format!(": {}", my_id, msg); for (&uid, tx) in users.read().await.iter() { if my_id != uid { if let Err(_disconnected) = tx.send(Message::text(new_msg.clone())) { // The tx is disconnected, our `user_disconnected` code // should be happening in another task, nothing more to // do here. } } } } async fn user_disconnected(my_id: usize, users: &Users) { eprintln!("good bye user: {}", my_id); }