aboutsummaryrefslogtreecommitdiff
path: root/src/transaction/mod.rs
blob: 3368c47b5dc8d4019ab50cfcb00995c99fc35a67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
pub mod auth;

use crate::{
    encoding::{
        headers::{CSeq, CallID, ContentLength},
        request::Request,
        response::Response,
        Message,
    },
    transport::Transport,
};
use anyhow::{anyhow, Result};
use std::{
    collections::HashMap,
    sync::atomic::{AtomicU32, Ordering},
};
use tokio::sync::{
    mpsc::{self, channel},
    RwLock,
};

pub struct TransactionUser<T> {
    transport: T,
    sequence: AtomicU32,
    pending_requests: RwLock<HashMap<CSeq, mpsc::Sender<Response>>>,
}

impl<T: Transport> TransactionUser<T> {
    pub fn new(transport: T) -> Self {
        Self {
            sequence: 0.into(),
            pending_requests: Default::default(),
            transport,
        }
    }

    pub async fn process_incoming(&self) -> Result<Request> {
        loop {
            let mesg = self.transport.recv().await?;
            match mesg {
                Message::Request(req) => break Ok(req),
                Message::Response(resp) => {
                    let cseq = resp
                        .headers
                        .get()
                        .ok_or(anyhow!("response cseq missing"))??;
                    self.pending_requests
                        .write()
                        .await
                        .get_mut(&cseq)
                        .ok_or(anyhow!("message was not requested"))?
                        .send(resp)
                        .await?;
                }
            }
        }
    }
    pub async fn respond(&self, req: &Request, mut resp: Response) -> Result<()> {
        resp.headers.insert(
            req.headers
                .get::<CSeq>()
                .ok_or(anyhow!("cseq is mandatory"))??,
        );
        resp.headers.insert(
            req.headers
                .get::<CallID>()
                .ok_or(anyhow!("call-id is mandatory"))??,
        );
        resp.headers.insert(ContentLength(resp.body.len()));
        self.transport.send(Message::Response(resp)).await?;
        Ok(())
    }

    pub async fn transact(&self, mut request: Request) -> Result<mpsc::Receiver<Response>> {
        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
        let cseq = CSeq(seq, request.method);
        request.headers.insert(cseq);
        request.headers.insert(ContentLength(request.body.len()));

        let (tx, rx) = channel(4);

        self.transport.send(Message::Request(request)).await?;
        self.pending_requests.write().await.insert(cseq, tx);

        Ok(rx)
    }
}