aboutsummaryrefslogtreecommitdiff
path: root/src/proxy.rs
blob: aff1ae50764dc3ad5481443783dc32a1214c7578 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use crate::ServiceError;
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::{
    body::Incoming,
    header::{SERVER, UPGRADE},
    http::{
        uri::{PathAndQuery, Scheme},
        HeaderValue,
    },
    upgrade::OnUpgrade,
    Request, Uri,
};
use log::{debug, error, warn};
use std::net::SocketAddr;
use tokio::net::TcpStream;

pub async fn proxy_request(
    mut req: Request<Incoming>,
    addr: SocketAddr,
    backend: &SocketAddr,
) -> Result<hyper::Response<BoxBody<bytes::Bytes, ServiceError>>, ServiceError> {
    let scheme_secure = req.uri().scheme() == Some(&Scheme::HTTPS);
    *req.uri_mut() = Uri::builder()
        .path_and_query(
            req.uri()
                .clone()
                .path_and_query()
                .cloned()
                .unwrap_or(PathAndQuery::from_static("/")),
        )
        .build()
        .unwrap();

    req.headers_mut().insert(
        "x-forwarded-for",
        HeaderValue::from_str(&format!("{addr}")).unwrap(),
    );
    req.headers_mut().insert(
        "x-forwarded-proto",
        if scheme_secure {
            HeaderValue::from_static("https")
        } else {
            HeaderValue::from_static("http")
        },
    );

    let do_upgrade = req.headers().contains_key(UPGRADE);
    let on_upgrade_downstream = req.extensions_mut().remove::<OnUpgrade>();

    debug!("\tforwarding to {}", backend);
    let mut resp = {
        let client_stream = TcpStream::connect(backend)
            .await
            .map_err(|_| ServiceError::CantConnect)?;

        let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
            .await
            .map_err(ServiceError::Hyper)?;
        tokio::task::spawn(async move {
            if let Err(err) = conn.await {
                warn!("connection failed: {:?}", err);
            }
        });
        sender
            .send_request(req)
            .await
            .map_err(ServiceError::Hyper)?
    };

    let server_header = resp.headers().get(SERVER).cloned();
    resp.headers_mut().insert(
        SERVER,
        HeaderValue::from_str(&if let Some(o) = server_header {
            format!("{} via gnix", o.to_str().unwrap())
        } else {
            format!("gnix")
        })
        .unwrap(),
    );

    if do_upgrade {
        let on_upgrade_upstream = resp.extensions_mut().remove::<OnUpgrade>();
        tokio::task::spawn(async move {
            debug!("about upgrading connection, sending empty response");
            match (
                on_upgrade_upstream.unwrap().await,
                on_upgrade_downstream.unwrap().await,
            ) {
                (Ok(mut upgraded_upstream), Ok(mut upgraded_downstream)) => {
                    debug!("upgrade successful");
                    match tokio::io::copy_bidirectional(
                        &mut upgraded_downstream,
                        &mut upgraded_upstream,
                    )
                    .await
                    {
                        Ok((from_client, from_server)) => {
                            debug!("proxy socket terminated: {from_server} sent, {from_client} received")
                        }
                        Err(e) => warn!("proxy socket error: {e}"),
                    }
                }
                (a, b) => error!("upgrade error: upstream={a:?} downstream={b:?}"),
            }
        });
    }
    Ok(resp.map(|b| b.map_err(ServiceError::Hyper).boxed()))
}