use super::{Node, NodeContext, NodeKind, NodeRequest, NodeResponse}; use crate::{config::DynNode, error::ServiceError}; use anyhow::Result; use futures::Future; use headers::{HeaderMapExt, Upgrade}; use hyper::Method; use serde::Deserialize; use std::{pin::Pin, sync::Arc}; pub struct SwitchKind; #[derive(Deserialize)] pub struct Switch { condition: Condition, case_true: DynNode, case_false: DynNode, } impl NodeKind for SwitchKind { fn name(&self) -> &'static str { "switch" } fn instanciate(&self, config: serde_yaml::Value) -> Result> { Ok(Arc::new(serde_yaml::from_value::(config)?)) } } impl Node for Switch { fn handle<'a>( &'a self, context: &'a mut NodeContext, request: NodeRequest, ) -> Pin> + Send + Sync + 'a>> { Box::pin(async move { let cond = self.condition.test(&request); if cond { &self.case_true } else { &self.case_false } .handle(context, request) .await }) } } #[derive(Deserialize)] #[serde(rename_all = "snake_case")] enum Condition { IsWebsocketUpgrade, IsPost, IsGet, HasHeader(String), PathStartsWith(String), PathIs(String), } impl Condition { pub fn test(&self, req: &NodeRequest) -> bool { match self { Condition::IsWebsocketUpgrade => { req.headers().typed_get::() == Some(Upgrade::websocket()) } Condition::HasHeader(name) => req.headers().contains_key(name), Condition::PathStartsWith(path_prefix) => req.uri().path().starts_with(path_prefix), Condition::PathIs(path) => req.uri().path() == path, Condition::IsPost => req.method() == Method::POST, Condition::IsGet => req.method() == Method::GET, } } }