{-# options_haddock prune #-}
{-# options_ghc -fno-warn-orphans #-}

-- | Interpreters for 'Jwt'
module Polysemy.Account.Api.Interpreter.Jwt where

import Conc (interpretAtomic)
import qualified Crypto.JOSE as JOSE
import Crypto.JOSE (JWK, KeyMaterialGenParam (OKPGenParam), OKPCrv (Ed25519), genJWK)
import Polysemy.Db (DbError, InitDbError)
import Polysemy.Hasql (Database, interpretAtomicStateDb, interpretTable)
import Servant.Auth.JWT (ToJWT)
import Servant.Auth.Server (FromJWT, JWTSettings, defaultJWTSettings, makeJWT)
import Sqel (tableName)
import Sqel.Data.TableSchema (TableSchema)
import Sqel.Names (named)
import Sqel.PgType (tableSchema)
import qualified Sqel.Prim as Sqel

import Polysemy.Account.Api.Effect.Jwt (GenJwk (GenJwk), Jwt (..), genJwk)
import Polysemy.Account.Data.AuthToken (AuthToken (AuthToken))
import Polysemy.Account.Data.AuthedAccount (AuthedAccount)

instance (FromJSON i, FromJSON p) => FromJWT (AuthedAccount i p) where
instance (ToJSON i, ToJSON p) => ToJWT (AuthedAccount i p) where

generateKey ::
  Member (Embed IO) r =>
  Sem r JWK
generateKey :: forall (r :: EffectRow). Member (Embed IO) r => Sem r JWK
generateKey =
  forall (m :: * -> *) (r :: EffectRow) a.
Member (Embed m) r =>
m a -> Sem r a
embed (forall (m :: * -> *). MonadRandom m => KeyMaterialGenParam -> m JWK
genJWK (OKPCrv -> KeyMaterialGenParam
OKPGenParam OKPCrv
Ed25519))

generateAndStoreKey ::
  Members [AtomicState (Maybe JWK), Embed IO] r =>
  Sem r JWK
generateAndStoreKey :: forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWK
generateAndStoreKey = do
  JWK
k <- forall (m :: * -> *) (r :: EffectRow) a.
Member (Embed m) r =>
m a -> Sem r a
embed (forall (m :: * -> *). MonadRandom m => KeyMaterialGenParam -> m JWK
genJWK (OKPCrv -> KeyMaterialGenParam
OKPGenParam OKPCrv
Ed25519))
  JWK
k forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall s (r :: EffectRow).
Member (AtomicState s) r =>
s -> Sem r ()
atomicPut (forall a. a -> Maybe a
Just JWK
k)

-- | Interpret 'GenJwk' using 'Ed25519'.
interpretGenJwk ::
  Member (Embed IO) r =>
  InterpreterFor GenJwk r
interpretGenJwk :: forall (r :: EffectRow).
Member (Embed IO) r =>
InterpreterFor GenJwk r
interpretGenJwk =
  forall (e :: (* -> *) -> * -> *) (r :: EffectRow) a.
FirstOrder e "interpret" =>
(forall (rInitial :: EffectRow) x. e (Sem rInitial) x -> Sem r x)
-> Sem (e : r) a -> Sem r a
interpret \ GenJwk (Sem rInitial) x
GenJwk -> forall (r :: EffectRow). Member (Embed IO) r => Sem r JWK
generateKey

key ::
  Members [AtomicState (Maybe JWK), Embed IO] r =>
  Sem r JWK
key :: forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWK
key =
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWK
generateAndStoreKey forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s (r :: EffectRow). Member (AtomicState s) r => Sem r s
atomicGet

settings ::
  Members [AtomicState (Maybe JWK), Embed IO] r =>
  Sem r JWTSettings
settings :: forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWTSettings
settings =
  JWK -> JWTSettings
defaultJWTSettings forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWK
key

authToken ::
  Member (Error Text) r =>
  Either JOSE.Error LByteString ->
  Sem r AuthToken
authToken :: forall (r :: EffectRow).
Member (Error Text) r =>
Either Error LByteString -> Sem r AuthToken
authToken = \case
  Right LByteString
bytes ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> AuthToken
AuthToken (forall a b. ConvertUtf8 a b => b -> a
decodeUtf8 LByteString
bytes))
  Left Error
err ->
    forall e (r :: EffectRow) a. Member (Error e) r => e -> Sem r a
throw (forall b a. (Show a, IsString b) => a -> b
show Error
err)

-- | Interpret 'Jwt' by storing the key in 'AtomicState', generating it on the fly if absent.
--
-- Generates 'Ed25519' keys.
--
-- Errors originating from the token generator are critical.
interpretJwtState ::
  Members [GenJwk, AtomicState (Maybe JWK), Error Text, Embed IO] r =>
  ToJWT a =>
  InterpreterFor (Jwt a) r
interpretJwtState :: forall (r :: EffectRow) a.
(Members
   '[GenJwk, AtomicState (Maybe JWK), Error Text, Embed IO] r,
 ToJWT a) =>
InterpreterFor (Jwt a) r
interpretJwtState =
  forall (e :: (* -> *) -> * -> *) (r :: EffectRow) a.
FirstOrder e "interpret" =>
(forall (rInitial :: EffectRow) x. e (Sem rInitial) x -> Sem r x)
-> Sem (e : r) a -> Sem r a
interpret \case
    Jwt a (Sem rInitial) x
Key ->
      forall (r :: EffectRow). Member GenJwk r => Sem r JWK
genJwk
    Jwt a (Sem rInitial) x
Settings ->
      forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWTSettings
settings
    MakeToken a
a -> do
      JWTSettings
sett <- forall (r :: EffectRow).
Members '[AtomicState (Maybe JWK), Embed IO] r =>
Sem r JWTSettings
settings
      forall (r :: EffectRow).
Member (Error Text) r =>
Either Error LByteString -> Sem r AuthToken
authToken forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) (r :: EffectRow) a.
Member (Embed m) r =>
m a -> Sem r a
embed (forall a.
ToJWT a =>
a -> JWTSettings -> Maybe UTCTime -> IO (Either Error LByteString)
makeJWT a
a JWTSettings
sett forall a. Maybe a
Nothing)

