1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
pub mod auth;
use crate::{
encoding::{
headers::{CSeq, CallID, ContentLength},
request::Request,
response::Response,
Message,
},
transport::Transport,
};
use anyhow::{anyhow, Result};
use std::{
collections::HashMap,
sync::atomic::{AtomicU32, Ordering},
};
use tokio::sync::{
mpsc::{self, channel},
RwLock,
};
pub struct TransactionUser<T> {
transport: T,
sequence: AtomicU32,
pending_requests: RwLock<HashMap<CSeq, mpsc::Sender<Response>>>,
}
impl<T: Transport> TransactionUser<T> {
pub fn new(transport: T) -> Self {
Self {
sequence: 0.into(),
pending_requests: Default::default(),
transport,
}
}
pub async fn process_incoming(&self) -> Result<Request> {
loop {
let mesg = self.transport.recv().await?;
match mesg {
Message::Request(req) => break Ok(req),
Message::Response(resp) => {
let cseq = resp
.headers
.get()
.ok_or(anyhow!("response cseq missing"))??;
self.pending_requests
.write()
.await
.get_mut(&cseq)
.ok_or(anyhow!("message was not requested"))?
.send(resp)
.await?;
}
}
}
}
pub async fn respond(&self, req: &Request, mut resp: Response) -> Result<()> {
resp.headers.insert(
req.headers
.get::<CSeq>()
.ok_or(anyhow!("cseq is mandatory"))??,
);
resp.headers.insert(
req.headers
.get::<CallID>()
.ok_or(anyhow!("call-id is mandatory"))??,
);
resp.headers.insert(ContentLength(resp.body.len()));
self.transport.send(Message::Response(resp)).await?;
Ok(())
}
pub async fn transact(&self, mut request: Request) -> Result<mpsc::Receiver<Response>> {
let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
let cseq = CSeq(seq, request.method);
request.headers.insert(cseq);
request.headers.insert(ContentLength(request.body.len()));
let (tx, rx) = channel(4);
self.transport.send(Message::Request(request)).await?;
self.pending_requests.write().await.insert(cseq, tx);
Ok(rx)
}
}
|