aboutsummaryrefslogtreecommitdiff
path: root/src/proxy.rs
blob: 10d7e3d69bfcdbda2c7c0ebdfbed6795756c2f45 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use crate::{helper::TokioIo, ServiceError, State};
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
    body::Incoming,
    header::UPGRADE,
    http::{
        uri::{PathAndQuery, Scheme},
        HeaderValue,
    },
    upgrade::OnUpgrade,
    Request, Uri,
};
use log::{debug, error, warn};
use std::{net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;

pub async fn proxy_request(
    state: &Arc<State>,
    mut req: Request<Incoming>,
    addr: SocketAddr,
    backend: &SocketAddr,
) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> {
    let scheme_secure = req.uri().scheme() == Some(&Scheme::HTTPS);
    *req.uri_mut() = Uri::builder()
        .path_and_query(
            req.uri()
                .clone()
                .path_and_query()
                .cloned()
                .unwrap_or(PathAndQuery::from_static("/")),
        )
        .build()
        .unwrap();

    req.headers_mut().insert(
        "x-forwarded-for",
        HeaderValue::from_str(&format!("{addr}")).unwrap(),
    );
    req.headers_mut().insert(
        "x-forwarded-proto",
        if scheme_secure {
            HeaderValue::from_static("https")
        } else {
            HeaderValue::from_static("http")
        },
    );

    let do_upgrade = req.headers().contains_key(UPGRADE);
    let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>();

    if let Some(_limit_guard) = state.l_outgoing.obtain() {
        debug!("\tforwarding to {}", backend);
        let mut resp = {
            let client_stream = TokioIo(
                TcpStream::connect(backend)
                    .await
                    .map_err(|_| ServiceError::CantConnect)?,
            );

            let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
                .await
                .map_err(ServiceError::Hyper)?;
            tokio::task::spawn(async move {
                if let Err(err) = conn.await {
                    warn!("connection failed: {:?}", err);
                }
            });
            sender
                .send_request(req)
                .await
                .map_err(ServiceError::Hyper)?
        };

        if do_upgrade {
            let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
            tokio::task::spawn(async move {
                debug!("about upgrading connection, sending empty response");
                match (
                    on_upgrade_upstream.unwrap().await,
                    on_upgrade_downstream.unwrap().await,
                ) {
                    (Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
                        debug!("upgrade successful");
                        match tokio::io::copy_bidirectional(
                            &mut TokioIo(upgraded_downstream),
                            &mut TokioIo(upgraded_upstream),
                        )
                        .await
                        {
                            Ok((from_client, from_server)) => {
                                debug!("proxy socket terminated: {from_server} sent, {from_client} received")
                            }
                            Err(e) => warn!("proxy socket error: {e}"),
                        }
                    }
                    (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"),
                }
            });
        }
        Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed()))
    } else {
        Err(ServiceError::Limit)
    }
}