-- | Interpret 'Jwt' by storing the key in 'AtomicState' in memory.
interpretJwt ::
   a r .
  Members [Error Text, Embed IO] r =>
  ToJWT a =>
  InterpreterFor (Jwt a) r
interpretJwt :: forall a (r :: EffectRow).
(Members '[Error Text, Embed IO] r, ToJWT a) =>
InterpreterFor (Jwt a) r
interpretJwt =
  forall a (r :: EffectRow).
Member (Embed IO) r =>
a -> InterpreterFor (AtomicState a) r
interpretAtomic forall a. Maybe a
Nothing forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall (r :: EffectRow).
Member (Embed IO) r =>
InterpreterFor GenJwk r
interpretGenJwk forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall (r :: EffectRow) a.
(Members
   '[GenJwk, AtomicState (Maybe JWK), Error Text, Embed IO] r,
 ToJWT a) =>
InterpreterFor (Jwt a) r
interpretJwtState forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall (e2 :: (* -> *) -> * -> *) (e3 :: (* -> *) -> * -> *)
       (e1 :: (* -> *) -> * -> *) (r :: EffectRow) a.
Sem (e1 : r) a -> Sem (e1 : e2 : e3 : r) a
raiseUnder2

settingsPersistent ::
  Member (AtomicState JWK) r =>
  Sem r JWTSettings
settingsPersistent :: forall (r :: EffectRow).
Member (AtomicState JWK) r =>
Sem r JWTSettings
settingsPersistent =
  JWK -> JWTSettings
defaultJWTSettings forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (r :: EffectRow). Member (AtomicState s) r => Sem r s
atomicGet

-- | Interpret 'Jwt' by storing the key in 'AtomicState', requiring the key to be present from the start.
-- This is intended to be used with a database backing the 'AtomicState', the key being generated when starting the app.
--
-- Generates 'Ed25519' keys.
--
-- Errors originating from the token generator are critical.
interpretJwtPersistent ::
   a e r .
  Members [AtomicState JWK !! e, Error Text, Embed IO] r =>
  ToJWT a =>
  InterpreterFor (Jwt a !! e) r
interpretJwtPersistent :: forall a e (r :: EffectRow).
(Members '[AtomicState JWK !! e, Error Text, Embed IO] r,
 ToJWT a) =>
InterpreterFor (Jwt a !! e) r
interpretJwtPersistent =
  forall err (eff :: (* -> *) -> * -> *) (r :: EffectRow).
FirstOrder eff "interpretResumable" =>
(forall x (r0 :: EffectRow).
 eff (Sem r0) x -> Sem (Stop err : r) x)
-> InterpreterFor (Resumable err eff) r
interpretResumable \case
    Jwt a (Sem r0) x
Key ->
      forall err (eff :: (* -> *) -> * -> *) (r :: EffectRow).
Members '[Resumable err eff, Stop err] r =>
InterpreterFor eff r
restop forall s (r :: EffectRow). Member (AtomicState s) r => Sem r s
atomicGet
    Jwt a (Sem r0) x
Settings ->
      forall err (eff :: (* -> *) -> * -> *) (r :: EffectRow).
Members '[Resumable err eff, Stop err] r =>
InterpreterFor eff r
restop forall (r :: EffectRow).
Member (AtomicState JWK) r =>
Sem r JWTSettings
settingsPersistent
    MakeToken a
a -> do
      JWTSettings
sett <- forall err (eff :: (* -> *) -> * -> *) (r :: EffectRow).
Members '[Resumable err eff, Stop err] r =>
InterpreterFor eff r
restop forall (r :: EffectRow).
Member (AtomicState JWK) r =>
Sem r JWTSettings
settingsPersistent
      forall (r :: EffectRow).
Member (Error Text) r =>
Either Error LByteString -> Sem r AuthToken
authToken forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) (r :: EffectRow) a.
Member (Embed m) r =>
m a -> Sem r a
embed (forall a.
ToJWT a =>
a -> JWTSettings -> Maybe UTCTime -> IO (Either Error LByteString)
makeJWT a
a JWTSettings
sett forall a. Maybe a
Nothing)

