summaryrefslogtreecommitdiff
path: root/client-native-rift/src/file.rs
diff options
context:
space:
mode:
Diffstat (limited to 'client-native-rift/src/file.rs')
-rw-r--r--client-native-rift/src/file.rs112
1 files changed, 109 insertions, 3 deletions
diff --git a/client-native-rift/src/file.rs b/client-native-rift/src/file.rs
index bfe32fd..93dde62 100644
--- a/client-native-rift/src/file.rs
+++ b/client-native-rift/src/file.rs
@@ -1,8 +1,12 @@
+use crate::RequestHandler;
use bytes::Bytes;
use humansize::DECIMAL;
-use libkeks::{peer::Peer, protocol::ProvideInfo, DynFut, LocalResource};
+use libkeks::{
+ peer::Peer, protocol::ProvideInfo, webrtc::data_channel::RTCDataChannel, DynFut, LocalResource,
+};
use log::{debug, error, info};
use std::{
+ future::Future,
path::PathBuf,
pin::Pin,
sync::{
@@ -11,8 +15,8 @@ use std::{
},
};
use tokio::{
- fs::File,
- io::{AsyncRead, AsyncReadExt},
+ fs::{File, OpenOptions},
+ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
sync::RwLock,
};
@@ -107,3 +111,105 @@ impl LocalResource for FileSender {
})
}
}
+
+pub struct DownloadHandler {
+ pub path: Option<PathBuf>,
+}
+impl RequestHandler for DownloadHandler {
+ fn on_connect(
+ &self,
+ resource: ProvideInfo,
+ channel: Arc<RTCDataChannel>,
+ ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + Sync>> {
+ let path = self.path.clone().unwrap_or_else(|| {
+ resource
+ .label
+ .clone()
+ .unwrap_or("download".to_owned())
+ .replace("/", "_")
+ .replace("..", "_")
+ .into()
+ });
+ if path.exists() {}
+ Box::pin(async move {
+ let pos = Arc::new(AtomicUsize::new(0));
+ let writer: Arc<RwLock<Option<Pin<Box<dyn AsyncWrite + Send + Sync>>>>> =
+ Arc::new(RwLock::new(None));
+ {
+ let writer = writer.clone();
+ let path = path.clone();
+ let channel2 = channel.clone();
+ channel.on_open(Box::new(move || {
+ let path = path.clone();
+ let writer = writer.clone();
+ Box::pin(async move {
+ info!("channel opened");
+ match OpenOptions::new()
+ .write(true)
+ .read(false)
+ .create_new(true)
+ .open(path)
+ .await
+ {
+ Ok(file) => {
+ *writer.write().await = Some(Box::pin(file));
+ }
+ Err(e) => {
+ error!("cannot write download: {e}");
+ channel2.close().await.unwrap();
+ }
+ }
+ })
+ }));
+ }
+ {
+ let writer = writer.clone();
+ channel.on_close(Box::new(move || {
+ let writer = writer.clone();
+ Box::pin(async move {
+ info!("channel closed");
+ *writer.write().await = None;
+ })
+ }));
+ }
+ {
+ let writer = writer.clone();
+ channel.on_message(Box::new(move |mesg| {
+ let writer = writer.clone();
+ let pos = pos.clone();
+ Box::pin(async move {
+ // TODO
+ if mesg.is_string {
+ let s = String::from_utf8((&mesg.data).to_vec()).unwrap();
+ if &s == "end" {
+ info!("transfer complete")
+ }
+ } else {
+ let pos = pos.fetch_add(mesg.data.len(), Ordering::Relaxed);
+ info!(
+ "recv {:?} ({} of {})",
+ mesg.data.len(),
+ humansize::format_size(pos, DECIMAL),
+ humansize::format_size(resource.size.unwrap_or(0), DECIMAL),
+ );
+ writer
+ .write()
+ .await
+ .as_mut()
+ .unwrap()
+ .write_all(&mesg.data)
+ .await
+ .unwrap();
+ }
+ })
+ }))
+ }
+ channel.on_error(Box::new(move |err| {
+ Box::pin(async move {
+ error!("data channel errored: {err}");
+ })
+ }));
+ Ok(())
+ })
+ }
+}