mas_storage/oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Device, Session, User};
12use oauth2_types::scope::Scope;
13use rand_core::RngCore;
14use ulid::Ulid;
15
16use crate::{Clock, Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum OAuth2SessionState {
20    Active,
21    Finished,
22}
23
24impl OAuth2SessionState {
25    pub fn is_active(self) -> bool {
26        matches!(self, Self::Active)
27    }
28
29    pub fn is_finished(self) -> bool {
30        matches!(self, Self::Finished)
31    }
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
35pub enum ClientKind {
36    Static,
37    Dynamic,
38}
39
40impl ClientKind {
41    pub fn is_static(self) -> bool {
42        matches!(self, Self::Static)
43    }
44}
45
46/// Filter parameters for listing OAuth 2.0 sessions
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
48pub struct OAuth2SessionFilter<'a> {
49    user: Option<&'a User>,
50    any_user: Option<bool>,
51    browser_session: Option<&'a BrowserSession>,
52    browser_session_filter: Option<BrowserSessionFilter<'a>>,
53    device: Option<&'a Device>,
54    client: Option<&'a Client>,
55    client_kind: Option<ClientKind>,
56    state: Option<OAuth2SessionState>,
57    scope: Option<&'a Scope>,
58    last_active_before: Option<DateTime<Utc>>,
59    last_active_after: Option<DateTime<Utc>>,
60}
61
62impl<'a> OAuth2SessionFilter<'a> {
63    /// Create a new [`OAuth2SessionFilter`] with default values
64    #[must_use]
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    /// List sessions for a specific user
70    #[must_use]
71    pub fn for_user(mut self, user: &'a User) -> Self {
72        self.user = Some(user);
73        self
74    }
75
76    /// Get the user filter
77    ///
78    /// Returns [`None`] if no user filter was set
79    #[must_use]
80    pub fn user(&self) -> Option<&'a User> {
81        self.user
82    }
83
84    /// List sessions which belong to any user
85    #[must_use]
86    pub fn for_any_user(mut self) -> Self {
87        self.any_user = Some(true);
88        self
89    }
90
91    /// List sessions which belong to no user
92    #[must_use]
93    pub fn for_no_user(mut self) -> Self {
94        self.any_user = Some(false);
95        self
96    }
97
98    /// Get the 'any user' filter
99    ///
100    /// Returns [`None`] if no 'any user' filter was set
101    #[must_use]
102    pub fn any_user(&self) -> Option<bool> {
103        self.any_user
104    }
105
106    /// List sessions started by a specific browser session
107    #[must_use]
108    pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
109        self.browser_session = Some(browser_session);
110        self
111    }
112
113    /// List sessions started by a set of browser sessions
114    #[must_use]
115    pub fn for_browser_sessions(
116        mut self,
117        browser_session_filter: BrowserSessionFilter<'a>,
118    ) -> Self {
119        self.browser_session_filter = Some(browser_session_filter);
120        self
121    }
122
123    /// Get the browser session filter
124    ///
125    /// Returns [`None`] if no browser session filter was set
126    #[must_use]
127    pub fn browser_session(&self) -> Option<&'a BrowserSession> {
128        self.browser_session
129    }
130
131    /// Get the browser sessions filter
132    ///
133    /// Returns [`None`] if no browser session filter was set
134    #[must_use]
135    pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
136        self.browser_session_filter
137    }
138
139    /// List sessions for a specific client
140    #[must_use]
141    pub fn for_client(mut self, client: &'a Client) -> Self {
142        self.client = Some(client);
143        self
144    }
145
146    /// Get the client filter
147    ///
148    /// Returns [`None`] if no client filter was set
149    #[must_use]
150    pub fn client(&self) -> Option<&'a Client> {
151        self.client
152    }
153
154    /// List only static clients
155    #[must_use]
156    pub fn only_static_clients(mut self) -> Self {
157        self.client_kind = Some(ClientKind::Static);
158        self
159    }
160
161    /// List only dynamic clients
162    #[must_use]
163    pub fn only_dynamic_clients(mut self) -> Self {
164        self.client_kind = Some(ClientKind::Dynamic);
165        self
166    }
167
168    /// Get the client kind filter
169    ///
170    /// Returns [`None`] if no client kind filter was set
171    #[must_use]
172    pub fn client_kind(&self) -> Option<ClientKind> {
173        self.client_kind
174    }
175
176    /// Only return sessions with a last active time before the given time
177    #[must_use]
178    pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
179        self.last_active_before = Some(last_active_before);
180        self
181    }
182
183    /// Only return sessions with a last active time after the given time
184    #[must_use]
185    pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
186        self.last_active_after = Some(last_active_after);
187        self
188    }
189
190    /// Get the last active before filter
191    ///
192    /// Returns [`None`] if no client filter was set
193    #[must_use]
194    pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
195        self.last_active_before
196    }
197
198    /// Get the last active after filter
199    ///
200    /// Returns [`None`] if no client filter was set
201    #[must_use]
202    pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
203        self.last_active_after
204    }
205
206    /// Only return active sessions
207    #[must_use]
208    pub fn active_only(mut self) -> Self {
209        self.state = Some(OAuth2SessionState::Active);
210        self
211    }
212
213    /// Only return finished sessions
214    #[must_use]
215    pub fn finished_only(mut self) -> Self {
216        self.state = Some(OAuth2SessionState::Finished);
217        self
218    }
219
220    /// Get the state filter
221    ///
222    /// Returns [`None`] if no state filter was set
223    #[must_use]
224    pub fn state(&self) -> Option<OAuth2SessionState> {
225        self.state
226    }
227
228    /// Only return sessions with the given scope
229    #[must_use]
230    pub fn with_scope(mut self, scope: &'a Scope) -> Self {
231        self.scope = Some(scope);
232        self
233    }
234
235    /// Get the scope filter
236    ///
237    /// Returns [`None`] if no scope filter was set
238    #[must_use]
239    pub fn scope(&self) -> Option<&'a Scope> {
240        self.scope
241    }
242
243    /// Only return sessions that have the given device in their scope
244    #[must_use]
245    pub fn for_device(mut self, device: &'a Device) -> Self {
246        self.device = Some(device);
247        self
248    }
249
250    /// Get the device filter
251    ///
252    /// Returns [`None`] if no device filter was set
253    #[must_use]
254    pub fn device(&self) -> Option<&'a Device> {
255        self.device
256    }
257}
258
259/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
260/// saved in the storage backend
261#[async_trait]
262pub trait OAuth2SessionRepository: Send + Sync {
263    /// The error type returned by the repository
264    type Error;
265
266    /// Lookup an [`Session`] by its ID
267    ///
268    /// Returns `None` if no [`Session`] was found
269    ///
270    /// # Parameters
271    ///
272    /// * `id`: The ID of the [`Session`] to lookup
273    ///
274    /// # Errors
275    ///
276    /// Returns [`Self::Error`] if the underlying repository fails
277    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
278
279    /// Create a new [`Session`] with the given parameters
280    ///
281    /// Returns the newly created [`Session`]
282    ///
283    /// # Parameters
284    ///
285    /// * `rng`: The random number generator to use
286    /// * `clock`: The clock used to generate timestamps
287    /// * `client`: The [`Client`] which created the [`Session`]
288    /// * `user`: The [`User`] for which the session should be created, if any
289    /// * `user_session`: The [`BrowserSession`] of the user which completed the
290    ///   authorization, if any
291    /// * `scope`: The [`Scope`] of the [`Session`]
292    ///
293    /// # Errors
294    ///
295    /// Returns [`Self::Error`] if the underlying repository fails
296    async fn add(
297        &mut self,
298        rng: &mut (dyn RngCore + Send),
299        clock: &dyn Clock,
300        client: &Client,
301        user: Option<&User>,
302        user_session: Option<&BrowserSession>,
303        scope: Scope,
304    ) -> Result<Session, Self::Error>;
305
306    /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
307    ///
308    /// Returns the newly created [`Session`]
309    ///
310    /// # Parameters
311    ///
312    /// * `rng`: The random number generator to use
313    /// * `clock`: The clock used to generate timestamps
314    /// * `client`: The [`Client`] which created the [`Session`]
315    /// * `user_session`: The [`BrowserSession`] of the user which completed the
316    ///   authorization
317    /// * `scope`: The [`Scope`] of the [`Session`]
318    ///
319    /// # Errors
320    ///
321    /// Returns [`Self::Error`] if the underlying repository fails
322    async fn add_from_browser_session(
323        &mut self,
324        rng: &mut (dyn RngCore + Send),
325        clock: &dyn Clock,
326        client: &Client,
327        user_session: &BrowserSession,
328        scope: Scope,
329    ) -> Result<Session, Self::Error> {
330        self.add(
331            rng,
332            clock,
333            client,
334            Some(&user_session.user),
335            Some(user_session),
336            scope,
337        )
338        .await
339    }
340
341    /// Create a new [`Session`] for a [`Client`] using the client credentials
342    /// flow
343    ///
344    /// Returns the newly created [`Session`]
345    ///
346    /// # Parameters
347    ///
348    /// * `rng`: The random number generator to use
349    /// * `clock`: The clock used to generate timestamps
350    /// * `client`: The [`Client`] which created the [`Session`]
351    /// * `scope`: The [`Scope`] of the [`Session`]
352    ///
353    /// # Errors
354    ///
355    /// Returns [`Self::Error`] if the underlying repository fails
356    async fn add_from_client_credentials(
357        &mut self,
358        rng: &mut (dyn RngCore + Send),
359        clock: &dyn Clock,
360        client: &Client,
361        scope: Scope,
362    ) -> Result<Session, Self::Error> {
363        self.add(rng, clock, client, None, None, scope).await
364    }
365
366    /// Mark a [`Session`] as finished
367    ///
368    /// Returns the updated [`Session`]
369    ///
370    /// # Parameters
371    ///
372    /// * `clock`: The clock used to generate timestamps
373    /// * `session`: The [`Session`] to mark as finished
374    ///
375    /// # Errors
376    ///
377    /// Returns [`Self::Error`] if the underlying repository fails
378    async fn finish(&mut self, clock: &dyn Clock, session: Session)
379    -> Result<Session, Self::Error>;
380
381    /// Mark all the [`Session`] matching the given filter as finished
382    ///
383    /// Returns the number of sessions affected
384    ///
385    /// # Parameters
386    ///
387    /// * `clock`: The clock used to generate timestamps
388    /// * `filter`: The filter parameters
389    ///
390    /// # Errors
391    ///
392    /// Returns [`Self::Error`] if the underlying repository fails
393    async fn finish_bulk(
394        &mut self,
395        clock: &dyn Clock,
396        filter: OAuth2SessionFilter<'_>,
397    ) -> Result<usize, Self::Error>;
398
399    /// List [`Session`]s matching the given filter and pagination parameters
400    ///
401    /// # Parameters
402    ///
403    /// * `filter`: The filter parameters
404    /// * `pagination`: The pagination parameters
405    ///
406    /// # Errors
407    ///
408    /// Returns [`Self::Error`] if the underlying repository fails
409    async fn list(
410        &mut self,
411        filter: OAuth2SessionFilter<'_>,
412        pagination: Pagination,
413    ) -> Result<Page<Session>, Self::Error>;
414
415    /// Count [`Session`]s matching the given filter
416    ///
417    /// # Parameters
418    ///
419    /// * `filter`: The filter parameters
420    ///
421    /// # Errors
422    ///
423    /// Returns [`Self::Error`] if the underlying repository fails
424    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
425
426    /// Record a batch of [`Session`] activity
427    ///
428    /// # Parameters
429    ///
430    /// * `activity`: A list of tuples containing the session ID, the last
431    ///   activity timestamp and the IP address of the client
432    ///
433    /// # Errors
434    ///
435    /// Returns [`Self::Error`] if the underlying repository fails
436    async fn record_batch_activity(
437        &mut self,
438        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
439    ) -> Result<(), Self::Error>;
440
441    /// Record the user agent of a [`Session`]
442    ///
443    /// # Parameters
444    ///
445    /// * `session`: The [`Session`] to record the user agent for
446    /// * `user_agent`: The user agent to record
447    async fn record_user_agent(
448        &mut self,
449        session: Session,
450        user_agent: String,
451    ) -> Result<Session, Self::Error>;
452
453    /// Set the human name of a [`Session`]
454    ///
455    /// # Parameters
456    ///
457    /// * `session`: The [`Session`] to set the human name for
458    /// * `human_name`: The human name to set
459    async fn set_human_name(
460        &mut self,
461        session: Session,
462        human_name: Option<String>,
463    ) -> Result<Session, Self::Error>;
464}
465
466repository_impl!(OAuth2SessionRepository:
467    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
468
469    async fn add(
470        &mut self,
471        rng: &mut (dyn RngCore + Send),
472        clock: &dyn Clock,
473        client: &Client,
474        user: Option<&User>,
475        user_session: Option<&BrowserSession>,
476        scope: Scope,
477    ) -> Result<Session, Self::Error>;
478
479    async fn add_from_browser_session(
480        &mut self,
481        rng: &mut (dyn RngCore + Send),
482        clock: &dyn Clock,
483        client: &Client,
484        user_session: &BrowserSession,
485        scope: Scope,
486    ) -> Result<Session, Self::Error>;
487
488    async fn add_from_client_credentials(
489        &mut self,
490        rng: &mut (dyn RngCore + Send),
491        clock: &dyn Clock,
492        client: &Client,
493        scope: Scope,
494    ) -> Result<Session, Self::Error>;
495
496    async fn finish(&mut self, clock: &dyn Clock, session: Session)
497        -> Result<Session, Self::Error>;
498
499    async fn finish_bulk(
500        &mut self,
501        clock: &dyn Clock,
502        filter: OAuth2SessionFilter<'_>,
503    ) -> Result<usize, Self::Error>;
504
505    async fn list(
506        &mut self,
507        filter: OAuth2SessionFilter<'_>,
508        pagination: Pagination,
509    ) -> Result<Page<Session>, Self::Error>;
510
511    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
512
513    async fn record_batch_activity(
514        &mut self,
515        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
516    ) -> Result<(), Self::Error>;
517
518    async fn record_user_agent(
519        &mut self,
520        session: Session,
521        user_agent: String,
522    ) -> Result<Session, Self::Error>;
523
524    async fn set_human_name(
525        &mut self,
526        session: Session,
527        human_name: Option<String>,
528    ) -> Result<Session, Self::Error>;
529);