-- | Interpret 'Jwt' using 'interpretJwtPersistent' and interpret 'AtomicState' as a PostgreSQL table using
-- @polysemy-hasql@, generating the JWK when it is not found in the database.
interpretJwtDb ::
   a r .
  Members [Database !! DbError, Error InitDbError, Error Text, Log, Mask, Resource, Race, Embed IO] r =>
  ToJWT a =>
  InterpreterFor (Jwt a !! DbError) r
interpretJwtDb :: forall a (r :: EffectRow).
(Members
   '[Database !! DbError, Error InitDbError, Error Text, Log, Mask,
     Resource, Race, Embed IO]
   r,
 ToJWT a) =>
InterpreterFor (Jwt a !! DbError) r
interpretJwtDb =
  forall (r :: EffectRow).
Member (Embed IO) r =>
InterpreterFor GenJwk r
interpretGenJwk forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall d (r :: EffectRow).
Members '[Database !! DbError, Log, Embed IO] r =>
TableSchema d -> InterpreterFor (DbTable d !! DbError) r
interpretTable TableSchema JWK
ts forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall d e (r :: EffectRow).
Members
  '[DbTable d !! e, Error InitDbError, Mask, Resource, Race,
    Embed IO]
  r =>
TableSchema d -> Sem r d -> InterpreterFor (AtomicState d !! e) r
interpretAtomicStateDb TableSchema JWK
ts forall (r :: EffectRow). Member GenJwk r => Sem r JWK
genJwk forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall a e (r :: EffectRow).
(Members '[AtomicState JWK !! e, Error Text, Embed IO] r,
 ToJWT a) =>
InterpreterFor (Jwt a !! e) r
interpretJwtPersistent forall b c a. (b -> c) -> (a -> b) -> a -> c
.
  forall (index :: Nat) (inserted :: EffectRow) (head :: EffectRow)
       (oldTail :: EffectRow) (tail :: EffectRow) (old :: EffectRow)
       (full :: EffectRow) a.
(ListOfLength index head, WhenStuck index InsertAtUnprovidedIndex,
 old ~ Append head oldTail, tail ~ Append inserted oldTail,
 full ~ Append head tail,
 InsertAtIndex index head tail oldTail full inserted) =>
Sem old a -> Sem full a
insertAt @1
  where
    ts :: TableSchema JWK
    ts :: TableSchema JWK
ts = forall (table :: DdK).
MkTableSchema table =>
Dd table -> TableSchema (DdType table)
tableSchema (forall (s0 :: DdK) (s1 :: DdK).
MapMod SetTableName s0 s1 =>
PgTableName -> Dd s0 -> Dd s1
tableName PgTableName
"jwk" (forall (name :: Symbol) (s0 :: DdK).
Rename s0 (SetName s0 name) =>
Dd s0 -> Dd (SetName s0 name)
named @"payload" forall a.
(ToJSON a, FromJSON a) =>
Dd ('DdK 'SelAuto '[PgPrimName, PrimValueCodec a] a 'Prim)
Sqel.json))