From d3e3cfdaa53ae354b1de4f540426af3e6af8096d Mon Sep 17 00:00:00 2001 From: metamuffin Date: Mon, 18 Mar 2024 09:51:38 +0100 Subject: reworking rift: part two --- client-native-rift/src/file.rs | 112 ++++++++++++++++++- client-native-rift/src/main.rs | 244 ++++++++++++++++------------------------- client-native-rift/src/port.rs | 66 ++++++++++- 3 files changed, 265 insertions(+), 157 deletions(-) (limited to 'client-native-rift/src') diff --git a/client-native-rift/src/file.rs b/client-native-rift/src/file.rs index bfe32fd..93dde62 100644 --- a/client-native-rift/src/file.rs +++ b/client-native-rift/src/file.rs @@ -1,8 +1,12 @@ +use crate::RequestHandler; use bytes::Bytes; use humansize::DECIMAL; -use libkeks::{peer::Peer, protocol::ProvideInfo, DynFut, LocalResource}; +use libkeks::{ + peer::Peer, protocol::ProvideInfo, webrtc::data_channel::RTCDataChannel, DynFut, LocalResource, +}; use log::{debug, error, info}; use std::{ + future::Future, path::PathBuf, pin::Pin, sync::{ @@ -11,8 +15,8 @@ use std::{ }, }; use tokio::{ - fs::File, - io::{AsyncRead, AsyncReadExt}, + fs::{File, OpenOptions}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, sync::RwLock, }; @@ -107,3 +111,105 @@ impl LocalResource for FileSender { }) } } + +pub struct DownloadHandler { + pub path: Option, +} +impl RequestHandler for DownloadHandler { + fn on_connect( + &self, + resource: ProvideInfo, + channel: Arc, + ) -> Pin> + Send + Sync>> { + let path = self.path.clone().unwrap_or_else(|| { + resource + .label + .clone() + .unwrap_or("download".to_owned()) + .replace("/", "_") + .replace("..", "_") + .into() + }); + if path.exists() {} + Box::pin(async move { + let pos = Arc::new(AtomicUsize::new(0)); + let writer: Arc>>>> = + Arc::new(RwLock::new(None)); + { + let writer = writer.clone(); + let path = path.clone(); + let channel2 = channel.clone(); + channel.on_open(Box::new(move || { + let path = path.clone(); + let writer = writer.clone(); + Box::pin(async move { + info!("channel opened"); + match OpenOptions::new() + .write(true) + .read(false) + .create_new(true) + .open(path) + .await + { + Ok(file) => { + *writer.write().await = Some(Box::pin(file)); + } + Err(e) => { + error!("cannot write download: {e}"); + channel2.close().await.unwrap(); + } + } + }) + })); + } + { + let writer = writer.clone(); + channel.on_close(Box::new(move || { + let writer = writer.clone(); + Box::pin(async move { + info!("channel closed"); + *writer.write().await = None; + }) + })); + } + { + let writer = writer.clone(); + channel.on_message(Box::new(move |mesg| { + let writer = writer.clone(); + let pos = pos.clone(); + Box::pin(async move { + // TODO + if mesg.is_string { + let s = String::from_utf8((&mesg.data).to_vec()).unwrap(); + if &s == "end" { + info!("transfer complete") + } + } else { + let pos = pos.fetch_add(mesg.data.len(), Ordering::Relaxed); + info!( + "recv {:?} ({} of {})", + mesg.data.len(), + humansize::format_size(pos, DECIMAL), + humansize::format_size(resource.size.unwrap_or(0), DECIMAL), + ); + writer + .write() + .await + .as_mut() + .unwrap() + .write_all(&mesg.data) + .await + .unwrap(); + } + }) + })) + } + channel.on_error(Box::new(move |err| { + Box::pin(async move { + error!("data channel errored: {err}"); + }) + })); + Ok(()) + }) + } +} diff --git a/client-native-rift/src/main.rs b/client-native-rift/src/main.rs index 57a4947..c44f8c2 100644 --- a/client-native-rift/src/main.rs +++ b/client-native-rift/src/main.rs @@ -8,18 +8,22 @@ pub mod file; pub mod port; use clap::{ColorChoice, Parser}; -use file::FileSender; +use file::{DownloadHandler, FileSender}; use libkeks::{ instance::Instance, peer::{Peer, TransportChannel}, protocol::ProvideInfo, + webrtc::data_channel::RTCDataChannel, Config, DynFut, EventHandler, }; -use log::{info, warn}; -use port::PortExposer; +use log::{error, info, warn}; +use port::{ForwardHandler, PortExposer}; use rustyline::{error::ReadlineError, DefaultEditor}; -use std::{collections::HashMap, os::unix::prelude::MetadataExt, path::PathBuf, sync::Arc}; -use tokio::{fs, sync::RwLock}; +use std::{ + collections::HashMap, future::Future, os::unix::prelude::MetadataExt, path::PathBuf, pin::Pin, + sync::Arc, +}; +use tokio::{fs, net::TcpListener, sync::RwLock}; use users::get_current_username; fn main() { @@ -67,6 +71,22 @@ pub enum Command { Forward { id: String, port: Option }, } +struct State { + requested: HashMap>, +} +pub trait RequestHandler: Send + Sync + 'static { + fn on_connect( + &self, + resource: ProvideInfo, + channel: Arc, + ) -> Pin> + Send + Sync>>; +} + +#[derive(Clone)] +struct Handler { + state: Arc>, +} + fn get_username() -> String { get_current_username() .map(|u| u.to_str().unwrap().to_string()) @@ -141,6 +161,11 @@ async fn run() -> anyhow::Result<()> { for (rid, r) in p.remote_provided.read().await.iter() { if rid == &id { if r.kind == "file" { + state + .write() + .await + .requested + .insert(id.clone(), Box::new(DownloadHandler { path })); p.request_resource(id).await; } else { warn!("not a file"); @@ -163,10 +188,52 @@ async fn run() -> anyhow::Result<()> { })) .await; } - Command::Forward { id, port } => {} + Command::Forward { id, port } => { + let peers = inst.peers.read().await; + 'outer: for peer in peers.values() { + for (rid, r) in peer.remote_provided.read().await.iter() { + if rid == &id { + if r.kind == "port" { + let peer = peer.to_owned(); + let state = state.clone(); + tokio::task::spawn(async move { + let Ok(listener) = + TcpListener::bind(("127.0.0.1", port.unwrap_or(0))) + .await + else { + error!("cannot bind tcp listener"); + return; + }; + info!( + "tcp listener bound to {}", + listener.local_addr().unwrap() + ); + while let Ok((stream, addr)) = listener.accept().await { + info!("new connection from {addr:?}"); + state.write().await.requested.insert( + id.clone(), + Box::new(ForwardHandler { + stream: Arc::new(RwLock::new(Some(stream))), + }), + ); + peer.request_resource(id.clone()).await; + } + }); + } else { + warn!("not a port"); + } + break 'outer; + } + } + } + } }, Err(err) => err.print().unwrap(), }, + Err(ReadlineError::Eof) => { + info!("exit"); + break; + } Err(ReadlineError::Interrupted) => { info!("interrupted; exiting..."); break; @@ -175,64 +242,25 @@ async fn run() -> anyhow::Result<()> { } } - // match &args.action { - // Action::Send { filename } => { - // inst.add_local_resource(Box::new(FileSender { - // info: ProvideInfo { - // id: "the-file".to_string(), // we only share a single file so its fine - // kind: "file".to_string(), - // track_kind: None, - // label: Some(filename.clone().unwrap_or("stdin".to_string())), - // size: if let Some(filename) = &filename { - // Some(fs::metadata(filename).await.unwrap().size() as usize) - // } else { - // None - // }, - // }, - // reader_factory: args.action, - // })) - // .await; - // } - // _ => (), - // } Ok(()) } -struct State { - requested: HashMap>, -} -pub trait RequestHandler: Send + Sync + 'static { - -} - -#[derive(Clone)] -struct Handler { - state: Arc>, -} - impl EventHandler for Handler { fn peer_join(&self, _peer: Arc) -> libkeks::DynFut<()> { Box::pin(async move {}) } - fn peer_leave(&self, _peer: Arc) -> libkeks::DynFut<()> { Box::pin(async move {}) } - fn resource_added(&self, peer: Arc, info: libkeks::protocol::ProvideInfo) -> DynFut<()> { - let id = info.id.clone(); - Box::pin(async move { - // match &args.action { - // Action::Receive { .. } => { - // if info.kind == "file" { - // peer.request_resource(id).await; - // } - // } - // _ => (), - // } - }) + fn resource_added( + &self, + _peer: Arc, + _info: libkeks::protocol::ProvideInfo, + ) -> DynFut<()> { + Box::pin(async move {}) } fn resource_removed(&self, _peer: Arc, _id: String) -> DynFut<()> { - Box::pin(async {}) + Box::pin(async move {}) } fn resource_connected( @@ -242,106 +270,20 @@ impl EventHandler for Handler { channel: TransportChannel, ) -> libkeks::DynFut<()> { let resource = resource.clone(); - let s = self.clone(); + let k = self.clone(); Box::pin(async move { - // match channel { - // TransportChannel::Track(_) => warn!("wrong type"), - // TransportChannel::DataChannel(dc) => { - // if resource.kind != "file" { - // return error!("we got a non-file resource for some reason…"); - // } - // let pos = Arc::new(AtomicUsize::new(0)); - // let writer: Arc>>>> = - // Arc::new(RwLock::new(None)); - // { - // let writer = writer.clone(); - // let s = s.clone(); - // dc.on_open(Box::new(move || { - // let s = s.clone(); - // let writer = writer.clone(); - // Box::pin(async move { - // info!("channel opened"); - // *writer.write().await = Some(s.args.action.create_writer().await) - // }) - // })); - // } - // { - // let writer = writer.clone(); - // dc.on_close(Box::new(move || { - // let writer = writer.clone(); - // Box::pin(async move { - // info!("channel closed"); - // *writer.write().await = None; - // exit(0); - // }) - // })); - // } - // { - // let writer = writer.clone(); - // dc.on_message(Box::new(move |mesg| { - // let writer = writer.clone(); - // let pos = pos.clone(); - // Box::pin(async move { - // // TODO - // if mesg.is_string { - // let s = String::from_utf8((&mesg.data).to_vec()).unwrap(); - // if &s == "end" { - // info!("EOF reached") - // } - // } else { - // let pos = pos.fetch_add(mesg.data.len(), Ordering::Relaxed); - // info!( - // "recv {:?} ({} of {})", - // mesg.data.len(), - // humansize::format_size(pos, DECIMAL), - // humansize::format_size(resource.size.unwrap_or(0), DECIMAL), - // ); - // writer - // .write() - // .await - // .as_mut() - // .unwrap() - // .write_all(&mesg.data) - // .await - // .unwrap(); - // } - // }) - // })) - // } - // dc.on_error(Box::new(move |err| { - // Box::pin(async move { - // error!("data channel errored: {err}"); - // }) - // })); - // } - // } + if let Some(handler) = k.state.write().await.requested.get(&resource.id) { + match channel { + TransportChannel::Track(_) => warn!("wrong type"), + TransportChannel::DataChannel(channel) => { + if let Err(e) = handler.on_connect(resource, channel).await { + warn!("request handler error: {e}"); + } + } + } + } else { + warn!("got {:?}, which was not requested", resource.id); + } }) } } - -// impl Action { -// pub async fn create_writer(&self) -> Pin> { -// match self { -// Action::Receive { filename } => { -// if let Some(filename) = filename { -// Box::pin(File::create(filename).await.unwrap()) -// } else { -// Box::pin(stdout()) -// } -// } -// _ => unreachable!(), -// } -// } -// pub async fn create_reader(&self) -> Pin> { -// match self { -// Action::Send { filename } => { -// if let Some(filename) = filename { -// Box::pin(File::open(filename).await.unwrap()) -// } else { -// Box::pin(stdin()) -// } -// } -// _ => unreachable!(), -// } -// } -// } diff --git a/client-native-rift/src/port.rs b/client-native-rift/src/port.rs index ae56f2c..857623e 100644 --- a/client-native-rift/src/port.rs +++ b/client-native-rift/src/port.rs @@ -1,9 +1,12 @@ +use crate::RequestHandler; use bytes::Bytes; -use libkeks::{peer::Peer, protocol::ProvideInfo, DynFut, LocalResource}; +use libkeks::{ + peer::Peer, protocol::ProvideInfo, webrtc::data_channel::RTCDataChannel, DynFut, LocalResource, +}; use log::{debug, error, info, warn}; -use std::{pin::Pin, sync::Arc}; +use std::{future::Future, pin::Pin, sync::Arc}; use tokio::{ - io::{AsyncRead, AsyncReadExt}, + io::{AsyncRead, AsyncReadExt, AsyncWriteExt}, net::TcpStream, sync::RwLock, }; @@ -97,3 +100,60 @@ impl LocalResource for PortExposer { }) } } + +pub struct ForwardHandler { + pub stream: Arc>>, +} +impl RequestHandler for ForwardHandler { + fn on_connect( + &self, + _resource: ProvideInfo, + channel: Arc, + ) -> Pin> + Send + Sync>> { + let stream = self.stream.clone(); + Box::pin(async move { + let stream = stream.write().await.take().unwrap(); + let (mut read, write) = stream.into_split(); + let write = Arc::new(RwLock::new(write)); + + let channel2 = channel.clone(); + channel.on_open(Box::new(move || { + Box::pin(async move { + info!("channel open"); + let channel = channel2.clone(); + tokio::task::spawn(async move { + let mut buf = [0u8; 1 << 15]; + loop { + let Ok(size) = read.read(&mut buf).await else { + break; + }; + channel + .send(&Bytes::copy_from_slice(&buf[..size])) + .await + .unwrap(); + } + }); + }) + })); + channel.on_close(Box::new(move || { + Box::pin(async move { + info!("channel closed"); + }) + })); + channel.on_error(Box::new(move |err| { + Box::pin(async move { error!("channel error: {err}") }) + })); + { + let write = write.clone(); + channel.on_message(Box::new(move |message| { + let write = write.clone(); + Box::pin(async move { + write.write().await.write_all(&message.data).await.unwrap(); + }) + })); + } + + Ok(()) + }) + } +} -- cgit v1.2.3-70-g09d2