{-# LANGUAGE DeriveDataTypeable #-}
module System.IO.Streams.SHA where

import Control.Exception (Exception, throwIO)
import Data.Binary.Get
import Data.Typeable (Typeable)
import System.IO.Streams.Internal (InputStream (..))
import qualified System.IO.Streams as S

import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as C
import Data.IORef (IORef, newIORef, readIORef, writeIORef)

import Data.Digest.Pure.SHA

sha1Input :: InputStream ByteString -> IO (InputStream ByteString, IO (Digest SHA1State))
sha1Input = shaInput sha1Incremental completeSha1Incremental

sha224Input :: InputStream ByteString -> IO (InputStream ByteString, IO (Digest SHA256State))
sha224Input = shaInput sha224Incremental completeSha224Incremental

sha256Input :: InputStream ByteString -> IO (InputStream ByteString, IO (Digest SHA256State))
sha256Input = shaInput sha256Incremental completeSha256Incremental

sha384Input :: InputStream ByteString -> IO (InputStream ByteString, IO (Digest SHA512State))
sha384Input = shaInput sha384Incremental completeSha384Incremental

sha512Input :: InputStream ByteString -> IO (InputStream ByteString, IO (Digest SHA512State))
sha512Input = shaInput sha512Incremental completeSha512Incremental

checkedSha1Input :: Digest SHA1State -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha1Input = checkedShaInput sha1Incremental completeSha1Incremental . showDigest

checkedSha224Input :: Digest SHA256State -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha224Input = checkedShaInput sha224Incremental completeSha224Incremental . showDigest

checkedSha256Input :: Digest SHA256State -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha256Input = checkedShaInput sha256Incremental completeSha256Incremental . showDigest

checkedSha384Input :: Digest SHA512State -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha384Input = checkedShaInput sha384Incremental completeSha384Incremental . showDigest

checkedSha512Input :: Digest SHA512State -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha512Input = checkedShaInput sha512Incremental completeSha512Incremental . showDigest

checkedSha1Input' :: String -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha1Input' = checkedShaInput sha1Incremental completeSha1Incremental

checkedSha224Input' :: String -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha224Input' = checkedShaInput sha224Incremental completeSha224Incremental

checkedSha256Input' :: String -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha256Input' = checkedShaInput sha256Incremental completeSha256Incremental

checkedSha384Input' :: String -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha384Input' = checkedShaInput sha384Incremental completeSha384Incremental

checkedSha512Input' :: String -> InputStream ByteString -> IO (InputStream ByteString)
checkedSha512Input' = checkedShaInput sha512Incremental completeSha512Incremental

-- | Strict pairs.
data Pair a b = Pair !a !b

uncurry' :: (a -> b -> c) -> Pair a b -> c
uncurry' f (Pair a b) = f a b

-- | Inspired by `S.countInput`. The returned IO action can be run only
-- when the input stream is exhausted, otherwise an error occurs.
shaInput :: Decoder a -> (Decoder a -> Int -> Digest a)
  -> InputStream ByteString -> IO (InputStream ByteString, IO (Digest a))
shaInput increment end is = do
  ref <- newIORef $ Pair increment 0
  is' <- S.makeInputStream $ prod ref
  return $! (is', readIORef ref >>= uncurry' complete)

  where

  prod ref = do
    mbs <- S.read is
    maybe
      (return Nothing)
      (\bs -> (modifyRef ref (uncurry' $ modify bs)) >> (return $! Just bs))
      mbs

  complete decoder c = return $! end decoder c
  modify bs decoder c = Pair (pushChunk decoder bs) (c + (fromIntegral $ C.length bs))

-- | This returns an input stream exactly as the one being wrapped, but throws
-- an error if the computed SHA hash does not match the one given.
checkedShaInput :: Decoder a -> (Decoder a -> Int -> Digest a)
  -> String -> InputStream ByteString -> IO (InputStream ByteString)
checkedShaInput increment end digest is = do
  ref <- newIORef $ Pair increment 0
  is' <- S.makeInputStream $ prod ref
  return $! is'

  where

  prod ref = do
    mbs <- S.read is
    maybe
      (do r <- readIORef ref
          digest' <- uncurry' complete r
          if digest == showDigest digest'
            then return Nothing
            else throwIO UnmatchedSHAException)
      (\bs -> (modifyRef ref (uncurry' $ modify bs)) >> (return $! Just bs))
      mbs

  complete decoder c = return $! end decoder c
  modify bs decoder c = Pair (pushChunk decoder bs) (c + (fromIntegral $ C.length bs))

-- | Taken from System.IO.Streams.ByteString.
{-# INLINE modifyRef #-}
modifyRef :: IORef a -> (a -> a) -> IO ()
modifyRef ref f = do
    x <- readIORef ref
    writeIORef ref $! f x

-- | Exception raised by `checkedShaInput`.
data UnmatchedSHAException = UnmatchedSHAException
  deriving (Typeable)

instance Show UnmatchedSHAException where
    show _ = "Unmatched SHA exception."

instance Exception UnmatchedSHAException