mas_data_model/upstream_oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use chrono::{DateTime, Utc};
8use serde::Serialize;
9use ulid::Ulid;
10
11use super::UpstreamOAuthLink;
12use crate::InvalidTransitionError;
13
14#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
15pub enum UpstreamOAuthAuthorizationSessionState {
16    #[default]
17    Pending,
18    Completed {
19        completed_at: DateTime<Utc>,
20        link_id: Ulid,
21        id_token: Option<String>,
22        id_token_claims: Option<serde_json::Value>,
23        extra_callback_parameters: Option<serde_json::Value>,
24        userinfo: Option<serde_json::Value>,
25    },
26    Consumed {
27        completed_at: DateTime<Utc>,
28        consumed_at: DateTime<Utc>,
29        link_id: Ulid,
30        id_token: Option<String>,
31        id_token_claims: Option<serde_json::Value>,
32        extra_callback_parameters: Option<serde_json::Value>,
33        userinfo: Option<serde_json::Value>,
34    },
35    Unlinked {
36        completed_at: DateTime<Utc>,
37        consumed_at: Option<DateTime<Utc>>,
38        unlinked_at: DateTime<Utc>,
39        id_token: Option<String>,
40        id_token_claims: Option<serde_json::Value>,
41    },
42}
43
44impl UpstreamOAuthAuthorizationSessionState {
45    /// Mark the upstream OAuth 2.0 authorization session as completed.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error if the upstream OAuth 2.0 authorization session state
50    /// is not [`Pending`].
51    ///
52    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
53    pub fn complete(
54        self,
55        completed_at: DateTime<Utc>,
56        link: &UpstreamOAuthLink,
57        id_token: Option<String>,
58        id_token_claims: Option<serde_json::Value>,
59        extra_callback_parameters: Option<serde_json::Value>,
60        userinfo: Option<serde_json::Value>,
61    ) -> Result<Self, InvalidTransitionError> {
62        match self {
63            Self::Pending => Ok(Self::Completed {
64                completed_at,
65                link_id: link.id,
66                id_token,
67                id_token_claims,
68                extra_callback_parameters,
69                userinfo,
70            }),
71            Self::Completed { .. } | Self::Consumed { .. } | Self::Unlinked { .. } => {
72                Err(InvalidTransitionError)
73            }
74        }
75    }
76
77    /// Mark the upstream OAuth 2.0 authorization session as consumed.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the upstream OAuth 2.0 authorization session state
82    /// is not [`Completed`].
83    ///
84    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
85    pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
86        match self {
87            Self::Completed {
88                completed_at,
89                link_id,
90                id_token,
91                id_token_claims,
92                extra_callback_parameters,
93                userinfo,
94            } => Ok(Self::Consumed {
95                completed_at,
96                link_id,
97                consumed_at,
98                id_token,
99                id_token_claims,
100                extra_callback_parameters,
101                userinfo,
102            }),
103            Self::Pending | Self::Consumed { .. } | Self::Unlinked { .. } => {
104                Err(InvalidTransitionError)
105            }
106        }
107    }
108
109    /// Get the link ID for the upstream OAuth 2.0 authorization session.
110    ///
111    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
112    /// [`Pending`].
113    ///
114    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
115    #[must_use]
116    pub fn link_id(&self) -> Option<Ulid> {
117        match self {
118            Self::Pending | Self::Unlinked { .. } => None,
119            Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id),
120        }
121    }
122
123    /// Get the time at which the upstream OAuth 2.0 authorization session was
124    /// completed.
125    ///
126    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
127    /// [`Pending`].
128    ///
129    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
130    #[must_use]
131    pub fn completed_at(&self) -> Option<DateTime<Utc>> {
132        match self {
133            Self::Pending => None,
134            Self::Completed { completed_at, .. }
135            | Self::Consumed { completed_at, .. }
136            | Self::Unlinked { completed_at, .. } => Some(*completed_at),
137        }
138    }
139
140    /// Get the ID token for the upstream OAuth 2.0 authorization session.
141    ///
142    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
143    /// [`Pending`].
144    ///
145    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
146    #[must_use]
147    pub fn id_token(&self) -> Option<&str> {
148        match self {
149            Self::Pending => None,
150            Self::Completed { id_token, .. }
151            | Self::Consumed { id_token, .. }
152            | Self::Unlinked { id_token, .. } => id_token.as_deref(),
153        }
154    }
155
156    /// Get the ID token claims for the upstream OAuth 2.0 authorization
157    /// session.
158    ///
159    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
160    /// not [`Pending`].
161    ///
162    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
163    #[must_use]
164    pub fn id_token_claims(&self) -> Option<&serde_json::Value> {
165        match self {
166            Self::Pending => None,
167            Self::Completed {
168                id_token_claims, ..
169            }
170            | Self::Consumed {
171                id_token_claims, ..
172            }
173            | Self::Unlinked {
174                id_token_claims, ..
175            } => id_token_claims.as_ref(),
176        }
177    }
178
179    /// Get the extra query parameters that were sent to the upstream provider.
180    ///
181    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
182    /// not [`Pending`].
183    ///
184    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
185    #[must_use]
186    pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
187        match self {
188            Self::Pending | Self::Unlinked { .. } => None,
189            Self::Completed {
190                extra_callback_parameters,
191                ..
192            }
193            | Self::Consumed {
194                extra_callback_parameters,
195                ..
196            } => extra_callback_parameters.as_ref(),
197        }
198    }
199
200    #[must_use]
201    pub fn userinfo(&self) -> Option<&serde_json::Value> {
202        match self {
203            Self::Pending | Self::Unlinked { .. } => None,
204            Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
205        }
206    }
207
208    /// Get the time at which the upstream OAuth 2.0 authorization session was
209    /// consumed.
210    ///
211    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
212    /// not [`Consumed`].
213    ///
214    /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
215    #[must_use]
216    pub fn consumed_at(&self) -> Option<DateTime<Utc>> {
217        match self {
218            Self::Pending | Self::Completed { .. } => None,
219            Self::Consumed { consumed_at, .. } => Some(*consumed_at),
220            Self::Unlinked { consumed_at, .. } => *consumed_at,
221        }
222    }
223
224    /// Get the time at which the upstream OAuth 2.0 authorization session was
225    /// unlinked.
226    ///
227    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
228    /// not [`Unlinked`].
229    ///
230    /// [`Unlinked`]: UpstreamOAuthAuthorizationSessionState::Unlinked
231    #[must_use]
232    pub fn unlinked_at(&self) -> Option<DateTime<Utc>> {
233        match self {
234            Self::Pending | Self::Completed { .. } | Self::Consumed { .. } => None,
235            Self::Unlinked { unlinked_at, .. } => Some(*unlinked_at),
236        }
237    }
238
239    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
240    /// [`Pending`].
241    ///
242    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
243    #[must_use]
244    pub fn is_pending(&self) -> bool {
245        matches!(self, Self::Pending)
246    }
247
248    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
249    /// [`Completed`].
250    ///
251    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
252    #[must_use]
253    pub fn is_completed(&self) -> bool {
254        matches!(self, Self::Completed { .. })
255    }
256
257    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
258    /// [`Consumed`].
259    ///
260    /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
261    #[must_use]
262    pub fn is_consumed(&self) -> bool {
263        matches!(self, Self::Consumed { .. })
264    }
265
266    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
267    /// [`Unlinked`].
268    ///
269    /// [`Unlinked`]: UpstreamOAuthAuthorizationSessionState::Unlinked
270    #[must_use]
271    pub fn is_unlinked(&self) -> bool {
272        matches!(self, Self::Unlinked { .. })
273    }
274}
275
276#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
277pub struct UpstreamOAuthAuthorizationSession {
278    pub id: Ulid,
279    pub state: UpstreamOAuthAuthorizationSessionState,
280    pub provider_id: Ulid,
281    pub state_str: String,
282    pub code_challenge_verifier: Option<String>,
283    pub nonce: Option<String>,
284    pub created_at: DateTime<Utc>,
285}
286
287impl std::ops::Deref for UpstreamOAuthAuthorizationSession {
288    type Target = UpstreamOAuthAuthorizationSessionState;
289
290    fn deref(&self) -> &Self::Target {
291        &self.state
292    }
293}
294
295impl UpstreamOAuthAuthorizationSession {
296    /// Mark the upstream OAuth 2.0 authorization session as completed. Returns
297    /// the updated session.
298    ///
299    /// # Errors
300    ///
301    /// Returns an error if the upstream OAuth 2.0 authorization session state
302    /// is not [`Pending`].
303    ///
304    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
305    pub fn complete(
306        mut self,
307        completed_at: DateTime<Utc>,
308        link: &UpstreamOAuthLink,
309        id_token: Option<String>,
310        id_token_claims: Option<serde_json::Value>,
311        extra_callback_parameters: Option<serde_json::Value>,
312        userinfo: Option<serde_json::Value>,
313    ) -> Result<Self, InvalidTransitionError> {
314        self.state = self.state.complete(
315            completed_at,
316            link,
317            id_token,
318            id_token_claims,
319            extra_callback_parameters,
320            userinfo,
321        )?;
322        Ok(self)
323    }
324
325    /// Mark the upstream OAuth 2.0 authorization session as consumed. Returns
326    /// the updated session.
327    ///
328    /// # Errors
329    ///
330    /// Returns an error if the upstream OAuth 2.0 authorization session state
331    /// is not [`Completed`].
332    ///
333    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
334    pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
335        self.state = self.state.consume(consumed_at)?;
336        Ok(self)
337    }
338}