Skip to main content

mas_storage_pg/
tracing.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8//! Records each executed SQL statement as `db.query.text` on the current
9//! tracing span, and accumulates per-context DB query count and timing onto the
10//! [`LogContext`].
11//!
12//! Recording happens at the *executor* layer rather than at `.traced()` time:
13//! [`ExecuteExt::traced`] wraps the query in a [`Traced`], whose `fetch_*` /
14//! `execute` methods substitute a [`RecordingExecutor`] for the real executor.
15//! The recording executor reads the SQL, records it, and times the query as it
16//! runs.
17
18use std::{
19    pin::Pin,
20    task::{Context, Poll, ready},
21    time::Instant,
22};
23
24use futures_util::{
25    FutureExt, StreamExt,
26    future::BoxFuture,
27    stream::{BoxStream, Stream},
28};
29use mas_context::LogContext;
30use opentelemetry_semantic_conventions::{
31    attribute::DB_QUERY_TEXT, trace::DB_RESPONSE_RETURNED_ROWS,
32};
33use sqlx::{
34    Database, Describe, Either, Error, Execute, Executor, IntoArguments,
35    query::{Map, Query, QueryAs, QueryScalar},
36};
37use tracing::Span;
38
39/// An extension trait that wraps a sqlx query so its SQL and timing get
40/// recorded when it is executed.
41///
42/// The span attached should have the `db.query.text` and
43/// `db.response.returned_rows` attribute set.
44pub trait ExecuteExt: Sized {
45    /// Wrap the query so that, when executed, its SQL is recorded as
46    /// `db.query.text` on the current span and its count/timing are added to
47    /// the current [`LogContext`].
48    #[must_use]
49    fn traced(self) -> Traced<Self> {
50        self.record(&Span::current())
51    }
52
53    /// Like [`ExecuteExt::traced`], but records onto the given span instead of
54    /// the current one. Use when the query runs under a span other than the
55    /// one current at the call site.
56    #[must_use]
57    fn record(self, span: &Span) -> Traced<Self> {
58        Traced {
59            query: self,
60            span: span.clone(),
61        }
62    }
63}
64
65pin_project_lite::pin_project! {
66    /// A stream that records every row fetched from the database
67    /// and tracks the elapsed wall-clock time.
68    struct RecordingStream<St, Db> {
69        #[pin]
70        inner: St,
71        span: Span,
72        database: std::marker::PhantomData<Db>,
73        start: Instant,
74        fetched: usize,
75    }
76}
77
78impl<Db: Database, St: Stream<Item = Result<Either<Db::QueryResult, Db::Row>, Error>>> Stream
79    for RecordingStream<St, Db>
80{
81    type Item = St::Item;
82
83    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St::Item>> {
84        let this = self.project();
85        let ret = match ready!(this.inner.poll_next(cx)) {
86            Some(Ok(Either::Left(query_result))) => Some(Ok(Either::Left(query_result))),
87            Some(Ok(Either::Right(row))) => {
88                *this.fetched += 1;
89                Some(Ok(Either::Right(row)))
90            }
91            Some(Err(err)) => Some(Err(err)),
92            // Stream is terminated; log query stats and return `None`.
93            None => {
94                let elapsed = this.start.elapsed();
95                this.span.record(DB_RESPONSE_RETURNED_ROWS, *this.fetched);
96                LogContext::maybe_record_query_stats(*this.fetched, elapsed);
97                None
98            }
99        };
100        Poll::Ready(ret)
101    }
102}
103
104/// An [`Executor`] wrapper that records the SQL of each query onto a span and
105/// accumulates count/timing onto the [`LogContext`]. Only `fetch_many` and
106/// `fetch_optional` are required; every other `Executor` method funnels through
107/// them. Note that no stats are recorded in case of an error.
108#[derive(Debug)]
109struct RecordingExecutor<E> {
110    inner: E,
111    span: Span,
112}
113
114impl<E> RecordingExecutor<E> {
115    fn new(inner: E, span: Span) -> Self {
116        Self { inner, span }
117    }
118}
119
120impl<'c, E> Executor<'c> for RecordingExecutor<E>
121where
122    E: Executor<'c>,
123{
124    type Database = E::Database;
125
126    fn fetch_many<'e, 'q: 'e, Q>(
127        self,
128        query: Q,
129    ) -> BoxStream<
130        'e,
131        Result<
132            Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
133            Error,
134        >,
135    >
136    where
137        'c: 'e,
138        Q: 'q + Execute<'q, E::Database>,
139    {
140        self.span.record(DB_QUERY_TEXT, query.sql());
141
142        RecordingStream {
143            inner: self.inner.fetch_many(query),
144            database: std::marker::PhantomData::<E::Database>,
145            span: self.span,
146            start: Instant::now(),
147            fetched: 0,
148        }
149        .boxed()
150    }
151
152    fn fetch_optional<'e, 'q: 'e, Q>(
153        self,
154        query: Q,
155    ) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
156    where
157        'c: 'e,
158        Q: 'q + Execute<'q, E::Database>,
159    {
160        self.span.record(DB_QUERY_TEXT, query.sql());
161        let inner = self.inner.fetch_optional(query);
162        async move {
163            let start = Instant::now();
164            let result = inner.await?;
165            #[expect(clippy::bool_to_int_with_if, reason = "clearer if explicit")]
166            let fetched = if result.is_some() { 1 } else { 0 };
167            self.span.record(DB_RESPONSE_RETURNED_ROWS, fetched);
168            LogContext::maybe_record_query_stats(fetched, start.elapsed());
169
170            Ok(result)
171        }
172        .boxed()
173    }
174
175    fn prepare_with<'e, 'q: 'e>(
176        self,
177        sql: &'q str,
178        parameters: &'e [<Self::Database as Database>::TypeInfo],
179    ) -> BoxFuture<'e, Result<<Self::Database as Database>::Statement<'q>, Error>>
180    where
181        'c: 'e,
182    {
183        self.inner.prepare_with(sql, parameters)
184    }
185
186    fn describe<'e, 'q: 'e>(
187        self,
188        sql: &'q str,
189    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
190    where
191        'c: 'e,
192    {
193        self.inner.describe(sql)
194    }
195}
196
197/// A query wrapped by [`ExecuteExt::traced`], carrying the span to record onto.
198pub struct Traced<Q> {
199    query: Q,
200    span: Span,
201}
202
203// Implementation of the [`ExecuteExt`] trait for each concrete query type we
204// care about. We avoid a blanket impl to avoid the methods being available on
205// all types.
206impl<DB: Database, A> ExecuteExt for Query<'_, DB, A> {}
207impl<DB: Database, O, A> ExecuteExt for QueryAs<'_, DB, O, A> {}
208impl<DB: Database, O, A> ExecuteExt for QueryScalar<'_, DB, O, A> {}
209impl<DB: Database, F, A> ExecuteExt for Map<'_, DB, F, A> {}
210
211// Each concrete query type needs its own delegating impl: `Map`/`QueryAs`/
212// `QueryScalar` apply their row-mapping in their *own* inherent `fetch_*`
213// methods (returning the mapped output), so we must call those, wrapping the
214// executor with a [`RecordingExecutor`] to record the span.
215
216impl<'q, DB: Database, A> Traced<Query<'q, DB, A>>
217where
218    A: 'q + Send + IntoArguments<'q, DB>,
219{
220    pub async fn execute<'e, 'c, E>(self, executor: E) -> Result<DB::QueryResult, Error>
221    where
222        'c: 'e,
223        'q: 'e,
224        A: 'e,
225        E: Executor<'c, Database = DB>,
226    {
227        self.query
228            .execute(RecordingExecutor::new(executor, self.span))
229            .await
230    }
231
232    pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<DB::Row, Error>
233    where
234        'c: 'e,
235        'q: 'e,
236        A: 'e,
237        E: Executor<'c, Database = DB>,
238    {
239        self.query
240            .fetch_one(RecordingExecutor::new(executor, self.span))
241            .await
242    }
243
244    pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<DB::Row>, Error>
245    where
246        'c: 'e,
247        'q: 'e,
248        A: 'e,
249        E: Executor<'c, Database = DB>,
250    {
251        self.query
252            .fetch_optional(RecordingExecutor::new(executor, self.span))
253            .await
254    }
255
256    pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<DB::Row>, Error>
257    where
258        'c: 'e,
259        'q: 'e,
260        A: 'e,
261        E: Executor<'c, Database = DB>,
262    {
263        self.query
264            .fetch_all(RecordingExecutor::new(executor, self.span))
265            .await
266    }
267}
268
269impl<'q, DB: Database, F, O, A> Traced<Map<'q, DB, F, A>>
270where
271    F: FnMut(DB::Row) -> Result<O, Error> + Send,
272    O: Send + Unpin,
273    A: 'q + Send + IntoArguments<'q, DB>,
274{
275    pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<O, Error>
276    where
277        'c: 'e,
278        'q: 'e,
279        E: 'e + Executor<'c, Database = DB>,
280        F: 'e,
281        O: 'e,
282    {
283        self.query
284            .fetch_one(RecordingExecutor::new(executor, self.span))
285            .await
286    }
287
288    pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<O>, Error>
289    where
290        'c: 'e,
291        'q: 'e,
292        E: 'e + Executor<'c, Database = DB>,
293        F: 'e,
294        O: 'e,
295    {
296        self.query
297            .fetch_optional(RecordingExecutor::new(executor, self.span))
298            .await
299    }
300
301    pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<O>, Error>
302    where
303        'c: 'e,
304        'q: 'e,
305        E: 'e + Executor<'c, Database = DB>,
306        F: 'e,
307        O: 'e,
308    {
309        self.query
310            .fetch_all(RecordingExecutor::new(executor, self.span))
311            .await
312    }
313}
314
315impl<'q, DB: Database, O, A> Traced<QueryAs<'q, DB, O, A>>
316where
317    A: 'q + IntoArguments<'q, DB>,
318    O: Send + Unpin + for<'r> sqlx::FromRow<'r, DB::Row>,
319{
320    pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<O, Error>
321    where
322        'c: 'e,
323        'q: 'e,
324        O: 'e,
325        A: 'e,
326        E: 'e + Executor<'c, Database = DB>,
327    {
328        self.query
329            .fetch_one(RecordingExecutor::new(executor, self.span))
330            .await
331    }
332
333    pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<O>, Error>
334    where
335        'c: 'e,
336        'q: 'e,
337        O: 'e,
338        A: 'e,
339        E: 'e + Executor<'c, Database = DB>,
340    {
341        self.query
342            .fetch_optional(RecordingExecutor::new(executor, self.span))
343            .await
344    }
345
346    pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<O>, Error>
347    where
348        'c: 'e,
349        'q: 'e,
350        O: 'e,
351        A: 'e,
352        E: 'e + Executor<'c, Database = DB>,
353    {
354        self.query
355            .fetch_all(RecordingExecutor::new(executor, self.span))
356            .await
357    }
358}
359
360impl<'q, DB: Database, O, A> Traced<QueryScalar<'q, DB, O, A>>
361where
362    O: Send + Unpin,
363    A: 'q + IntoArguments<'q, DB>,
364    (O,): Send + Unpin + for<'r> sqlx::FromRow<'r, DB::Row>,
365{
366    pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<O, Error>
367    where
368        'c: 'e,
369        'q: 'e,
370        O: 'e,
371        A: 'e,
372        E: 'e + Executor<'c, Database = DB>,
373    {
374        self.query
375            .fetch_one(RecordingExecutor::new(executor, self.span))
376            .await
377    }
378
379    pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<O>, Error>
380    where
381        'c: 'e,
382        'q: 'e,
383        O: 'e,
384        A: 'e,
385        E: 'e + Executor<'c, Database = DB>,
386    {
387        self.query
388            .fetch_optional(RecordingExecutor::new(executor, self.span))
389            .await
390    }
391
392    pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<O>, Error>
393    where
394        'c: 'e,
395        'q: 'e,
396        O: 'e,
397        A: 'e,
398        E: 'e + Executor<'c, Database = DB>,
399    {
400        self.query
401            .fetch_all(RecordingExecutor::new(executor, self.span))
402            .await
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use mas_context::LogContext;
409    use sqlx::PgPool;
410
411    use crate::tracing::ExecuteExt;
412
413    /// Each executed query should be counted (and timed) on the surrounding
414    /// [`LogContext`].
415    #[sqlx::test]
416    async fn test_db_stats_recorded(pool: PgPool) {
417        let log_context = LogContext::new("test");
418        log_context
419            .run(|| async {
420                sqlx::query("SELECT 1")
421                    .traced()
422                    .fetch_one(&pool)
423                    .await
424                    .unwrap();
425
426                sqlx::query("SELECT 1 FROM UNNEST(ARRAY[1, 2, 3])")
427                    .traced()
428                    .fetch_all(&pool)
429                    .await
430                    .unwrap();
431            })
432            .await;
433
434        let stats = log_context.stats();
435        assert_eq!(stats.db_queries, 2);
436        assert_eq!(stats.db_rows_fetched, 4);
437        assert!(stats.to_string().contains("queries: 2, fetched: 4"));
438    }
439}