use std::{future::Future, time::Instant};
use opentelemetry::{metrics::Histogram, KeyValue};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use crate::{utils::FnWrapper, MetricsAttributes};
#[derive(Clone, Debug)]
pub struct DurationRecorderLayer<OnRequest = (), OnResponse = (), OnError = ()> {
histogram: Histogram<u64>,
on_request: OnRequest,
on_response: OnResponse,
on_error: OnError,
}
impl DurationRecorderLayer {
#[must_use]
pub fn new(name: &'static str) -> Self {
let histogram = crate::meter().u64_histogram(name).init();
Self {
histogram,
on_request: (),
on_response: (),
on_error: (),
}
}
}
impl<OnRequest, OnResponse, OnError> DurationRecorderLayer<OnRequest, OnResponse, OnError> {
#[must_use]
pub fn on_request<NewOnRequest>(
self,
on_request: NewOnRequest,
) -> DurationRecorderLayer<NewOnRequest, OnResponse, OnError> {
DurationRecorderLayer {
histogram: self.histogram,
on_request,
on_response: self.on_response,
on_error: self.on_error,
}
}
#[must_use]
pub fn on_request_fn<F, T>(
self,
on_request: F,
) -> DurationRecorderLayer<FnWrapper<F>, OnResponse, OnError>
where
F: Fn(&T) -> Vec<KeyValue>,
{
self.on_request(FnWrapper(on_request))
}
#[must_use]
pub fn on_response<NewOnResponse>(
self,
on_response: NewOnResponse,
) -> DurationRecorderLayer<OnRequest, NewOnResponse, OnError> {
DurationRecorderLayer {
histogram: self.histogram,
on_request: self.on_request,
on_response,
on_error: self.on_error,
}
}
#[must_use]
pub fn on_response_fn<F, T>(
self,
on_response: F,
) -> DurationRecorderLayer<OnRequest, FnWrapper<F>, OnError>
where
F: Fn(&T) -> Vec<KeyValue>,
{
self.on_response(FnWrapper(on_response))
}
#[must_use]
pub fn on_error<NewOnError>(
self,
on_error: NewOnError,
) -> DurationRecorderLayer<OnRequest, OnResponse, NewOnError> {
DurationRecorderLayer {
histogram: self.histogram,
on_request: self.on_request,
on_response: self.on_response,
on_error,
}
}
#[must_use]
pub fn on_error_fn<F, T>(
self,
on_error: F,
) -> DurationRecorderLayer<OnRequest, OnResponse, FnWrapper<F>>
where
F: Fn(&T) -> Vec<KeyValue>,
{
self.on_error(FnWrapper(on_error))
}
}
impl<S, OnRequest, OnResponse, OnError> Layer<S>
for DurationRecorderLayer<OnRequest, OnResponse, OnError>
where
OnRequest: Clone,
OnResponse: Clone,
OnError: Clone,
{
type Service = DurationRecorderService<S, OnRequest, OnResponse, OnError>;
fn layer(&self, inner: S) -> Self::Service {
DurationRecorderService {
inner,
histogram: self.histogram.clone(),
on_request: self.on_request.clone(),
on_response: self.on_response.clone(),
on_error: self.on_error.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct DurationRecorderService<S, OnRequest = (), OnResponse = (), OnError = ()> {
inner: S,
histogram: Histogram<u64>,
on_request: OnRequest,
on_response: OnResponse,
on_error: OnError,
}
pin_project! {
pub struct DurationRecorderFuture<F, OnResponse = (), OnError = ()> {
#[pin]
inner: F,
start: Instant,
histogram: Histogram<u64>,
attributes_from_request: Vec<KeyValue>,
from_response: OnResponse,
from_error: OnError,
}
}
impl<F, R, E, OnResponse, OnError> Future for DurationRecorderFuture<F, OnResponse, OnError>
where
F: Future<Output = Result<R, E>>,
OnResponse: MetricsAttributes<R>,
OnError: MetricsAttributes<E>,
{
type Output = F::Output;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
let result = std::task::ready!(this.inner.poll(cx));
let duration = this.start.elapsed();
let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
let mut attributes = this.attributes_from_request.clone();
match &result {
Ok(response) => {
attributes.extend(this.from_response.attributes(response));
}
Err(error) => {
attributes.extend(this.from_error.attributes(error));
}
}
this.histogram.record(duration_ms, &attributes);
std::task::Poll::Ready(result)
}
}
impl<S, R, OnRequest, OnResponse, OnError> Service<R>
for DurationRecorderService<S, OnRequest, OnResponse, OnError>
where
S: Service<R>,
OnRequest: MetricsAttributes<R>,
OnResponse: MetricsAttributes<S::Response> + Clone,
OnError: MetricsAttributes<S::Error> + Clone,
{
type Response = S::Response;
type Error = S::Error;
type Future = DurationRecorderFuture<S::Future, OnResponse, OnError>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: R) -> Self::Future {
let start = Instant::now();
let attributes_from_request = self.on_request.attributes(&request).collect();
let inner = self.inner.call(request);
DurationRecorderFuture {
inner,
start,
histogram: self.histogram.clone(),
attributes_from_request,
from_response: self.on_response.clone(),
from_error: self.on_error.clone(),
}
}
}