aboutsummaryrefslogtreecommitdiff
path: root/src/transaction/mod.rs
blob: 3be9544617178e66bba33943ffd323d8f69c0cf0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use crate::{
    encoding::{headers::CSeq, request::Request, response::Response, Message},
    transport::Transport,
};
use anyhow::{anyhow, Result};
use std::{
    collections::HashMap,
    sync::atomic::{AtomicU32, Ordering},
};
use tokio::sync::{
    mpsc::{self, channel},
    RwLock,
};

pub struct TransactionUser<T> {
    transport: T,
    sequence: AtomicU32,
    pending_requests: RwLock<HashMap<CSeq, mpsc::Sender<Response>>>,
}

impl<T: Transport> TransactionUser<T> {
    pub fn new(transport: T) -> Self {
        Self {
            sequence: 0.into(),
            pending_requests: Default::default(),
            transport,
        }
    }

    pub async fn process_incoming(&self) -> Result<Request> {
        loop {
            let mesg = self.transport.recv().await?;
            match mesg {
                Message::Request(req) => break Ok(req),
                Message::Response(resp) => {
                    let cseq = resp
                        .headers
                        .get()
                        .ok_or(anyhow!("response cseq missing"))??;
                    self.pending_requests
                        .write()
                        .await
                        .get_mut(&cseq)
                        .ok_or(anyhow!("message was not requested"))?
                        .send(resp)
                        .await?;
                }
            }
        }
    }

    pub async fn transact(&self, mut request: Request) -> Result<mpsc::Receiver<Response>> {
        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
        let cseq = CSeq(seq, request.method);
        request.headers.insert(cseq);

        let (tx, rx) = channel(4);

        self.transport.send(Message::Request(request)).await?;
        self.pending_requests.write().await.insert(cseq, tx);

        Ok(rx)
    }
}