mas_storage_pg/compat/
sso_login.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState};
10use mas_storage::{
11    Clock, Page, Pagination,
12    compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
13};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{
23    DatabaseError, DatabaseInconsistencyError,
24    filter::{Filter, StatementExt},
25    iden::{CompatSessions, CompatSsoLogins},
26    pagination::QueryBuilderExt,
27    tracing::ExecuteExt,
28};
29
30/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL
31/// connection
32pub struct PgCompatSsoLoginRepository<'c> {
33    conn: &'c mut PgConnection,
34}
35
36impl<'c> PgCompatSsoLoginRepository<'c> {
37    /// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL
38    /// connection
39    pub fn new(conn: &'c mut PgConnection) -> Self {
40        Self { conn }
41    }
42}
43
44#[derive(sqlx::FromRow)]
45#[enum_def]
46struct CompatSsoLoginLookup {
47    compat_sso_login_id: Uuid,
48    login_token: String,
49    redirect_uri: String,
50    created_at: DateTime<Utc>,
51    fulfilled_at: Option<DateTime<Utc>>,
52    exchanged_at: Option<DateTime<Utc>>,
53    compat_session_id: Option<Uuid>,
54}
55
56impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
57    type Error = DatabaseInconsistencyError;
58
59    fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
60        let id = res.compat_sso_login_id.into();
61        let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
62            DatabaseInconsistencyError::on("compat_sso_logins")
63                .column("redirect_uri")
64                .row(id)
65                .source(e)
66        })?;
67
68        let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
69            (None, None, None) => CompatSsoLoginState::Pending,
70            (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
71                fulfilled_at,
72                session_id: session_id.into(),
73            },
74            (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
75                CompatSsoLoginState::Exchanged {
76                    fulfilled_at,
77                    exchanged_at,
78                    session_id: session_id.into(),
79                }
80            }
81            _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
82        };
83
84        Ok(CompatSsoLogin {
85            id,
86            login_token: res.login_token,
87            redirect_uri,
88            created_at: res.created_at,
89            state,
90        })
91    }
92}
93
94impl Filter for CompatSsoLoginFilter<'_> {
95    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
96        sea_query::Condition::all()
97            .add_option(self.user().map(|user| {
98                Expr::exists(
99                    Query::select()
100                        .expr(Expr::cust("1"))
101                        .from(CompatSessions::Table)
102                        .and_where(
103                            Expr::col((CompatSessions::Table, CompatSessions::UserId))
104                                .eq(Uuid::from(user.id)),
105                        )
106                        .and_where(
107                            Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))
108                                .equals((CompatSessions::Table, CompatSessions::CompatSessionId)),
109                        )
110                        .take(),
111                )
112            }))
113            .add_option(self.state().map(|state| {
114                if state.is_exchanged() {
115                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
116                } else if state.is_fulfilled() {
117                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
118                        .is_not_null()
119                        .and(
120                            Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
121                                .is_null(),
122                        )
123                } else {
124                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
125                }
126            }))
127    }
128}
129
130#[async_trait]
131impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> {
132    type Error = DatabaseError;
133
134    #[tracing::instrument(
135        name = "db.compat_sso_login.lookup",
136        skip_all,
137        fields(
138            db.query.text,
139            compat_sso_login.id = %id,
140        ),
141        err,
142    )]
143    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
144        let res = sqlx::query_as!(
145            CompatSsoLoginLookup,
146            r#"
147                SELECT compat_sso_login_id
148                     , login_token
149                     , redirect_uri
150                     , created_at
151                     , fulfilled_at
152                     , exchanged_at
153                     , compat_session_id
154
155                FROM compat_sso_logins
156                WHERE compat_sso_login_id = $1
157            "#,
158            Uuid::from(id),
159        )
160        .traced()
161        .fetch_optional(&mut *self.conn)
162        .await?;
163
164        let Some(res) = res else { return Ok(None) };
165
166        Ok(Some(res.try_into()?))
167    }
168
169    #[tracing::instrument(
170        name = "db.compat_sso_login.find_for_session",
171        skip_all,
172        fields(
173            db.query.text,
174            %compat_session.id,
175        ),
176        err,
177    )]
178    async fn find_for_session(
179        &mut self,
180        compat_session: &CompatSession,
181    ) -> Result<Option<CompatSsoLogin>, Self::Error> {
182        let res = sqlx::query_as!(
183            CompatSsoLoginLookup,
184            r#"
185                SELECT compat_sso_login_id
186                     , login_token
187                     , redirect_uri
188                     , created_at
189                     , fulfilled_at
190                     , exchanged_at
191                     , compat_session_id
192
193                FROM compat_sso_logins
194                WHERE compat_session_id = $1
195            "#,
196            Uuid::from(compat_session.id),
197        )
198        .traced()
199        .fetch_optional(&mut *self.conn)
200        .await?;
201
202        let Some(res) = res else { return Ok(None) };
203
204        Ok(Some(res.try_into()?))
205    }
206
207    #[tracing::instrument(
208        name = "db.compat_sso_login.find_by_token",
209        skip_all,
210        fields(
211            db.query.text,
212        ),
213        err,
214    )]
215    async fn find_by_token(
216        &mut self,
217        login_token: &str,
218    ) -> Result<Option<CompatSsoLogin>, Self::Error> {
219        let res = sqlx::query_as!(
220            CompatSsoLoginLookup,
221            r#"
222                SELECT compat_sso_login_id
223                     , login_token
224                     , redirect_uri
225                     , created_at
226                     , fulfilled_at
227                     , exchanged_at
228                     , compat_session_id
229
230                FROM compat_sso_logins
231                WHERE login_token = $1
232            "#,
233            login_token,
234        )
235        .traced()
236        .fetch_optional(&mut *self.conn)
237        .await?;
238
239        let Some(res) = res else { return Ok(None) };
240
241        Ok(Some(res.try_into()?))
242    }
243
244    #[tracing::instrument(
245        name = "db.compat_sso_login.add",
246        skip_all,
247        fields(
248            db.query.text,
249            compat_sso_login.id,
250            compat_sso_login.redirect_uri = %redirect_uri,
251        ),
252        err,
253    )]
254    async fn add(
255        &mut self,
256        rng: &mut (dyn RngCore + Send),
257        clock: &dyn Clock,
258        login_token: String,
259        redirect_uri: Url,
260    ) -> Result<CompatSsoLogin, Self::Error> {
261        let created_at = clock.now();
262        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
263        tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
264
265        sqlx::query!(
266            r#"
267                INSERT INTO compat_sso_logins
268                    (compat_sso_login_id, login_token, redirect_uri, created_at)
269                VALUES ($1, $2, $3, $4)
270            "#,
271            Uuid::from(id),
272            &login_token,
273            redirect_uri.as_str(),
274            created_at,
275        )
276        .traced()
277        .execute(&mut *self.conn)
278        .await?;
279
280        Ok(CompatSsoLogin {
281            id,
282            login_token,
283            redirect_uri,
284            created_at,
285            state: CompatSsoLoginState::default(),
286        })
287    }
288
289    #[tracing::instrument(
290        name = "db.compat_sso_login.fulfill",
291        skip_all,
292        fields(
293            db.query.text,
294            %compat_sso_login.id,
295            %compat_session.id,
296            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
297            user.id = %compat_session.user_id,
298        ),
299        err,
300    )]
301    async fn fulfill(
302        &mut self,
303        clock: &dyn Clock,
304        compat_sso_login: CompatSsoLogin,
305        compat_session: &CompatSession,
306    ) -> Result<CompatSsoLogin, Self::Error> {
307        let fulfilled_at = clock.now();
308        let compat_sso_login = compat_sso_login
309            .fulfill(fulfilled_at, compat_session)
310            .map_err(DatabaseError::to_invalid_operation)?;
311
312        let res = sqlx::query!(
313            r#"
314                UPDATE compat_sso_logins
315                SET
316                    compat_session_id = $2,
317                    fulfilled_at = $3
318                WHERE
319                    compat_sso_login_id = $1
320            "#,
321            Uuid::from(compat_sso_login.id),
322            Uuid::from(compat_session.id),
323            fulfilled_at,
324        )
325        .traced()
326        .execute(&mut *self.conn)
327        .await?;
328
329        DatabaseError::ensure_affected_rows(&res, 1)?;
330
331        Ok(compat_sso_login)
332    }
333
334    #[tracing::instrument(
335        name = "db.compat_sso_login.exchange",
336        skip_all,
337        fields(
338            db.query.text,
339            %compat_sso_login.id,
340        ),
341        err,
342    )]
343    async fn exchange(
344        &mut self,
345        clock: &dyn Clock,
346        compat_sso_login: CompatSsoLogin,
347    ) -> Result<CompatSsoLogin, Self::Error> {
348        let exchanged_at = clock.now();
349        let compat_sso_login = compat_sso_login
350            .exchange(exchanged_at)
351            .map_err(DatabaseError::to_invalid_operation)?;
352
353        let res = sqlx::query!(
354            r#"
355                UPDATE compat_sso_logins
356                SET
357                    exchanged_at = $2
358                WHERE
359                    compat_sso_login_id = $1
360            "#,
361            Uuid::from(compat_sso_login.id),
362            exchanged_at,
363        )
364        .traced()
365        .execute(&mut *self.conn)
366        .await?;
367
368        DatabaseError::ensure_affected_rows(&res, 1)?;
369
370        Ok(compat_sso_login)
371    }
372
373    #[tracing::instrument(
374        name = "db.compat_sso_login.list",
375        skip_all,
376        fields(
377            db.query.text,
378        ),
379        err
380    )]
381    async fn list(
382        &mut self,
383        filter: CompatSsoLoginFilter<'_>,
384        pagination: Pagination,
385    ) -> Result<Page<CompatSsoLogin>, Self::Error> {
386        let (sql, arguments) = Query::select()
387            .expr_as(
388                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
389                CompatSsoLoginLookupIden::CompatSsoLoginId,
390            )
391            .expr_as(
392                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
393                CompatSsoLoginLookupIden::CompatSessionId,
394            )
395            .expr_as(
396                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
397                CompatSsoLoginLookupIden::LoginToken,
398            )
399            .expr_as(
400                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
401                CompatSsoLoginLookupIden::RedirectUri,
402            )
403            .expr_as(
404                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
405                CompatSsoLoginLookupIden::CreatedAt,
406            )
407            .expr_as(
408                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
409                CompatSsoLoginLookupIden::FulfilledAt,
410            )
411            .expr_as(
412                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
413                CompatSsoLoginLookupIden::ExchangedAt,
414            )
415            .from(CompatSsoLogins::Table)
416            .apply_filter(filter)
417            .generate_pagination(
418                (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
419                pagination,
420            )
421            .build_sqlx(PostgresQueryBuilder);
422
423        let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
424            .traced()
425            .fetch_all(&mut *self.conn)
426            .await?;
427
428        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
429
430        Ok(page)
431    }
432
433    #[tracing::instrument(
434        name = "db.compat_sso_login.count",
435        skip_all,
436        fields(
437            db.query.text,
438        ),
439        err
440    )]
441    async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
442        let (sql, arguments) = Query::select()
443            .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
444            .from(CompatSsoLogins::Table)
445            .apply_filter(filter)
446            .build_sqlx(PostgresQueryBuilder);
447
448        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
449            .traced()
450            .fetch_one(&mut *self.conn)
451            .await?;
452
453        count
454            .try_into()
455            .map_err(DatabaseError::to_invalid_operation)
456    }
457}