module Crypto.MAC.TOTP.Factory
( Factory (..)
, initialize
, initializeIO
, initGrace
, epochEq
, authenticate
, authenticateBS
, roundTime
, setTime
, validUntil
, shouldRefresh
, refresh
, refreshIO
, tryRefreshEvery
, startRefreshThread
, getNext
, getMessages
, getNextIO
, getMessagesIO
) where

import Crypto.Hash hiding (hmac)
import Crypto.MAC.HMAC
import qualified Data.ByteString as BS
import Data.ByteString (ByteString)
import Data.Int
import Data.Word
import Data.Bits
import System.Posix.Time
import System.Posix.Types (EpochTime)
import Foreign.C.Types (CTime (..))
import Control.Concurrent
import Data.IORef
import Data.Serialize

instance Integral CTime where
  quot (CTime a) (CTime b) = CTime (a `quot` b)
  rem (CTime a) (CTime b) = CTime (a `rem` b)
  div (CTime a) (CTime b) = CTime (a `div` b)
  mod (CTime a) (CTime b) = CTime (a `mod` b)
  quotRem (CTime n) (CTime d) = (\(d,m) -> (CTime d, CTime m)) (quotRem n d)
  divMod (CTime n) (CTime d) = (\(d,m) -> (CTime d, CTime m)) (divMod n d)
  toInteger (CTime t) = toInteger t

data Factory = Factory { secret :: ByteString
                       , secretInit :: ByteString
                       , count :: Int64
                       , validSeconds :: CTime
                       , refreshEpoch :: EpochTime
                       , hashMethod :: ByteString -> ByteString
                       , blockSize :: Int
                       , prefix :: ByteString -> ByteString
                       }

initialize :: (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> Factory
initialize hashMethod blockSize tokenBytes secretInit validSeconds =
    if validSeconds < 1
    then error "validSeconds must be >= 1"
    else Factory { secret = BS.empty
                 , secretInit
                 , count = 0
                 , validSeconds
                 , refreshEpoch = 0
                 , hashMethod
                 , blockSize
                 , prefix = BS.take tokenBytes
                 }

initializeIO :: (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> IO (Factory)
initializeIO hashMethod blockSize tokenBytes secretInit validSeconds = do
  time <- epochTime
  return $ refresh time (initialize hashMethod blockSize tokenBytes secretInit validSeconds)

initGrace :: Factory -> CTime -> Factory
initGrace (Factory _ secretInit _ validSeconds refreshEpoch hashMethod blockSize prefix) graceSeconds =
    let time = refreshEpoch - graceSeconds * validSeconds in
    refresh time $ Factory { secret = BS.empty
                           , secretInit
                           , count = 0
                           , validSeconds
                           , refreshEpoch = 0
                           , hashMethod
                           , blockSize
                           , prefix
                           }

epochEq :: Factory -> CTime -> Factory -> Bool
epochEq baseF n f =
    refreshEpoch f == refreshEpoch baseF - n * (validSeconds baseF)

incr :: Factory -> Factory
incr f = f {count = count f + 1}

authenticate :: Serialize b => Factory -> b -> ByteString
authenticate factory = authenticateBS factory encode

authenticateBS :: Factory -> (b -> ByteString) -> b -> ByteString
authenticateBS factory encodeFun message =
      (prefix factory) $ hmac (hashMethod factory) (blockSize factory) (secret factory) (encodeFun message)

hashCount :: Factory -> ByteString
hashCount f =
    authenticate f (count f)

roundTime :: CTime -> CTime -> CTime
roundTime t r = (t `div` r) * r

setTime :: CTime -> Factory -> Factory
setTime t f =
    let ct'@(CTime t') = roundTime t (validSeconds f)
        timeBytes = encode t'
    in f { refreshEpoch =  ct', secret = (hashMethod f) $ BS.concat [secretInit f, timeBytes] }

validUntil :: Factory -> EpochTime
validUntil f = refreshEpoch f + validSeconds f

shouldRefresh :: Factory -> EpochTime -> Bool
shouldRefresh f t =
    t >= validUntil f

refresh :: EpochTime -> Factory -> Factory
refresh time factory =
    if shouldRefresh factory time
    then (setTime time factory) { count = 0 }
    else factory

refreshIO :: Factory -> IO (Factory)
refreshIO factory = do
    time <- epochTime
    return $ refresh time factory

tryRefreshEvery :: Int -- ^ The delay according to Control.Concurrent.threadDelay before refresh attempts.
                -> IORef (Factory) -- ^ The current factory.
                -> IO ()
tryRefreshEvery delay factoryRef = do
  threadDelay delay
  time <- epochTime
  atomicModifyIORef factoryRef (\f -> let f' = if validUntil f <= time
                                               then refresh time f
                                               else f
                                      in (f', ()))
  tryRefreshEvery delay factoryRef

startRefreshThread :: Int -> Factory -> IO (ThreadId, IORef (Factory))
startRefreshThread delay factory = do
  time <- epochTime
  let factory' = refresh time factory
  factoryRef <- newIORef factory'
  t <- forkIO (tryRefreshEvery delay factoryRef)
  return (t, factoryRef)

getNext :: Factory -> (Factory, ByteString)
getNext f =
    (incr f, hashCount f)

getMessages :: Int -> Factory -> (Factory, [ByteString])
getMessages 0 f = (f, [])
getMessages n f =
  let (f', keys) = getMessages (n-1) f
      (f'', key) = getNext f'
  in (f'', key:keys)

getNextIO :: IORef (Factory) -> IO (ByteString)
getNextIO factoryRef =
    atomicModifyIORef factoryRef getNext

getMessagesIO ::IORef (Factory) -> Int -> IO [ByteString]
getMessagesIO _ n | n < 1 = return []
getMessagesIO factoryRef n =
    atomicModifyIORef factoryRef (getMessages n)