mas_storage_pg/queue/
worker.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
// Copyright 2024 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

//! A module containing the PostgreSQL implementation of the
//! [`QueueWorkerRepository`].

use async_trait::async_trait;
use chrono::Duration;
use mas_storage::{
    queue::{QueueWorkerRepository, Worker},
    Clock,
};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;

use crate::{DatabaseError, ExecuteExt};

/// An implementation of [`QueueWorkerRepository`] for a PostgreSQL connection.
pub struct PgQueueWorkerRepository<'c> {
    conn: &'c mut PgConnection,
}

impl<'c> PgQueueWorkerRepository<'c> {
    /// Create a new [`PgQueueWorkerRepository`] from an active PostgreSQL
    /// connection.
    #[must_use]
    pub fn new(conn: &'c mut PgConnection) -> Self {
        Self { conn }
    }
}

#[async_trait]
impl QueueWorkerRepository for PgQueueWorkerRepository<'_> {
    type Error = DatabaseError;

    #[tracing::instrument(
        name = "db.queue_worker.register",
        skip_all,
        fields(
            worker.id,
            db.query.text,
        ),
        err,
    )]
    async fn register(
        &mut self,
        rng: &mut (dyn RngCore + Send),
        clock: &dyn Clock,
    ) -> Result<Worker, Self::Error> {
        let now = clock.now();
        let worker_id = Ulid::from_datetime_with_source(now.into(), rng);
        tracing::Span::current().record("worker.id", tracing::field::display(worker_id));

        sqlx::query!(
            r#"
                INSERT INTO queue_workers (queue_worker_id, registered_at, last_seen_at)
                VALUES ($1, $2, $2)
            "#,
            Uuid::from(worker_id),
            now,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        Ok(Worker { id: worker_id })
    }

    #[tracing::instrument(
        name = "db.queue_worker.heartbeat",
        skip_all,
        fields(
            %worker.id,
            db.query.text,
        ),
        err,
    )]
    async fn heartbeat(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
        let now = clock.now();
        let res = sqlx::query!(
            r#"
                UPDATE queue_workers
                SET last_seen_at = $2
                WHERE queue_worker_id = $1 AND shutdown_at IS NULL
            "#,
            Uuid::from(worker.id),
            now,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        // If no row was updated, the worker was shutdown so we return an error
        DatabaseError::ensure_affected_rows(&res, 1)?;

        Ok(())
    }

    #[tracing::instrument(
        name = "db.queue_worker.shutdown",
        skip_all,
        fields(
            %worker.id,
            db.query.text,
        ),
        err,
    )]
    async fn shutdown(&mut self, clock: &dyn Clock, worker: &Worker) -> Result<(), Self::Error> {
        let now = clock.now();
        let res = sqlx::query!(
            r#"
                UPDATE queue_workers
                SET shutdown_at = $2
                WHERE queue_worker_id = $1
            "#,
            Uuid::from(worker.id),
            now,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        DatabaseError::ensure_affected_rows(&res, 1)?;

        // Remove the leader lease if we were holding it
        let res = sqlx::query!(
            r#"
                DELETE FROM queue_leader
                WHERE queue_worker_id = $1
            "#,
            Uuid::from(worker.id),
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        // If we were holding the leader lease, notify workers
        if res.rows_affected() > 0 {
            sqlx::query!(
                r#"
                    NOTIFY queue_leader_stepdown
                "#,
            )
            .traced()
            .execute(&mut *self.conn)
            .await?;
        }

        Ok(())
    }

    #[tracing::instrument(
        name = "db.queue_worker.shutdown_dead_workers",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn shutdown_dead_workers(
        &mut self,
        clock: &dyn Clock,
        threshold: Duration,
    ) -> Result<(), Self::Error> {
        // Here the threshold is usually set to a few minutes, so we don't need to use
        // the database time, as we can assume worker clocks have less than a minute
        // skew between each other, else other things would break
        let now = clock.now();
        sqlx::query!(
            r#"
                UPDATE queue_workers
                SET shutdown_at = $1
                WHERE shutdown_at IS NULL
                  AND last_seen_at < $2
            "#,
            now,
            now - threshold,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        Ok(())
    }

    #[tracing::instrument(
        name = "db.queue_worker.remove_leader_lease_if_expired",
        skip_all,
        fields(
            db.query.text,
        ),
        err,
    )]
    async fn remove_leader_lease_if_expired(
        &mut self,
        _clock: &dyn Clock,
    ) -> Result<(), Self::Error> {
        // `expires_at` is a rare exception where we use the database time, as this
        // would be very sensitive to clock skew between workers
        sqlx::query!(
            r#"
                DELETE FROM queue_leader
                WHERE expires_at < NOW()
            "#,
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        Ok(())
    }

    #[tracing::instrument(
        name = "db.queue_worker.try_get_leader_lease",
        skip_all,
        fields(
            %worker.id,
            db.query.text,
        ),
        err,
    )]
    async fn try_get_leader_lease(
        &mut self,
        clock: &dyn Clock,
        worker: &Worker,
    ) -> Result<bool, Self::Error> {
        let now = clock.now();
        // The queue_leader table is meant to only have a single row, which conflicts on
        // the `active` column

        // If there is a conflict, we update the `expires_at` column ONLY IF the current
        // leader is ourselves.

        // `expires_at` is a rare exception where we use the database time, as this
        // would be very sensitive to clock skew between workers
        let res = sqlx::query!(
            r#"
                INSERT INTO queue_leader (elected_at, expires_at, queue_worker_id)
                VALUES ($1, NOW() + INTERVAL '5 seconds', $2)
                ON CONFLICT (active)
                DO UPDATE SET expires_at = EXCLUDED.expires_at
                WHERE queue_leader.queue_worker_id = $2
            "#,
            now,
            Uuid::from(worker.id)
        )
        .traced()
        .execute(&mut *self.conn)
        .await?;

        // We can then detect whether we are the leader or not by checking how many rows
        // were affected by the upsert
        let am_i_the_leader = res.rows_affected() == 1;

        Ok(am_i_the_leader)
    }
}