module Gamgee.Effects.TOTP
    ( -- * Effects
      TOTP (..)

      -- * Actions
    , getSecret
    , getTOTP

      -- * Interpretations
    , runTOTP
    ) where

import qualified Crypto.Hash.Algorithms     as HashAlgos
import qualified Crypto.OTP                 as OTP
import qualified Data.ByteArray.Encoding    as Encoding
import qualified Data.Time.Clock.POSIX      as Clock
import           Gamgee.Effects.Crypto      (Crypto)
import qualified Gamgee.Effects.Crypto      as Crypto
import qualified Gamgee.Effects.Error       as Err
import           Gamgee.Effects.SecretInput (SecretInput)
import qualified Gamgee.Token               as Token
import           Polysemy                   (Member, Members, Sem)
import qualified Polysemy                   as P
import qualified Polysemy.Error             as P
import           Relude
import qualified Text.Printf                as Printf


----------------------------------------------------------------------------------------------------
-- Effects
----------------------------------------------------------------------------------------------------

data TOTP m a where
  GetSecret :: Token.TokenSpec -> TOTP m Text
  GetTOTP   :: Token.TokenSpec -> Clock.POSIXTime -> TOTP m Text

P.makeSem ''TOTP


----------------------------------------------------------------------------------------------------
-- Interpret TOTP
----------------------------------------------------------------------------------------------------

runTOTP :: Members [SecretInput Text, Crypto, P.Error Err.EffError] r => Sem (TOTP : r) a -> Sem r a
runTOTP :: Sem (TOTP : r) a -> Sem r a
runTOTP = (forall x (rInitial :: EffectRow).
 TOTP (Sem rInitial) x -> Sem r x)
-> Sem (TOTP : r) a -> Sem r a
forall (e :: (* -> *) -> * -> *) (r :: EffectRow) a.
FirstOrder e "interpret" =>
(forall x (rInitial :: EffectRow). e (Sem rInitial) x -> Sem r x)
-> Sem (e : r) a -> Sem r a
P.interpret ((forall x (rInitial :: EffectRow).
  TOTP (Sem rInitial) x -> Sem r x)
 -> Sem (TOTP : r) a -> Sem r a)
-> (forall x (rInitial :: EffectRow).
    TOTP (Sem rInitial) x -> Sem r x)
-> Sem (TOTP : r) a
-> Sem r a
forall a b. (a -> b) -> a -> b
$ \case
  GetSecret spec    -> (ByteString, Text) -> Text
forall a b. (a, b) -> b
snd ((ByteString, Text) -> Text)
-> Sem r (ByteString, Text) -> Sem r Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TokenSpec -> Sem r (ByteString, Text)
forall (r :: EffectRow).
Members '[SecretInput Text, Crypto, Error EffError] r =>
TokenSpec -> Sem r (ByteString, Text)
retrieveKeyAndSecret TokenSpec
spec
  GetTOTP spec time -> TokenSpec -> Sem r (ByteString, Text)
forall (r :: EffectRow).
Members '[SecretInput Text, Crypto, Error EffError] r =>
TokenSpec -> Sem r (ByteString, Text)
retrieveKeyAndSecret TokenSpec
spec Sem r (ByteString, Text)
-> ((ByteString, Text) -> Sem r Text) -> Sem r Text
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TokenSpec -> POSIXTime -> ByteString -> Sem r Text
forall (r :: EffectRow).
Member (Error EffError) r =>
TokenSpec -> POSIXTime -> ByteString -> Sem r Text
computeTOTP TokenSpec
spec POSIXTime
time (ByteString -> Sem r Text)
-> ((ByteString, Text) -> ByteString)
-> (ByteString, Text)
-> Sem r Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Text) -> ByteString
forall a b. (a, b) -> a
fst

retrieveKeyAndSecret :: Members [SecretInput Text, Crypto, P.Error Err.EffError] r
                     => Token.TokenSpec
                     -> Sem r (ByteString, Text)
retrieveKeyAndSecret :: TokenSpec -> Sem r (ByteString, Text)
retrieveKeyAndSecret TokenSpec
spec = do
  Text
secret <- TokenSpec -> Sem r Text
forall (r :: EffectRow).
Members '[SecretInput Text, Crypto] r =>
TokenSpec -> Sem r Text
Crypto.decryptSecret TokenSpec
spec
  case Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
Encoding.convertFromBase Base
Encoding.Base32 (Text -> ByteString
forall a b. ConvertUtf8 a b => a -> b
encodeUtf8 Text
secret :: ByteString) of
    Left String
msg  -> EffError -> Sem r (ByteString, Text)
forall e (r :: EffectRow) a.
MemberWithError (Error e) r =>
e -> Sem r a
P.throw (EffError -> Sem r (ByteString, Text))
-> EffError -> Sem r (ByteString, Text)
forall a b. (a -> b) -> a -> b
$ Text -> EffError
Err.SecretDecryptError (Text -> EffError) -> Text -> EffError
forall a b. (a -> b) -> a -> b
$ String -> Text
forall a. ToText a => a -> Text
toText String
msg
    Right ByteString
key -> (ByteString, Text) -> Sem r (ByteString, Text)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
key, Text
secret)

computeTOTP :: Member (P.Error Err.EffError) r => Token.TokenSpec -> Clock.POSIXTime -> ByteString -> Sem r Text
computeTOTP :: TokenSpec -> POSIXTime -> ByteString -> Sem r Text
computeTOTP TokenSpec
spec POSIXTime
time ByteString
key =
  case TokenSpec -> TokenAlgorithm
