aboutsummaryrefslogtreecommitdiff
path: root/src/transaction/mod.rs
blob: 2953cb70756dad836410dfb64d47980c6469e4e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use crate::{
    encoding::{headers::CSeq, request::Request, response::Response},
    transport::Transport,
};
use anyhow::{anyhow, Result};
use std::{
    collections::HashMap,
    sync::atomic::{AtomicU32, Ordering},
};
use tokio::sync::{
    mpsc::{self, channel},
    RwLock,
};

pub struct TransactionUser<T> {
    transport: T,
    sequence: AtomicU32,
    pending_requests: RwLock<HashMap<CSeq, mpsc::Sender<Response>>>,
}

impl<T: Transport> TransactionUser<T> {
    pub fn new(transport: T) -> Self {
        Self {
            sequence: 0.into(),
            pending_requests: Default::default(),
            transport,
        }
    }

    pub async fn process_responses(&self) -> Result<()> {
        let resp = self.transport.recv().await?;
        let cseq = resp
            .headers
            .get()
            .ok_or(anyhow!("response cseq missing"))??;
        self.pending_requests
            .write()
            .await
            .get_mut(&cseq)
            .ok_or(anyhow!("message was not requested"))?
            .send(resp)
            .await?;
        Ok(())
    }

    pub async fn transact(&self, mut request: Request) -> Result<mpsc::Receiver<Response>> {
        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
        let cseq = CSeq(seq, request.method);
        request.headers.insert(cseq);

        let (tx, rx) = channel(4);

        self.transport.send(request).await?;
        self.pending_requests.write().await.insert(cseq, tx);

        Ok(rx)
    }
}