1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState};
10use mas_storage::{
11 Clock, Page, Pagination,
12 compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
13};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{
23 DatabaseError, DatabaseInconsistencyError,
24 filter::{Filter, StatementExt},
25 iden::{CompatSessions, CompatSsoLogins},
26 pagination::QueryBuilderExt,
27 tracing::ExecuteExt,
28};
29
30pub struct PgCompatSsoLoginRepository<'c> {
33 conn: &'c mut PgConnection,
34}
35
36impl<'c> PgCompatSsoLoginRepository<'c> {
37 pub fn new(conn: &'c mut PgConnection) -> Self {
40 Self { conn }
41 }
42}
43
44#[derive(sqlx::FromRow)]
45#[enum_def]
46struct CompatSsoLoginLookup {
47 compat_sso_login_id: Uuid,
48 login_token: String,
49 redirect_uri: String,
50 created_at: DateTime<Utc>,
51 fulfilled_at: Option<DateTime<Utc>>,
52 exchanged_at: Option<DateTime<Utc>>,
53 compat_session_id: Option<Uuid>,
54}
55
56impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
57 type Error = DatabaseInconsistencyError;
58
59 fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
60 let id = res.compat_sso_login_id.into();
61 let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
62 DatabaseInconsistencyError::on("compat_sso_logins")
63 .column("redirect_uri")
64 .row(id)
65 .source(e)
66 })?;
67
68 let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) {
69 (None, None, None) => CompatSsoLoginState::Pending,
70 (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled {
71 fulfilled_at,
72 session_id: session_id.into(),
73 },
74 (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => {
75 CompatSsoLoginState::Exchanged {
76 fulfilled_at,
77 exchanged_at,
78 session_id: session_id.into(),
79 }
80 }
81 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
82 };
83
84 Ok(CompatSsoLogin {
85 id,
86 login_token: res.login_token,
87 redirect_uri,
88 created_at: res.created_at,
89 state,
90 })
91 }
92}
93
94impl Filter for CompatSsoLoginFilter<'_> {
95 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
96 sea_query::Condition::all()
97 .add_option(self.user().map(|user| {
98 Expr::exists(
99 Query::select()
100 .expr(Expr::cust("1"))
101 .from(CompatSessions::Table)
102 .and_where(
103 Expr::col((CompatSessions::Table, CompatSessions::UserId))
104 .eq(Uuid::from(user.id)),
105 )
106 .and_where(
107 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId))
108 .equals((CompatSessions::Table, CompatSessions::CompatSessionId)),
109 )
110 .take(),
111 )
112 }))
113 .add_option(self.state().map(|state| {
114 if state.is_exchanged() {
115 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
116 } else if state.is_fulfilled() {
117 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
118 .is_not_null()
119 .and(
120 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
121 .is_null(),
122 )
123 } else {
124 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
125 }
126 }))
127 }
128}
129
130#[async_trait]
131impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> {
132 type Error = DatabaseError;
133
134 #[tracing::instrument(
135 name = "db.compat_sso_login.lookup",
136 skip_all,
137 fields(
138 db.query.text,
139 compat_sso_login.id = %id,
140 ),
141 err,
142 )]
143 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
144 let res = sqlx::query_as!(
145 CompatSsoLoginLookup,
146 r#"
147 SELECT compat_sso_login_id
148 , login_token
149 , redirect_uri
150 , created_at
151 , fulfilled_at
152 , exchanged_at
153 , compat_session_id
154
155 FROM compat_sso_logins
156 WHERE compat_sso_login_id = $1
157 "#,
158 Uuid::from(id),
159 )
160 .traced()
161 .fetch_optional(&mut *self.conn)
162 .await?;
163
164 let Some(res) = res else { return Ok(None) };
165
166 Ok(Some(res.try_into()?))
167 }
168
169 #[tracing::instrument(
170 name = "db.compat_sso_login.find_for_session",
171 skip_all,
172 fields(
173 db.query.text,
174 %compat_session.id,
175 ),
176 err,
177 )]
178 async fn find_for_session(
179 &mut self,
180 compat_session: &CompatSession,
181 ) -> Result<Option<CompatSsoLogin>, Self::Error> {
182 let res = sqlx::query_as!(
183 CompatSsoLoginLookup,
184 r#"
185 SELECT compat_sso_login_id
186 , login_token
187 , redirect_uri
188 , created_at
189 , fulfilled_at
190 , exchanged_at
191 , compat_session_id
192
193 FROM compat_sso_logins
194 WHERE compat_session_id = $1
195 "#,
196 Uuid::from(compat_session.id),
197 )
198 .traced()
199 .fetch_optional(&mut *self.conn)
200 .await?;
201
202 let Some(res) = res else { return Ok(None) };
203
204 Ok(Some(res.try_into()?))
205 }
206
207 #[tracing::instrument(
208 name = "db.compat_sso_login.find_by_token",
209 skip_all,
210 fields(
211 db.query.text,
212 ),
213 err,
214 )]
215 async fn find_by_token(
216 &mut self,
217 login_token: &str,
218 ) -> Result<Option<CompatSsoLogin>, Self::Error> {
219 let res = sqlx::query_as!(
220 CompatSsoLoginLookup,
221 r#"
222 SELECT compat_sso_login_id
223 , login_token
224 , redirect_uri
225 , created_at
226 , fulfilled_at
227 , exchanged_at
228 , compat_session_id
229
230 FROM compat_sso_logins
231 WHERE login_token = $1
232 "#,
233 login_token,
234 )
235 .traced()
236 .fetch_optional(&mut *self.conn)
237 .await?;
238
239 let Some(res) = res else { return Ok(None) };
240
241 Ok(Some(res.try_into()?))
242 }
243
244 #[tracing::instrument(
245 name = "db.compat_sso_login.add",
246 skip_all,
247 fields(
248 db.query.text,
249 compat_sso_login.id,
250 compat_sso_login.redirect_uri = %redirect_uri,
251 ),
252 err,
253 )]
254 async fn add(
255 &mut self,
256 rng: &mut (dyn RngCore + Send),
257 clock: &dyn Clock,
258 login_token: String,
259 redirect_uri: Url,
260 ) -> Result<CompatSsoLogin, Self::Error> {
261 let created_at = clock.now();
262 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
263 tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
264
265 sqlx::query!(
266 r#"
267 INSERT INTO compat_sso_logins
268 (compat_sso_login_id, login_token, redirect_uri, created_at)
269 VALUES ($1, $2, $3, $4)
270 "#,
271 Uuid::from(id),
272 &login_token,
273 redirect_uri.as_str(),
274 created_at,
275 )
276 .traced()
277 .execute(&mut *self.conn)
278 .await?;
279
280 Ok(CompatSsoLogin {
281 id,
282 login_token,
283 redirect_uri,
284 created_at,
285 state: CompatSsoLoginState::default(),
286 })
287 }
288
289 #[tracing::instrument(
290 name = "db.compat_sso_login.fulfill",
291 skip_all,
292 fields(
293 db.query.text,
294 %compat_sso_login.id,
295 %compat_session.id,
296 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
297 user.id = %compat_session.user_id,
298 ),
299 err,
300 )]
301 async fn fulfill(
302 &mut self,
303 clock: &dyn Clock,
304 compat_sso_login: CompatSsoLogin,
305 compat_session: &CompatSession,
306 ) -> Result<CompatSsoLogin, Self::Error> {
307 let fulfilled_at = clock.now();
308 let compat_sso_login = compat_sso_login
309 .fulfill(fulfilled_at, compat_session)
310 .map_err(DatabaseError::to_invalid_operation)?;
311
312 let res = sqlx::query!(
313 r#"
314 UPDATE compat_sso_logins
315 SET
316 compat_session_id = $2,
317 fulfilled_at = $3
318 WHERE
319 compat_sso_login_id = $1
320 "#,
321 Uuid::from(compat_sso_login.id),
322 Uuid::from(compat_session.id),
323 fulfilled_at,
324 )
325 .traced()
326 .execute(&mut *self.conn)
327 .await?;
328
329 DatabaseError::ensure_affected_rows(&res, 1)?;
330
331 Ok(compat_sso_login)
332 }
333
334 #[tracing::instrument(
335 name = "db.compat_sso_login.exchange",
336 skip_all,
337 fields(
338 db.query.text,
339 %compat_sso_login.id,
340 ),
341 err,
342 )]
343 async fn exchange(
344 &mut self,
345 clock: &dyn Clock,
346 compat_sso_login: CompatSsoLogin,
347 ) -> Result<CompatSsoLogin, Self::Error> {
348 let exchanged_at = clock.now();
349 let compat_sso_login = compat_sso_login
350 .exchange(exchanged_at)
351 .map_err(DatabaseError::to_invalid_operation)?;
352
353 let res = sqlx::query!(
354 r#"
355 UPDATE compat_sso_logins
356 SET
357 exchanged_at = $2
358 WHERE
359 compat_sso_login_id = $1
360 "#,
361 Uuid::from(compat_sso_login.id),
362 exchanged_at,
363 )
364 .traced()
365 .execute(&mut *self.conn)
366 .await?;
367
368 DatabaseError::ensure_affected_rows(&res, 1)?;
369
370 Ok(compat_sso_login)
371 }
372
373 #[tracing::instrument(
374 name = "db.compat_sso_login.list",
375 skip_all,
376 fields(
377 db.query.text,
378 ),
379 err
380 )]
381 async fn list(
382 &mut self,
383 filter: CompatSsoLoginFilter<'_>,
384 pagination: Pagination,
385 ) -> Result<Page<CompatSsoLogin>, Self::Error> {
386 let (sql, arguments) = Query::select()
387 .expr_as(
388 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
389 CompatSsoLoginLookupIden::CompatSsoLoginId,
390 )
391 .expr_as(
392 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
393 CompatSsoLoginLookupIden::CompatSessionId,
394 )
395 .expr_as(
396 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
397 CompatSsoLoginLookupIden::LoginToken,
398 )
399 .expr_as(
400 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
401 CompatSsoLoginLookupIden::RedirectUri,
402 )
403 .expr_as(
404 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
405 CompatSsoLoginLookupIden::CreatedAt,
406 )
407 .expr_as(
408 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
409 CompatSsoLoginLookupIden::FulfilledAt,
410 )
411 .expr_as(
412 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
413 CompatSsoLoginLookupIden::ExchangedAt,
414 )
415 .from(CompatSsoLogins::Table)
416 .apply_filter(filter)
417 .generate_pagination(
418 (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
419 pagination,
420 )
421 .build_sqlx(PostgresQueryBuilder);
422
423 let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
424 .traced()
425 .fetch_all(&mut *self.conn)
426 .await?;
427
428 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
429
430 Ok(page)
431 }
432
433 #[tracing::instrument(
434 name = "db.compat_sso_login.count",
435 skip_all,
436 fields(
437 db.query.text,
438 ),
439 err
440 )]
441 async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
442 let (sql, arguments) = Query::select()
443 .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
444 .from(CompatSsoLogins::Table)
445 .apply_filter(filter)
446 .build_sqlx(PostgresQueryBuilder);
447
448 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
449 .traced()
450 .fetch_one(&mut *self.conn)
451 .await?;
452
453 count
454 .try_into()
455 .map_err(DatabaseError::to_invalid_operation)
456 }
457}