module Crypto.MAC.TOTP.Factory
( Factory (..)
, initialize
, initializeIO
, initGrace
, epochEq
, authenticate
, authenticateBS
, roundTime
, setTime
, validUntil
, shouldRefresh
, refresh
, refreshIO
, tryRefreshEvery
, startRefreshThread
, getNext
, getMessages
, getNextIO
, getMessagesIO
) where
import Crypto.Hash hiding (hmac)
import Crypto.MAC.HMAC
import qualified Data.ByteString as BS
import Data.ByteString (ByteString)
import Data.Int
import Data.Word
import Data.Bits
import System.Posix.Time
import System.Posix.Types (EpochTime)
import Foreign.C.Types (CTime (..))
import Control.Concurrent
import Data.IORef
import Data.Serialize
instance Integral CTime where
quot (CTime a) (CTime b) = CTime (a `quot` b)
rem (CTime a) (CTime b) = CTime (a `rem` b)
div (CTime a) (CTime b) = CTime (a `div` b)
mod (CTime a) (CTime b) = CTime (a `mod` b)
quotRem (CTime n) (CTime d) = (\(d,m) -> (CTime d, CTime m)) (quotRem n d)
divMod (CTime n) (CTime d) = (\(d,m) -> (CTime d, CTime m)) (divMod n d)
toInteger (CTime t) = toInteger t
data Factory = Factory { secret :: ByteString
, secretInit :: ByteString
, count :: Int64
, validSeconds :: CTime
, refreshEpoch :: EpochTime
, hashMethod :: ByteString -> ByteString
, blockSize :: Int
, prefix :: ByteString -> ByteString
}
initialize :: (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> Factory
initialize hashMethod blockSize tokenBytes secretInit validSeconds =
if validSeconds < 1
then error "validSeconds must be >= 1"
else Factory { secret = BS.empty
, secretInit
, count = 0
, validSeconds
, refreshEpoch = 0
, hashMethod
, blockSize
, prefix = BS.take tokenBytes
}
initializeIO :: (ByteString -> ByteString) -> Int -> Int -> ByteString -> CTime -> IO (Factory)
initializeIO hashMethod blockSize tokenBytes secretInit validSeconds = do
time <- epochTime
return $ refresh time (initialize hashMethod blockSize tokenBytes secretInit validSeconds)
initGrace :: Factory -> CTime -> Factory
initGrace (Factory _ secretInit _ validSeconds refreshEpoch hashMethod blockSize prefix) graceSeconds =
let time = refreshEpoch graceSeconds * validSeconds in
refresh time $ Factory { secret = BS.empty
, secretInit
, count = 0
, validSeconds
, refreshEpoch = 0
, hashMethod
, blockSize
, prefix
}
epochEq :: Factory -> CTime -> Factory -> Bool
epochEq baseF n f =
refreshEpoch f == refreshEpoch baseF n * (validSeconds baseF)
incr :: Factory -> Factory
incr f = f {count = count f + 1}
authenticate :: Serialize b => Factory -> b -> ByteString
authenticate factory = authenticateBS factory encode
authenticateBS :: Factory -> (b -> ByteString) -> b -> ByteString
authenticateBS factory encodeFun message =
(prefix factory) $ hmac (hashMethod factory) (blockSize factory) (secret factory) (encodeFun message)
hashCount :: Factory -> ByteString
hashCount f =
authenticate f (count f)
roundTime :: CTime -> CTime -> CTime
roundTime t r = (t `div` r) * r
setTime :: CTime -> Factory -> Factory
setTime t f =
let ct'@(CTime t') = roundTime t (validSeconds f)
timeBytes = encode t'
in f { refreshEpoch = ct', secret = (hashMethod f) $ BS.concat [secretInit f, timeBytes] }
validUntil :: Factory -> EpochTime
validUntil f = refreshEpoch f + validSeconds f
shouldRefresh :: Factory -> EpochTime -> Bool
shouldRefresh f t =
t >= validUntil f
refresh :: EpochTime -> Factory -> Factory
refresh time factory =
if shouldRefresh factory time
then (setTime time factory) { count = 0 }
else factory
refreshIO :: Factory -> IO (Factory)
refreshIO factory = do
time <- epochTime
return $ refresh time factory
tryRefreshEvery :: Int
-> IORef (Factory)
-> IO ()
tryRefreshEvery delay factoryRef = do
threadDelay delay
time <- epochTime
atomicModifyIORef factoryRef (\f -> let f' = if validUntil f <= time
then refresh time f
else f
in (f', ()))
tryRefreshEvery delay factoryRef
startRefreshThread :: Int -> Factory -> IO (ThreadId, IORef (Factory))
startRefreshThread delay factory = do
time <- epochTime
let factory' = refresh time factory
factoryRef <- newIORef factory'
t <- forkIO (tryRefreshEvery delay factoryRef)
return (t, factoryRef)
getNext :: Factory -> (Factory, ByteString)
getNext f =
(incr f, hashCount f)
getMessages :: Int -> Factory -> (Factory, [ByteString])
getMessages 0 f = (f, [])
getMessages n f =
let (f', keys) = getMessages (n1) f
(f'', key) = getNext f'
in (f'', key:keys)
getNextIO :: IORef (Factory) -> IO (ByteString)
getNextIO factoryRef =
atomicModifyIORef factoryRef getNext
getMessagesIO ::IORef (Factory) -> Int -> IO [ByteString]
getMessagesIO _ n | n < 1 = return []
getMessagesIO factoryRef n =
atomicModifyIORef factoryRef (getMessages n)