// Copyright (C) 2020 Éloïs SANCHEZ.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
use std::{
net::{IpAddr, SocketAddr},
time::Duration,
};
use bytes::Bytes;
use crate::anti_spam::{AntiSpam, AntiSpamResponse};
use crate::*;
const MAX_BATCH_REQ_PROCESS_DURATION_IN_MILLIS: u64 = 5_000;
pub struct BadRequest(pub anyhow::Error);
impl std::fmt::Debug for BadRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl warp::reject::Reject for BadRequest {}
pub struct ReqExecTooLong;
impl std::fmt::Debug for ReqExecTooLong {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "server error: request execution too long")
}
}
impl warp::reject::Reject for ReqExecTooLong {}
struct GraphQlRequest {
inner: async_graphql::BatchRequest,
}
impl GraphQlRequest {
fn data(self, data: D) -> Self {
match self.inner {
async_graphql::BatchRequest::Single(request) => {
Self::new(async_graphql::BatchRequest::Single(request.data(data)))
}
async_graphql::BatchRequest::Batch(requests) => {
Self::new(async_graphql::BatchRequest::Batch(
requests.into_iter().map(|req| req.data(data)).collect(),
))
}
}
}
#[allow(clippy::from_iter_instead_of_collect)]
async fn execute(self, schema: GvaSchema) -> async_graphql::BatchResponse {
use std::iter::FromIterator as _;
match self.inner {
async_graphql::BatchRequest::Single(request) => {
async_graphql::BatchResponse::Single(schema.execute(request).await)
}
async_graphql::BatchRequest::Batch(requests) => async_graphql::BatchResponse::Batch(
futures::stream::FuturesOrdered::from_iter(
requests
.into_iter()
.zip(std::iter::repeat(schema))
.map(|(request, schema)| async move { schema.execute(request).await }),
)
.collect()
.await,
),
}
}
fn len(&self) -> usize {
match &self.inner {
async_graphql::BatchRequest::Single(_) => 1,
async_graphql::BatchRequest::Batch(requests) => requests.len(),
}
}
fn new(inner: async_graphql::BatchRequest) -> Self {
Self { inner }
}
fn single(request: async_graphql::Request) -> Self {
Self::new(async_graphql::BatchRequest::Single(request))
}
}
enum ServerResponse {
Bincode(Vec),
GraphQl(async_graphql::BatchResponse),
}
impl warp::reply::Reply for ServerResponse {
fn into_response(self) -> warp::reply::Response {
match self {
ServerResponse::Bincode(bytes) => bytes.into_response(),
ServerResponse::GraphQl(gql_batch_resp) => {
let mut resp = warp::reply::with_header(
warp::reply::json(&gql_batch_resp),
"content-type",
"application/json",
)
.into_response();
add_cache_control_batch(&mut resp, &gql_batch_resp);
resp
}
}
}
}
fn add_cache_control_batch(
http_resp: &mut warp::reply::Response,
batch_resp: &async_graphql::BatchResponse,
) {
match batch_resp {
async_graphql::BatchResponse::Single(resp) => add_cache_control(http_resp, resp),
async_graphql::BatchResponse::Batch(resps) => {
for resp in resps {
add_cache_control(http_resp, resp)
}
}
}
}
fn add_cache_control(http_resp: &mut warp::reply::Response, resp: &async_graphql::Response) {
if resp.is_ok() {
if let Some(cache_control) = resp.cache_control.value() {
if let Ok(value) = cache_control.parse() {
http_resp.headers_mut().insert("cache-control", value);
}
}
}
}
pub(crate) fn graphql(
conf: &GvaConf,
gva_schema: GvaSchema,
opts: async_graphql::http::MultipartOptions,
) -> impl warp::Filter + Clone {
let anti_spam = AntiSpam::from(conf);
let opts = Arc::new(opts);
warp::path::path(conf.path.clone())
.and(warp::method())
.and(warp::query::raw().or(warp::any().map(String::new)).unify())
.and(warp::addr::remote())
.and(warp::header::optional::("X-Real-IP"))
.and(warp::header::optional::("content-type"))
.and(warp::body::stream())
.and(warp::any().map(move || anti_spam.clone()))
.and(warp::any().map(move || gva_schema.clone()))
.and(warp::any().map(move || opts.clone()))
.and_then(
|method,
query: String,
remote_addr: Option,
x_real_ip: Option,
content_type: Option,
body,
anti_spam: AntiSpam,
gva_schema: GvaSchema,
opts: Arc| async move {
let AntiSpamResponse {
is_whitelisted,
is_ok,
} = anti_spam
.verify(x_real_ip.or_else(|| remote_addr.map(|ra| ra.ip())))
.await;
if is_ok {
if method == http::Method::GET {
let request: async_graphql::Request = serde_urlencoded::from_str(&query)
.map_err(|err| warp::reject::custom(BadRequest(err.into())))?;
Ok(ServerResponse::GraphQl(
GraphQlRequest::single(request.data(QueryContext { is_whitelisted }))
.execute(gva_schema)
.await,
))
} else {
let body_stream = futures::TryStreamExt::map_err(body, |err| {
std::io::Error::new(std::io::ErrorKind::Other, err)
})
.map_ok(|mut buf| {
let remaining = warp::Buf::remaining(&buf);
warp::Buf::copy_to_bytes(&mut buf, remaining)
});
if content_type.as_deref() == Some("application/bincode") {
tokio::time::timeout(
Duration::from_millis(MAX_BATCH_REQ_PROCESS_DURATION_IN_MILLIS),
process_bincode_batch_queries(body_stream, is_whitelisted),
)
.await
.map_err(|_| warp::reject::custom(ReqExecTooLong))?
} else {
tokio::time::timeout(
Duration::from_millis(MAX_BATCH_REQ_PROCESS_DURATION_IN_MILLIS),
process_json_batch_queries(
body_stream.into_async_read(),
content_type,
gva_schema,
is_whitelisted,
*opts,
),
)
.await
.map_err(|_| warp::reject::custom(ReqExecTooLong))?
}
}
} else {
Err(warp::reject::custom(BadRequest(anyhow::Error::msg(
r#"{ "error": "too many requests" }"#,
))))
}
},
)
}
async fn process_bincode_batch_queries(
body_reader: impl 'static + futures::TryStream + Send + Unpin,
is_whitelisted: bool,
) -> Result {
Ok(ServerResponse::Bincode(
duniter_bca::execute(body_reader, is_whitelisted).await,
))
}
async fn process_json_batch_queries(
body_reader: impl 'static + futures::AsyncRead + Send + Unpin,
content_type: Option,
gva_schema: GvaSchema,
is_whitelisted: bool,
opts: async_graphql::http::MultipartOptions,
) -> Result {
let batch_request = GraphQlRequest::new(
async_graphql::http::receive_batch_body(
content_type,
body_reader,
async_graphql::http::MultipartOptions::clone(&opts),
)
.await
.map_err(|err| warp::reject::custom(BadRequest(err.into())))?,
);
if is_whitelisted || batch_request.len() <= anti_spam::MAX_BATCH_SIZE {
Ok(ServerResponse::GraphQl(
batch_request
.data(QueryContext { is_whitelisted })
.execute(gva_schema)
.await,
))
} else {
Err(warp::reject::custom(BadRequest(anyhow::Error::msg(
r#"{ "error": "The batch contains too many requests" }"#,
))))
}
}
pub(crate) fn graphql_ws(
conf: &GvaConf,
schema: GvaSchema,
) -> impl warp::Filter + Clone {
let anti_spam = AntiSpam::from(conf);
warp::path::path(conf.subscriptions_path.clone())
.and(warp::addr::remote())
.and(warp::header::optional::("X-Real-IP"))
.and(warp::ws())
.and(warp::any().map(move || schema.clone()))
.and(warp::any().map(move || anti_spam.clone()))
.and_then(
|remote_addr: Option,
x_real_ip: Option,
ws: warp::ws::Ws,
schema: GvaSchema,
anti_spam: AntiSpam| async move {
let AntiSpamResponse {
is_whitelisted: _,
is_ok,
} = anti_spam
.verify(x_real_ip.or_else(|| remote_addr.map(|ra| ra.ip())))
.await;
if is_ok {
Ok((ws, schema))
} else {
Err(warp::reject::custom(BadRequest(anyhow::Error::msg(
r#"{ "error": "too many requests" }"#,
))))
}
},
)
.and_then(|(ws, schema): (warp::ws::Ws, GvaSchema)| {
let reply = ws.on_upgrade(move |websocket| {
let (ws_sender, ws_receiver) = websocket.split();
async move {
let _ = async_graphql::http::WebSocket::new(
schema,
ws_receiver
.take_while(|msg| futures::future::ready(msg.is_ok()))
.map(Result::unwrap)
.map(warp::ws::Message::into_bytes),
async_graphql::http::WebSocketProtocols::SubscriptionsTransportWS,
)
.map(|ws_msg| match ws_msg {
async_graphql::http::WsMessage::Text(s) => warp::ws::Message::text(s),
async_graphql::http::WsMessage::Close(code, reason) => {
warp::ws::Message::close_with(code, reason)
}
})
.map(Ok)
.forward(ws_sender)
.await;
}
});
futures::future::ready(Ok::<_, Rejection>(warp::reply::with_header(
reply,
"Sec-WebSocket-Protocol",
"graphql-ws",
)))
})
}