1use std::{
19 pin::Pin,
20 task::{Context, Poll, ready},
21 time::Instant,
22};
23
24use futures_util::{
25 FutureExt, StreamExt,
26 future::BoxFuture,
27 stream::{BoxStream, Stream},
28};
29use mas_context::LogContext;
30use opentelemetry_semantic_conventions::{
31 attribute::DB_QUERY_TEXT, trace::DB_RESPONSE_RETURNED_ROWS,
32};
33use sqlx::{
34 Database, Describe, Either, Error, Execute, Executor, IntoArguments,
35 query::{Map, Query, QueryAs, QueryScalar},
36};
37use tracing::Span;
38
39pub trait ExecuteExt: Sized {
45 #[must_use]
49 fn traced(self) -> Traced<Self> {
50 self.record(&Span::current())
51 }
52
53 #[must_use]
57 fn record(self, span: &Span) -> Traced<Self> {
58 Traced {
59 query: self,
60 span: span.clone(),
61 }
62 }
63}
64
65pin_project_lite::pin_project! {
66 struct RecordingStream<St, Db> {
69 #[pin]
70 inner: St,
71 span: Span,
72 database: std::marker::PhantomData<Db>,
73 start: Instant,
74 fetched: usize,
75 }
76}
77
78impl<Db: Database, St: Stream<Item = Result<Either<Db::QueryResult, Db::Row>, Error>>> Stream
79 for RecordingStream<St, Db>
80{
81 type Item = St::Item;
82
83 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St::Item>> {
84 let this = self.project();
85 let ret = match ready!(this.inner.poll_next(cx)) {
86 Some(Ok(Either::Left(query_result))) => Some(Ok(Either::Left(query_result))),
87 Some(Ok(Either::Right(row))) => {
88 *this.fetched += 1;
89 Some(Ok(Either::Right(row)))
90 }
91 Some(Err(err)) => Some(Err(err)),
92 None => {
94 let elapsed = this.start.elapsed();
95 this.span.record(DB_RESPONSE_RETURNED_ROWS, *this.fetched);
96 LogContext::maybe_record_query_stats(*this.fetched, elapsed);
97 None
98 }
99 };
100 Poll::Ready(ret)
101 }
102}
103
104#[derive(Debug)]
109struct RecordingExecutor<E> {
110 inner: E,
111 span: Span,
112}
113
114impl<E> RecordingExecutor<E> {
115 fn new(inner: E, span: Span) -> Self {
116 Self { inner, span }
117 }
118}
119
120impl<'c, E> Executor<'c> for RecordingExecutor<E>
121where
122 E: Executor<'c>,
123{
124 type Database = E::Database;
125
126 fn fetch_many<'e, 'q: 'e, Q>(
127 self,
128 query: Q,
129 ) -> BoxStream<
130 'e,
131 Result<
132 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
133 Error,
134 >,
135 >
136 where
137 'c: 'e,
138 Q: 'q + Execute<'q, E::Database>,
139 {
140 self.span.record(DB_QUERY_TEXT, query.sql());
141
142 RecordingStream {
143 inner: self.inner.fetch_many(query),
144 database: std::marker::PhantomData::<E::Database>,
145 span: self.span,
146 start: Instant::now(),
147 fetched: 0,
148 }
149 .boxed()
150 }
151
152 fn fetch_optional<'e, 'q: 'e, Q>(
153 self,
154 query: Q,
155 ) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
156 where
157 'c: 'e,
158 Q: 'q + Execute<'q, E::Database>,
159 {
160 self.span.record(DB_QUERY_TEXT, query.sql());
161 let inner = self.inner.fetch_optional(query);
162 async move {
163 let start = Instant::now();
164 let result = inner.await?;
165 #[expect(clippy::bool_to_int_with_if, reason = "clearer if explicit")]
166 let fetched = if result.is_some() { 1 } else { 0 };
167 self.span.record(DB_RESPONSE_RETURNED_ROWS, fetched);
168 LogContext::maybe_record_query_stats(fetched, start.elapsed());
169
170 Ok(result)
171 }
172 .boxed()
173 }
174
175 fn prepare_with<'e, 'q: 'e>(
176 self,
177 sql: &'q str,
178 parameters: &'e [<Self::Database as Database>::TypeInfo],
179 ) -> BoxFuture<'e, Result<<Self::Database as Database>::Statement<'q>, Error>>
180 where
181 'c: 'e,
182 {
183 self.inner.prepare_with(sql, parameters)
184 }
185
186 fn describe<'e, 'q: 'e>(
187 self,
188 sql: &'q str,
189 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
190 where
191 'c: 'e,
192 {
193 self.inner.describe(sql)
194 }
195}
196
197pub struct Traced<Q> {
199 query: Q,
200 span: Span,
201}
202
203impl<DB: Database, A> ExecuteExt for Query<'_, DB, A> {}
207impl<DB: Database, O, A> ExecuteExt for QueryAs<'_, DB, O, A> {}
208impl<DB: Database, O, A> ExecuteExt for QueryScalar<'_, DB, O, A> {}
209impl<DB: Database, F, A> ExecuteExt for Map<'_, DB, F, A> {}
210
211impl<'q, DB: Database, A> Traced<Query<'q, DB, A>>
217where
218 A: 'q + Send + IntoArguments<'q, DB>,
219{
220 pub async fn execute<'e, 'c, E>(self, executor: E) -> Result<DB::QueryResult, Error>
221 where
222 'c: 'e,
223 'q: 'e,
224 A: 'e,
225 E: Executor<'c, Database = DB>,
226 {
227 self.query
228 .execute(RecordingExecutor::new(executor, self.span))
229 .await
230 }
231
232 pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<DB::Row, Error>
233 where
234 'c: 'e,
235 'q: 'e,
236 A: 'e,
237 E: Executor<'c, Database = DB>,
238 {
239 self.query
240 .fetch_one(RecordingExecutor::new(executor, self.span))
241 .await
242 }
243
244 pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<DB::Row>, Error>
245 where
246 'c: 'e,
247 'q: 'e,
248 A: 'e,
249 E: Executor<'c, Database = DB>,
250 {
251 self.query
252 .fetch_optional(RecordingExecutor::new(executor, self.span))
253 .await
254 }
255
256 pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<DB::Row>, Error>
257 where
258 'c: 'e,
259 'q: 'e,
260 A: 'e,
261 E: Executor<'c, Database = DB>,
262 {
263 self.query
264 .fetch_all(RecordingExecutor::new(executor, self.span))
265 .await
266 }
267}
268
269impl<'q, DB: Database, F, O, A> Traced<Map<'q, DB, F, A>>
270where
271 F: FnMut(DB::Row) -> Result<O, Error> + Send,
272 O: Send + Unpin,
273 A: 'q + Send + IntoArguments<'q, DB>,
274{
275 pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<O, Error>
276 where
277 'c: 'e,
278 'q: 'e,
279 E: 'e + Executor<'c, Database = DB>,
280 F: 'e,
281 O: 'e,
282 {
283 self.query
284 .fetch_one(RecordingExecutor::new(executor, self.span))
285 .await
286 }
287
288 pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<O>, Error>
289 where
290 'c: 'e,
291 'q: 'e,
292 E: 'e + Executor<'c, Database = DB>,
293 F: 'e,
294 O: 'e,
295 {
296 self.query
297 .fetch_optional(RecordingExecutor::new(executor, self.span))
298 .await
299 }
300
301 pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<O>, Error>
302 where
303 'c: 'e,
304 'q: 'e,
305 E: 'e + Executor<'c, Database = DB>,
306 F: 'e,
307 O: 'e,
308 {
309 self.query
310 .fetch_all(RecordingExecutor::new(executor, self.span))
311 .await
312 }
313}
314
315impl<'q, DB: Database, O, A> Traced<QueryAs<'q, DB, O, A>>
316where
317 A: 'q + IntoArguments<'q, DB>,
318 O: Send + Unpin + for<'r> sqlx::FromRow<'r, DB::Row>,
319{
320 pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<O, Error>
321 where
322 'c: 'e,
323 'q: 'e,
324 O: 'e,
325 A: 'e,
326 E: 'e + Executor<'c, Database = DB>,
327 {
328 self.query
329 .fetch_one(RecordingExecutor::new(executor, self.span))
330 .await
331 }
332
333 pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<O>, Error>
334 where
335 'c: 'e,
336 'q: 'e,
337 O: 'e,
338 A: 'e,
339 E: 'e + Executor<'c, Database = DB>,
340 {
341 self.query
342 .fetch_optional(RecordingExecutor::new(executor, self.span))
343 .await
344 }
345
346 pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<O>, Error>
347 where
348 'c: 'e,
349 'q: 'e,
350 O: 'e,
351 A: 'e,
352 E: 'e + Executor<'c, Database = DB>,
353 {
354 self.query
355 .fetch_all(RecordingExecutor::new(executor, self.span))
356 .await
357 }
358}
359
360impl<'q, DB: Database, O, A> Traced<QueryScalar<'q, DB, O, A>>
361where
362 O: Send + Unpin,
363 A: 'q + IntoArguments<'q, DB>,
364 (O,): Send + Unpin + for<'r> sqlx::FromRow<'r, DB::Row>,
365{
366 pub async fn fetch_one<'e, 'c, E>(self, executor: E) -> Result<O, Error>
367 where
368 'c: 'e,
369 'q: 'e,
370 O: 'e,
371 A: 'e,
372 E: 'e + Executor<'c, Database = DB>,
373 {
374 self.query
375 .fetch_one(RecordingExecutor::new(executor, self.span))
376 .await
377 }
378
379 pub async fn fetch_optional<'e, 'c, E>(self, executor: E) -> Result<Option<O>, Error>
380 where
381 'c: 'e,
382 'q: 'e,
383 O: 'e,
384 A: 'e,
385 E: 'e + Executor<'c, Database = DB>,
386 {
387 self.query
388 .fetch_optional(RecordingExecutor::new(executor, self.span))
389 .await
390 }
391
392 pub async fn fetch_all<'e, 'c, E>(self, executor: E) -> Result<Vec<O>, Error>
393 where
394 'c: 'e,
395 'q: 'e,
396 O: 'e,
397 A: 'e,
398 E: 'e + Executor<'c, Database = DB>,
399 {
400 self.query
401 .fetch_all(RecordingExecutor::new(executor, self.span))
402 .await
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use mas_context::LogContext;
409 use sqlx::PgPool;
410
411 use crate::tracing::ExecuteExt;
412
413 #[sqlx::test]
416 async fn test_db_stats_recorded(pool: PgPool) {
417 let log_context = LogContext::new("test");
418 log_context
419 .run(|| async {
420 sqlx::query("SELECT 1")
421 .traced()
422 .fetch_one(&pool)
423 .await
424 .unwrap();
425
426 sqlx::query("SELECT 1 FROM UNNEST(ARRAY[1, 2, 3])")
427 .traced()
428 .fetch_all(&pool)
429 .await
430 .unwrap();
431 })
432 .await;
433
434 let stats = log_context.stats();
435 assert_eq!(stats.db_queries, 2);
436 assert_eq!(stats.db_rows_fetched, 4);
437 assert!(stats.to_string().contains("queries: 2, fetched: 4"));
438 }
439}