You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

670 lines
15 KiB
Rust

use std::{
cmp::Ordering,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Instant,
};
use brotli::{enc::BrotliEncoderParams, CompressorWriter};
use flate2::{
write::{DeflateEncoder, GzEncoder},
Compression,
};
use hyper::{
body::{Buf, Bytes, HttpBody},
header::HeaderValue,
Method, Request, Response, Uri,
};
use tower::{Layer, Service};
#[derive(Clone)]
pub struct Log;
impl<S> Layer<S> for Log {
type Service = LogService<S>;
fn layer(&self, inner: S) -> Self::Service {
LogService::new(inner)
}
}
#[derive(Clone)]
pub struct LogService<S> {
inner: S,
}
impl<S> LogService<S> {
fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for LogService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = LogServiceFuture<S::Future>;
fn poll_ready(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let uri = req.uri().clone();
let method = req.method().clone();
let fut = self.inner.call(req);
LogServiceFuture {
inner: fut,
uri,
method,
start: None,
}
}
}
pub struct LogServiceFuture<InnerFut> {
uri: Uri,
method: Method,
inner: InnerFut,
start: Option<Instant>,
}
impl<ResBody, InnerFut, InnerFutError> Future for LogServiceFuture<InnerFut>
where
InnerFut: Future<Output = Result<Response<ResBody>, InnerFutError>>,
{
type Output = InnerFut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (uri, method, inner, mut start) = unsafe {
let this = self.get_unchecked_mut();
(
&this.uri,
&this.method,
Pin::new_unchecked(&mut this.inner),
Pin::new_unchecked(&mut this.start),
)
};
start.get_or_insert(Instant::now());
match inner.poll(cx) {
Poll::Ready(r) => {
let res = match r {
Ok(res) => res,
e => return Poll::Ready(e),
};
let start = unsafe { start.unwrap_unchecked() };
let now = Instant::now();
let diff = now - start;
println!(
"{} {} {} [{}ms]",
method,
uri,
res.status(),
diff.as_millis()
);
Poll::Ready(Ok(res))
}
p => p,
}
}
}
pub type Preference = u8;
/// Compression algorithms and their associated prefence level. Higher is
/// better
#[derive(Clone)]
pub struct AlgorithmPreferences {
pub brotli: Option<(Preference, BrotliEncoderParams)>,
pub gzip: Option<(Preference, Compression)>,
pub deflate: Option<(Preference, Compression)>,
}
pub fn brotli_default_params(
quality: u8,
favor_cpu_efficiency: bool,
) -> BrotliEncoderParams {
BrotliEncoderParams {
quality: quality as i32,
favor_cpu_efficiency,
..Default::default()
}
}
#[derive(Clone)]
pub struct CompressPredicate<ContentTypePredicateFn> {
pub min_size: usize,
pub content_type_predicate: ContentTypePredicateFn,
}
#[derive(Clone)]
pub struct Compress<ContentTypePredicateFn> {
pub algorithms: AlgorithmPreferences,
pub predicate: CompressPredicate<ContentTypePredicateFn>,
}
impl<CTPF: Clone, S> Layer<S> for Compress<CTPF> {
type Service = CompressService<S, CTPF>;
fn layer(&self, inner: S) -> Self::Service {
CompressService {
compress: self.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct CompressService<S, CTPF> {
compress: Compress<CTPF>,
inner: S,
}
pub enum Algorithm {
Brotli(BrotliEncoderParams),
Deflate(Compression),
Gzip(Compression),
}
impl<S, CTPF, ReqBody, ResBody> Service<Request<ReqBody>>
for CompressService<S, CTPF>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
CTPF: Clone + FnOnce(&str) -> bool,
ResBody: HttpBody,
{
type Response = Response<CompressBody<ResBody>>;
type Error = S::Error;
type Future = CompressFuture<S::Future, CTPF>;
fn poll_ready(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let chosen_algorithm = match req.headers().get("accept-encoding") {
None => None,
Some(s) => choose_algorithm(&self.compress.algorithms, s),
};
let fut = self.inner.call(req);
CompressFuture {
inner: fut,
chosen_algorithm,
predicate: self.compress.predicate.clone(),
}
}
}
fn choose_algorithm(
preferences: &AlgorithmPreferences,
accept_encodings: &HeaderValue,
) -> Option<Algorithm> {
let accept_encodings = match accept_encodings.to_str().ok() {
None => return None,
Some(s) => s,
};
accept_encodings
.split(',')
.flat_map(|s| s.split(';').next())
.map(|s| s.trim().to_string())
.flat_map(|s| match s.as_str() {
"br" => preferences
.brotli
.as_ref()
.map(|p| (Algorithm::Brotli(p.1.clone()), p.0)),
"gzip" => preferences
.gzip
.as_ref()
.map(|p| (Algorithm::Gzip(p.1), p.0)),
"deflate" => preferences
.deflate
.as_ref()
.map(|p| (Algorithm::Deflate(p.1), p.0)),
_ => None,
})
.max_by(|a, b| {
if a.1 == b.1 {
return Ordering::Greater;
}
a.1.cmp(&b.1)
})
.map(|(e, _)| e)
}
pub struct CompressFuture<InnerFut, ContentTypePredicateFn> {
inner: InnerFut,
chosen_algorithm: Option<Algorithm>,
predicate: CompressPredicate<ContentTypePredicateFn>,
}
impl<InnerFut, CTPF, ResBody, InnerFutError> Future
for CompressFuture<InnerFut, CTPF>
where
InnerFut: Future<Output = Result<Response<ResBody>, InnerFutError>>,
CTPF: FnOnce(&str) -> bool + Clone,
ResBody: HttpBody,
{
type Output = Result<Response<CompressBody<ResBody>>, InnerFutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (chosen_algorithm, predicate, inner) = unsafe {
let this = self.get_unchecked_mut();
(
&this.chosen_algorithm,
&mut this.predicate,
Pin::new_unchecked(&mut this.inner),
)
};
match inner.poll(cx) {
Poll::Ready(r) => {
let res = match r {
Ok(res) => res,
Err(e) => return Poll::Ready(Err(e)),
};
let res = choose_body(res, chosen_algorithm, predicate.clone());
Poll::Ready(Ok(res))
}
_ => Poll::Pending,
}
}
}
pub enum CompressBody<B> {
None(B),
Brotli(BrotliBody<B>),
Gzip(GzipBody<B>),
Deflate(DeflateBody<B>),
}
impl<B> HttpBody for CompressBody<B>
where
B: HttpBody,
B::Data: Send + 'static,
{
type Data = Box<dyn Buf + Send + 'static>;
type Error = B::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match unsafe { self.get_unchecked_mut() } {
CompressBody::Brotli(b) => unsafe { Pin::new_unchecked(b) }
.poll_data(cx)
.map(|o| o.map(|r| r.map(|b| Box::new(b) as Box<_>))),
CompressBody::Gzip(b) => unsafe { Pin::new_unchecked(b) }
.poll_data(cx)
.map(|o| o.map(|r| r.map(|b| Box::new(b) as Box<_>))),
CompressBody::Deflate(b) => unsafe { Pin::new_unchecked(b) }
.poll_data(cx)
.map(|o| o.map(|r| r.map(|b| Box::new(b) as Box<_>))),
CompressBody::None(b) => unsafe { Pin::new_unchecked(b) }
.poll_data(cx)
.map(|o| o.map(|r| r.map(|b| Box::new(b) as Box<_>))),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
match unsafe { self.get_unchecked_mut() } {
CompressBody::None(b) => {
unsafe { Pin::new_unchecked(b) }.poll_trailers(cx)
}
CompressBody::Brotli(b) => {
unsafe { Pin::new_unchecked(b) }.poll_trailers(cx)
}
CompressBody::Gzip(b) => {
unsafe { Pin::new_unchecked(b) }.poll_trailers(cx)
}
CompressBody::Deflate(b) => {
unsafe { Pin::new_unchecked(b) }.poll_trailers(cx)
}
}
}
}
pub struct BrotliBody<B> {
inner: B,
compressor: Option<CompressorWriter<Vec<u8>>>,
}
impl<B: HttpBody> HttpBody for BrotliBody<B>
where
B::Data: Send + 'static,
{
type Data = Bytes;
type Error = B::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let (inner, compressor) = unsafe {
let this = self.get_unchecked_mut();
(Pin::new_unchecked(&mut this.inner), &mut this.compressor)
};
if compressor.is_none() {
return Poll::Ready(None);
}
let data = match inner.poll_data(cx) {
Poll::Ready(Some(Ok(d))) => Some(d),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Pending => return Poll::Pending,
};
match data {
None => {
let compressor =
unsafe { compressor.take().unwrap_unchecked() };
let buf = compressor.into_inner();
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
Some(d) => {
let mut compressor =
unsafe { compressor.as_mut().unwrap_unchecked() };
let mut reader = d.reader();
use std::io::Write;
let _ = std::io::copy(&mut reader, &mut compressor);
let _ = compressor.flush();
let buf = compressor.get_ref().clone();
compressor.get_mut().clear();
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
let inner = unsafe {
let this = self.get_unchecked_mut();
Pin::new_unchecked(&mut this.inner)
};
inner.poll_trailers(cx)
}
}
impl<B> BrotliBody<B> {
pub fn new(inner: B, params: &BrotliEncoderParams) -> Self {
Self {
inner,
compressor: Some(CompressorWriter::with_params(
Vec::new(),
4096,
params,
)),
}
}
}
pub struct GzipBody<B> {
inner: B,
encoder: Option<GzEncoder<Vec<u8>>>,
}
impl<B> GzipBody<B> {
pub fn new(inner: B, compression: Compression) -> Self {
Self {
inner,
encoder: Some(GzEncoder::new(Vec::new(), compression)),
}
}
}
impl<B> HttpBody for GzipBody<B>
where
B: HttpBody,
{
type Data = Bytes;
type Error = B::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let (inner, encoder) = unsafe {
let this = self.get_unchecked_mut();
(Pin::new_unchecked(&mut this.inner), &mut this.encoder)
};
if encoder.is_none() {
return Poll::Ready(None);
}
let data = match inner.poll_data(cx) {
Poll::Ready(Some(Ok(d))) => Some(d),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Pending => return Poll::Pending,
};
match data {
None => {
let encoder = unsafe { encoder.take().unwrap_unchecked() };
let buf = encoder.finish().unwrap();
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
Some(d) => {
let mut encoder =
unsafe { encoder.as_mut().unwrap_unchecked() };
let mut reader = d.reader();
let _ = std::io::copy(&mut reader, &mut encoder);
let buf = encoder.get_ref().clone();
encoder.get_mut().clear();
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
let inner = unsafe {
let this = self.get_unchecked_mut();
Pin::new_unchecked(&mut this.inner)
};
inner.poll_trailers(cx)
}
}
pub struct DeflateBody<B> {
inner: B,
encoder: Option<DeflateEncoder<Vec<u8>>>,
}
impl<B> DeflateBody<B> {
pub fn new(inner: B, compression: Compression) -> Self {
Self {
inner,
encoder: Some(DeflateEncoder::new(Vec::new(), compression)),
}
}
}
impl<B> HttpBody for DeflateBody<B>
where
B: HttpBody,
{
type Data = Bytes;
type Error = B::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let (inner, encoder) = unsafe {
let this = self.get_unchecked_mut();
(Pin::new_unchecked(&mut this.inner), &mut this.encoder)
};
if encoder.is_none() {
return Poll::Ready(None);
}
let data = match inner.poll_data(cx) {
Poll::Ready(Some(Ok(d))) => Some(d),
Poll::Ready(None) => None,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Pending => return Poll::Pending,
};
match data {
None => {
let encoder = unsafe { encoder.take().unwrap_unchecked() };
let buf = encoder.finish().unwrap();
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
Some(d) => {
let mut encoder =
unsafe { encoder.as_mut().unwrap_unchecked() };
let mut reader = d.reader();
let _ = std::io::copy(&mut reader, &mut encoder);
let buf = encoder.get_ref().clone();
encoder.get_mut().clear();
Poll::Ready(Some(Ok(Bytes::from(buf))))
}
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
let inner = unsafe {
let this = self.get_unchecked_mut();
Pin::new_unchecked(&mut this.inner)
};
inner.poll_trailers(cx)
}
}
fn replace_body<B, B_>(
r: Response<B>,
replace: impl FnOnce(B) -> B_,
) -> Response<B_> {
let (parts, body) = r.into_parts();
Response::from_parts(parts, replace(body))
}
pub fn choose_body<B, CTPF>(
res: Response<B>,
chosen_algorithm: &Option<Algorithm>,
predicate: CompressPredicate<CTPF>,
) -> Response<CompressBody<B>>
where
CTPF: FnOnce(&str) -> bool,
B: HttpBody,
{
let chosen_algorithm = match chosen_algorithm {
None => return replace_body(res, CompressBody::None),
Some(a) => a,
};
let headers = res.headers();
let content_length = match headers
.get("transfer-encoding")
.and_then(|v| v.to_str().ok())
{
None => Some(
headers
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or_else(|| res.body().size_hint().lower() as usize),
),
Some("chunked") => None,
_ => return replace_body(res, CompressBody::None),
};
match content_length {
Some(l) if l < predicate.min_size => {
return replace_body(res, CompressBody::None)
}
_ => (),
}
let content_type =
match headers.get("content-type").and_then(|v| v.to_str().ok()) {
Some(ct) => ct,
None => return replace_body(res, CompressBody::None),
};
if !(predicate.content_type_predicate)(content_type) {
return replace_body(res, CompressBody::None);
}
match chosen_algorithm {
Algorithm::Brotli(p) => {
let mut res = replace_body(res, |b| {
CompressBody::Brotli(BrotliBody::new(b, p))
});
res.headers_mut()
.insert("content-encoding", HeaderValue::from_static("br"));
res
}
Algorithm::Gzip(c) => {
let mut res =
replace_body(res, |b| CompressBody::Gzip(GzipBody::new(b, *c)));
res.headers_mut()
.insert("content-encoding", HeaderValue::from_static("gzip"));
res
}
Algorithm::Deflate(c) => {
let mut res = replace_body(res, |b| {
CompressBody::Deflate(DeflateBody::new(b, *c))
});
res.headers_mut().insert(
"content-encoding",
HeaderValue::from_static("deflate"),
);
res
}
}
}