mas_handlers/activity_tracker/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7mod bound;
8mod worker;
9
10use std::net::IpAddr;
11
12use chrono::{DateTime, Utc};
13use mas_data_model::{BrowserSession, Clock, CompatSession, Session};
14use mas_storage::BoxRepositoryFactory;
15use tokio_util::{sync::CancellationToken, task::TaskTracker};
16use ulid::Ulid;
17
18pub use self::bound::Bound;
19use self::worker::Worker;
20
21static MESSAGE_QUEUE_SIZE: usize = 1000;
22
23#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)]
24enum SessionKind {
25    OAuth2,
26    Compat,
27    /// Session associated with personal access tokens
28    Personal,
29    Browser,
30}
31
32impl SessionKind {
33    const fn as_str(self) -> &'static str {
34        match self {
35            SessionKind::OAuth2 => "oauth2",
36            SessionKind::Compat => "compat",
37            SessionKind::Personal => "personal",
38            SessionKind::Browser => "browser",
39        }
40    }
41}
42
43enum Message {
44    Record {
45        kind: SessionKind,
46        id: Ulid,
47        date_time: DateTime<Utc>,
48        ip: Option<IpAddr>,
49    },
50    Flush(tokio::sync::oneshot::Sender<()>),
51}
52
53#[derive(Clone)]
54pub struct ActivityTracker {
55    channel: tokio::sync::mpsc::Sender<Message>,
56}
57
58impl ActivityTracker {
59    /// Create a new activity tracker
60    ///
61    /// It will spawn the background worker and a loop to flush the tracker on
62    /// the task tracker, and both will shut themselves down, flushing one last
63    /// time, when the cancellation token is cancelled.
64    #[must_use]
65    pub fn new(
66        repository_factory: BoxRepositoryFactory,
67        flush_interval: std::time::Duration,
68        task_tracker: &TaskTracker,
69        cancellation_token: CancellationToken,
70    ) -> Self {
71        let worker = Worker::new(repository_factory);
72        let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
73        let tracker = ActivityTracker { channel: sender };
74
75        // Spawn the flush loop and the worker
76        task_tracker.spawn(
77            tracker
78                .clone()
79                .flush_loop(flush_interval, cancellation_token.clone()),
80        );
81        task_tracker.spawn(worker.run(receiver, cancellation_token));
82
83        tracker
84    }
85
86    /// Bind the activity tracker to an IP address.
87    #[must_use]
88    pub fn bind(self, ip: Option<IpAddr>) -> Bound {
89        Bound::new(self, ip)
90    }
91
92    /// Record activity in an OAuth 2.0 session.
93    pub async fn record_oauth2_session(
94        &self,
95        clock: &dyn Clock,
96        session: &Session,
97        ip: Option<IpAddr>,
98    ) {
99        let res = self
100            .channel
101            .send(Message::Record {
102                kind: SessionKind::OAuth2,
103                id: session.id,
104                date_time: clock.now(),
105                ip,
106            })
107            .await;
108
109        if let Err(e) = res {
110            tracing::error!("Failed to record OAuth2 session: {}", e);
111        }
112    }
113
114    /// Record activity in a personal access token session.
115    pub async fn record_personal_access_token_session(
116        &self,
117        clock: &dyn Clock,
118        session: &Session,
119        ip: Option<IpAddr>,
120    ) {
121        let res = self
122            .channel
123            .send(Message::Record {
124                kind: SessionKind::Personal,
125                id: session.id,
126                date_time: clock.now(),
127                ip,
128            })
129            .await;
130
131        if let Err(e) = res {
132            tracing::error!("Failed to record Personal session: {}", e);
133        }
134    }
135
136    /// Record activity in a compat session.
137    pub async fn record_compat_session(
138        &self,
139        clock: &dyn Clock,
140        compat_session: &CompatSession,
141        ip: Option<IpAddr>,
142    ) {
143        let res = self
144            .channel
145            .send(Message::Record {
146                kind: SessionKind::Compat,
147                id: compat_session.id,
148                date_time: clock.now(),
149                ip,
150            })
151            .await;
152
153        if let Err(e) = res {
154            tracing::error!("Failed to record compat session: {}", e);
155        }
156    }
157
158    /// Record activity in a browser session.
159    pub async fn record_browser_session(
160        &self,
161        clock: &dyn Clock,
162        browser_session: &BrowserSession,
163        ip: Option<IpAddr>,
164    ) {
165        let res = self
166            .channel
167            .send(Message::Record {
168                kind: SessionKind::Browser,
169                id: browser_session.id,
170                date_time: clock.now(),
171                ip,
172            })
173            .await;
174
175        if let Err(e) = res {
176            tracing::error!("Failed to record browser session: {}", e);
177        }
178    }
179
180    /// Manually flush the activity tracker.
181    pub async fn flush(&self) {
182        let (tx, rx) = tokio::sync::oneshot::channel();
183        let res = self.channel.send(Message::Flush(tx)).await;
184
185        match res {
186            Ok(()) => {
187                if let Err(e) = rx.await {
188                    tracing::error!(
189                        error = &e as &dyn std::error::Error,
190                        "Failed to flush activity tracker"
191                    );
192                }
193            }
194            Err(e) => {
195                tracing::error!(
196                    error = &e as &dyn std::error::Error,
197                    "Failed to flush activity tracker"
198                );
199            }
200        }
201    }
202
203    /// Regularly flush the activity tracker.
204    async fn flush_loop(
205        self,
206        interval: std::time::Duration,
207        cancellation_token: CancellationToken,
208    ) {
209        // This guard on the shutdown token is to ensure that if this task crashes for
210        // any reason, the server will shut down
211        let _guard = cancellation_token.clone().drop_guard();
212        let mut interval = tokio::time::interval(interval);
213        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
214
215        loop {
216            tokio::select! {
217                biased;
218
219                () = cancellation_token.cancelled() => {
220                    // The cancellation token was cancelled, so we should exit
221                    return;
222                }
223
224                // First check if the channel is closed, then check if the timer expired
225                () = self.channel.closed() => {
226                    // The channel was closed, so we should exit
227                    return;
228                }
229
230
231                _ = interval.tick() => {
232                    self.flush().await;
233                }
234            }
235        }
236    }
237}