Token.tokenAlgorithm TokenSpec
spec of
    TokenAlgorithm
Token.AlgorithmSHA1   -> TOTPParams SHA1 -> Text
forall h. HashAlgorithm h => TOTPParams h -> Text
makeOTP (TOTPParams SHA1 -> Text) -> Sem r (TOTPParams SHA1) -> Sem r Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SHA1 -> Sem r (TOTPParams SHA1)
forall (r :: EffectRow) h.
(Member (Error EffError) r, HashAlgorithm h) =>
h -> Sem r (TOTPParams h)
makeParams SHA1
HashAlgos.SHA1
    TokenAlgorithm
Token.AlgorithmSHA256 -> TOTPParams SHA256 -> Text
forall h. HashAlgorithm h => TOTPParams h -> Text
makeOTP (TOTPParams SHA256 -> Text)
-> Sem r (TOTPParams SHA256) -> Sem r Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SHA256 -> Sem r (TOTPParams SHA256)
forall (r :: EffectRow) h.
(Member (Error EffError) r, HashAlgorithm h) =>
h -> Sem r (TOTPParams h)
makeParams SHA256
HashAlgos.SHA256
    TokenAlgorithm
Token.AlgorithmSHA512 -> TOTPParams SHA512 -> Text
forall h. HashAlgorithm h => TOTPParams h -> Text
makeOTP (TOTPParams SHA512 -> Text)
-> Sem r (TOTPParams SHA512) -> Sem r Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SHA512 -> Sem r (TOTPParams SHA512)
forall (r :: EffectRow) h.
(Member (Error EffError) r, HashAlgorithm h) =>
h -> Sem r (TOTPParams h)
makeParams SHA512
HashAlgos.SHA512

  where
    period :: Token.TokenPeriod
    period :: TokenPeriod
period = TokenSpec -> TokenPeriod
Token.tokenPeriod TokenSpec
spec

    digits :: OTP.OTPDigits
    digits :: OTPDigits
digits = case TokenSpec -> TokenDigits
Token.tokenDigits TokenSpec
spec of
               TokenDigits
Token.Digits6 -> OTPDigits
OTP.OTP6
               TokenDigits
Token.Digits8 -> OTPDigits
OTP.OTP8

    makeParams :: (Member (P.Error Err.EffError) r, HashAlgos.HashAlgorithm h) => h -> Sem r (OTP.TOTPParams h)
    makeParams :: h -> Sem r (TOTPParams h)
makeParams h
alg = (String -> Sem r (TOTPParams h))
-> (TOTPParams h -> Sem r (TOTPParams h))
-> Either String (TOTPParams h)
-> Sem r (TOTPParams h)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Sem r (TOTPParams h) -> String -> Sem r (TOTPParams h)
forall a b. a -> b -> a
const (EffError -> Sem r (TOTPParams h)
forall e (r :: EffectRow) a.
MemberWithError (Error e) r =>
e -> Sem r a
P.throw (EffError -> Sem r (TOTPParams h))
-> EffError -> Sem r (TOTPParams h)
forall a b. (a -> b) -> a -> b
$ TokenPeriod -> EffError
Err.InvalidTokenPeriod TokenPeriod
period))
                            TOTPParams h -> Sem r (TOTPParams h)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (TOTPParams h) -> Sem r (TOTPParams h))
-> Either String (TOTPParams h) -> Sem r (TOTPParams h)
forall a b. (a -> b) -> a -> b
$
                            h
-> OTPTime
-> Word16
-> OTPDigits
-> ClockSkew
-> Either String (TOTPParams h)
forall hash.
HashAlgorithm hash =>
hash
-> OTPTime
-> Word16
-> OTPDigits
-> ClockSkew
-> Either String (TOTPParams hash)
OTP.mkTOTPParams h
alg OTPTime
0 (TokenPeriod -> Word16
Token.unTokenPeriod TokenPeriod
period) OTPDigits
digits ClockSkew
OTP.NoSkew

    makeOTP :: (HashAlgos.HashAlgorithm h) => OTP.TOTPParams h -> Text
    makeOTP :: TOTPParams h -> Text
makeOTP TOTPParams h
p = OTP -> Text
format (OTP -> Text) -> OTP -> Text
forall a b. (a -> b) -> a -> b
$ TOTPParams h -> ByteString -> OTPTime -> OTP
forall hash key.
(HashAlgorithm hash, ByteArrayAccess key) =>
TOTPParams hash -> key -> OTPTime -> OTP
OTP.totp TOTPParams h
p ByteString
key (OTPTime -> OTP) -> OTPTime -> OTP
forall a b. (a -> b) -> a -> b
$ POSIXTime -> OTPTime
forall a b. (RealFrac a, Integral b) => a -> b
floor POSIXTime
time

    format :: OTP.OTP -> Text
    format :: OTP -> Text
format OTP
otp =
      let (OTP
base, String
size) = case TokenSpec -> TokenDigits
Token.tokenDigits TokenSpec
spec of
                           TokenDigits
Token.Digits6 -> (OTP
1000000, String
"6")
                           TokenDigits
Token.Digits8 -> (OTP
100000000, String
"8")
      in String -> Text
forall a. IsString a => String -> a
fromString (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String -> OTP -> String
forall r. PrintfType r => String -> r
Printf.printf (String
"%0" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
size String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"d") (OTP
otp OTP -> OTP -> OTP
forall a. Integral a => a -> a -> a
`mod` OTP
base)