aboutsummaryrefslogtreecommitdiff
path: root/src/modules/proxy.rs
blob: 2fa35381b7fea3b4c3403824a1341b1e63120858 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse};
use crate::ServiceError;
use futures::Future;
use http_body_util::BodyExt;
use hyper::{http::HeaderValue, upgrade::OnUpgrade, StatusCode};
use hyper_util::rt::TokioIo;
use log::{debug, warn};
use serde::Deserialize;
use serde_yml::Value;
use std::{net::SocketAddr, pin::Pin, sync::Arc};
use tokio::net::TcpStream;

#[derive(Default)]
pub struct ProxyKind;

#[derive(Debug, Deserialize)]
struct Proxy {
    #[serde(default)]
    set_forwarded_for: bool,
    #[serde(default = "ret_true")]
    set_real_ip: bool,
    backend: SocketAddr,
}

fn ret_true() -> bool {
    true
}

impl NodeKind for ProxyKind {
    fn name(&self) -> &'static str {
        "proxy"
    }
    fn instanciate(&self, config: Value) -> anyhow::Result<Arc<dyn Node>> {
        Ok(Arc::new(serde_yml::from_value::<Proxy>(config)?))
    }
}
impl Node for Proxy {
    fn handle<'a>(
        &'a self,
        context: &'a mut NodeContext,
        mut request: NodeRequest,
    ) -> Pin<Box<dyn Future<Output = Result<NodeResponse, ServiceError>> + Send + Sync + 'a>> {
        Box::pin(async move {
            if self.set_real_ip {
                request.headers_mut().insert(
                    "x-real-ip",
                    HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(),
                );
            }
            if self.set_forwarded_for {
                request.headers_mut().insert(
                    "x-forwarded-for",
                    HeaderValue::from_str(&format!("{}", context.addr.ip())).unwrap(),
                );
                request.headers_mut().insert(
                    "x-forwarded-port",
                    HeaderValue::from_str(&context.listen_addr.port().to_string()).unwrap(),
                );
                let scheme =
                    HeaderValue::from_str(if context.secure { "https" } else { "http" }).unwrap();
                request
                    .headers_mut()
                    .insert("x-forwarded-scheme", scheme.clone());
                request.headers_mut().insert("x-forwarded-proto", scheme);
                if let Some(host) = request.headers().get("host").cloned() {
                    request.headers_mut().insert("x-forwarded-host", host);
                }
            }

            let on_upgrade_downstream = request.extensions_mut().remove::<OnUpgrade>();

            let _limit_guard = context.state.l_outgoing.try_acquire()?;
            debug!("\tforwarding to {}", self.backend);
            let mut resp = {
                let client_stream = TokioIo::new(
                    TcpStream::connect(self.backend)
                        .await
                        .map_err(|_| ServiceError::CantConnect)?,
                );

                let (mut sender, conn) = hyper::client::conn::http1::handshake(client_stream)
                    .await
                    .map_err(ServiceError::Hyper)?;
                tokio::task::spawn(async move {
                    if let Err(err) = conn.with_upgrades().await {
                        warn!("connection failed: {:?}", err);
                    }
                });
                sender
                    .send_request(request)
                    .await
                    .map_err(ServiceError::Hyper)?
            };

            if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
                let on_upgrade_upstream = resp
                    .extensions_mut()
                    .remove::<OnUpgrade>()
                    .ok_or(ServiceError::UpgradeFailed)?;
                let on_upgrade_downstream =
                    on_upgrade_downstream.ok_or(ServiceError::UpgradeFailed)?;
                tokio::task::spawn(async move {
                    debug!("about to upgrade connection");
                    match (on_upgrade_upstream.await, on_upgrade_downstream.await) {
                        (Ok(upgraded_upstream), Ok(upgraded_downstream)) => {
                            debug!("upgrade successful");
                            match tokio::io::copy_bidirectional(
                                &mut TokioIo::new(upgraded_downstream),
                                &mut TokioIo::new(upgraded_upstream),
                            )
                            .await
                            {
                                Ok((from_client, from_server)) => {
                                    debug!("proxy socket terminated: {from_server} sent, {from_client} received")
                                }
                                Err(e) => warn!("proxy socket error: {e}"),
                            }
                        }
                        (a, b) => warn!("upgrade error: upstream={a:?} downstream={b:?}"),
                    }
                });
            }

            let resp = resp.map(|b| b.map_err(ServiceError::Hyper).boxed());
            Ok(resp)
        })
    }
}