{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module LLVM.Extra.Multi.Value.Storable (
   -- * Basic class
   C(load, store),
   storeNext,
   modify,

   -- * Classes for tuples and vectors
   Tuple(..),
   Vector(..),
   TupleVector(..),

   -- * Standard method implementations
   loadTraversable,
   loadApplicative,
   storeFoldable,

   -- * Pointer handling
   Storable.advancePtr,
   Storable.incrementPtr,
   Storable.decrementPtr,

   -- * Loops over Storable arrays
   Array.arrayLoop,
   Array.arrayLoop2,
   Array.arrayLoopMaybeCont,
   Array.arrayLoopMaybeCont2,
   ) where

import qualified LLVM.Extra.Storable.Private as Storable
import qualified LLVM.Extra.Storable.Array as Array
import LLVM.Extra.Storable.Private
         (BytePtr, advancePtrStatic, incPtrState, incrementPtr, update,
          castFromBytePtr, castToBytePtr,
          runElements, elementOffset, castElementPtr,
          assemblePrimitive, disassemblePrimitive, proxyFromElement3)

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.ArithmeticPrivate as A

import qualified LLVM.ExecutionEngine as EE
import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value)

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.Reader as MR
import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative.HT as App
import qualified Control.Functor.HT as FuncHT
import Control.Monad (foldM, replicateM, replicateM_, (<=<))
import Control.Applicative (Applicative, pure, (<$>))

import qualified Foreign.Storable.Record.Tuple as StoreTuple
import qualified Foreign.Storable as Store
import Foreign.Ptr (Ptr)

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Orphans ()
import Data.Tuple.HT (uncurry3)
import Data.Complex (Complex)
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int  (Int8,  Int16,  Int32,  Int64)
import Data.Bool8 (Bool8)



class (Store.Storable a, MultiValue.C a) => C a where
   {-
   Not all Storable types have a compatible LLVM type,
   or even more, one LLVM type that is compatible on all platforms.
   -}
   load :: Value (Ptr a) -> CodeGenFunction r (MultiValue.T a)
   store :: MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r ()

storeNext ::
   (C a, Value (Ptr a) ~ ptr) => MultiValue.T a -> ptr -> CodeGenFunction r ptr
storeNext a ptr  =  store a ptr >> incrementPtr ptr

modify ::
   (C a, MultiValue.T a ~ al) =>
   (al -> CodeGenFunction r al) ->
   Value (Ptr a) -> CodeGenFunction r ()
modify f ptr  =  flip store ptr =<< f =<< load ptr


instance
   (EE.Marshal a, LLVM.IsConst a, LLVM.IsFirstClass a) =>
      C (EE.Stored a) where
   load = fmap MultiValue.Cons . LLVM.load <=< castFromStoredPtr
   store (MultiValue.Cons a) = LLVM.store a <=< castFromStoredPtr

castFromStoredPtr ::
   (LLVM.IsType a) =>
   Value (Ptr (EE.Stored a)) -> CodeGenFunction r (Value (LLVM.Ptr a))
castFromStoredPtr = LLVM.bitcast


loadPrimitive ::
   (LLVM.Storable a, MultiValue.Repr a ~ LLVM.Value a) =>
   Value (Ptr a) -> CodeGenFunction r (MultiValue.T a)
loadPrimitive ptr = fmap MultiValue.Cons $ LLVM.load =<< LLVM.bitcast ptr

storePrimitive ::
   (LLVM.Storable a, MultiValue.Repr a ~ LLVM.Value a) =>
   MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r ()
storePrimitive (MultiValue.Cons a) ptr = LLVM.store a =<< LLVM.bitcast ptr

instance C Float where
   load = loadPrimitive; store = storePrimitive

instance C Double where
   load = loadPrimitive; store = storePrimitive

instance C Word where
   load = loadPrimitive; store = storePrimitive

instance C Word8 where
   load = loadPrimitive; store = storePrimitive

instance C Word16 where
   load = loadPrimitive; store = storePrimitive

instance C Word32 where
   load = loadPrimitive; store = storePrimitive

instance C Word64 where
   load = loadPrimitive; store = storePrimitive

instance C Int where
   load = loadPrimitive; store = storePrimitive

instance C Int8 where
   load = loadPrimitive; store = storePrimitive

instance C Int16 where
   load = loadPrimitive; store = storePrimitive

instance C Int32 where
   load = loadPrimitive; store = storePrimitive

instance C Int64 where
   load = loadPrimitive; store = storePrimitive

{- |
Not very efficient implementation
because we want to adapt to @sizeOf Bool@ dynamically.
Unfortunately, LLVM-9's optimizer does not recognize the instruction pattern.
Better use 'Bool8' for booleans.
-}
instance C Bool where
   load ptr = do
      bytePtr <- castToBytePtr ptr
      bytes <-
         flip MS.evalStateT bytePtr $
            replicateM (Store.sizeOf (False :: Bool))
               (MT.lift . LLVM.load =<< incPtrState)
      let zero = LLVM.valueOf 0
      mask <- foldM A.or zero bytes
      MultiValue.Cons <$> A.cmp LLVM.CmpNE mask zero
   store (MultiValue.Cons b) ptr = do
      bytePtr <- castToBytePtr ptr
      byte <- LLVM.sext b
      flip MS.evalStateT bytePtr $
         replicateM_ (Store.sizeOf (False :: Bool))
            (MT.lift . LLVM.store byte =<< incPtrState)

instance C Bool8 where
   load ptr =
      fmap MultiValue.Cons $
      A.cmp LLVM.CmpNE (LLVM.valueOf 0) =<< LLVM.load =<< castToBytePtr ptr
   store (MultiValue.Cons b) ptr = do
      byte <- LLVM.zext b
      LLVM.store byte =<< castToBytePtr ptr

instance (C a) => C (Complex a) where
   load = loadApplicative; store = storeFoldable



instance (Tuple tuple) => C (StoreTuple.Tuple tuple) where
   load ptr = MultiValue.cast <$> loadTuple ptr
   store = storeTuple . MultiValue.cast

class (StoreTuple.Storable tuple, MultiValue.C tuple) => Tuple tuple where
   loadTuple ::
      Value (Ptr (StoreTuple.Tuple tuple)) ->
      CodeGenFunction r (MultiValue.T tuple)
   storeTuple ::
      MultiValue.T tuple ->
      Value (Ptr (StoreTuple.Tuple tuple)) ->
      CodeGenFunction r ()

instance (C a, C b) => Tuple (a,b) where
   loadTuple ptr =
      runElements ptr $ fmap (uncurry MultiValue.zip) $
         App.mapPair (loadElement, loadElement) $
         FuncHT.unzip $ proxyFromElement3 ptr
   storeTuple = MultiValue.uncurry $ \a b ptr ->
      case FuncHT.unzip $ proxyFromElement3 ptr of
         (pa,pb) -> runElements ptr $ storeElement pa a >> storeElement pb b

instance (C a, C b, C c) => Tuple (a,b,c) where
   loadTuple ptr =
      runElements ptr $ fmap (uncurry3 MultiValue.zip3) $
         App.mapTriple (loadElement, loadElement, loadElement) $
         FuncHT.unzip3 $ proxyFromElement3 ptr
   storeTuple = MultiValue.uncurry3 $ \a b c ptr ->
      case FuncHT.unzip3 $ proxyFromElement3 ptr of
         (pa,pb,pc) ->
            runElements ptr $
               storeElement pa a >> storeElement pb b >> storeElement pc c

loadElement ::
   (C a) =>
   LP.Proxy a ->
   MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) (MultiValue.T a)
loadElement proxy =
   MT.lift . MT.lift . load =<< elementPtr proxy

storeElement ::
   (C a) =>
   LP.Proxy a -> MultiValue.T a ->
   MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) ()
storeElement proxy a =
   MT.lift . MT.lift . store a =<< elementPtr proxy

elementPtr ::
   (C a) =>
   LP.Proxy a ->
   MR.ReaderT BytePtr
      (MS.StateT Int (CodeGenFunction r)) (LLVM.Value (Ptr a))
elementPtr proxy = do
   ptr <- MR.ask
   MT.lift $ do
      offset <- elementOffset proxy
      MT.lift $ castFromBytePtr =<< LLVM.getElementPtr ptr (offset, ())


instance
   (TypeNum.Positive n, Vector a) =>
      C (LLVM.Vector n a) where
   load ptr =
      fmap MultiValue.Cons $
      assembleVector (proxyFromElement3 ptr) =<< loadApplicativeRepr ptr
   store (MultiValue.Cons a) ptr =
      flip storeFoldableRepr ptr
         =<< disassembleVector (proxyFromElement3 ptr) a

class (C a, MultiVector.C a) => Vector a where
   assembleVector ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> LLVM.Vector n (MultiValue.Repr a) ->
      CodeGenFunction r (MultiVector.Repr n a)
   disassembleVector ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> MultiVector.Repr n a ->
      CodeGenFunction r (LLVM.Vector n (MultiValue.Repr a))

instance Vector Float where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Double where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word8 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word16 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word32 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Word64 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int8 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int16 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int32 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Int64 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Bool where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive

instance Vector Bool8 where
   assembleVector LP.Proxy = assemblePrimitive
   disassembleVector LP.Proxy = disassemblePrimitive


instance
   (Tuple tuple, TupleVector tuple) =>
      Vector (StoreTuple.Tuple tuple) where
   assembleVector = deinterleave . fmap StoreTuple.getTuple
   disassembleVector = interleave . fmap StoreTuple.getTuple


class (MultiVector.C a) => TupleVector a where
   deinterleave ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> LLVM.Vector n (MultiValue.Repr a) ->
      CodeGenFunction r (MultiVector.Repr n a)
   interleave ::
      (TypeNum.Positive n) =>
      LP.Proxy a -> MultiVector.Repr n a ->
      CodeGenFunction r (LLVM.Vector n (MultiValue.Repr a))

instance (Vector a, Vector b) => TupleVector (a,b) where
   deinterleave = FuncHT.uncurry $ \pa pb -> FuncHT.uncurry $ \a b ->
      App.lift2 (,) (assembleVector pa a) (assembleVector pb b)
   interleave = FuncHT.uncurry $ \pa pb (a,b) ->
      App.lift2 (App.lift2 (,))
         (disassembleVector pa a) (disassembleVector pb b)

instance (Vector a, Vector b, Vector c) => TupleVector (a,b,c) where
   deinterleave = FuncHT.uncurry3 $ \pa pb pc -> FuncHT.uncurry3 $ \a b c ->
      App.lift3 (,,)
         (assembleVector pa a)
         (assembleVector pb b)
         (assembleVector pc c)
   interleave = FuncHT.uncurry3 $ \pa pb pc (a,b,c) ->
      App.lift3 (App.lift3 (,,))
         (disassembleVector pa a)
         (disassembleVector pb b)
         (disassembleVector pc c)


{-
instance Storable () available since base-4.9/GHC-8.0.
Before we need Data.Orphans.
-}
instance C () where
   load _ptr = return $ MultiValue.Cons ()
   store (MultiValue.Cons ()) _ptr = return ()


loadTraversable ::
   (NonEmptyC.Repeat f, Trav.Traversable f,
    C a, MultiValue.Repr fa ~ f (MultiValue.Repr a)) =>
   Value (Ptr (f a)) -> CodeGenFunction r (MultiValue.T fa)
loadTraversable =
   (MS.evalStateT $ fmap MultiValue.Cons $
    Trav.sequence $ NonEmptyC.repeat $ loadState)
      <=< castElementPtr

loadApplicative ::
   (Applicative f, Trav.Traversable f,
    C a, MultiValue.Repr fa ~ f (MultiValue.Repr a)) =>
   Value (Ptr (f a)) -> CodeGenFunction r (MultiValue.T fa)
loadApplicative = fmap MultiValue.Cons . loadApplicativeRepr

loadApplicativeRepr ::
   (Applicative f, Trav.Traversable f, C a) =>
   Value (Ptr (f a)) -> CodeGenFunction r (f (MultiValue.Repr a))
loadApplicativeRepr =
   (MS.evalStateT $ Trav.sequence $ pure loadState) <=< castElementPtr

loadState ::
   (C a, MultiValue.Repr a ~ al) =>
   MS.StateT (Value (Ptr a)) (CodeGenFunction r) al
loadState =
   MT.lift . fmap (\(MultiValue.Cons a) -> a) . load =<< advancePtrState


storeFoldable ::
   (Fold.Foldable f, C a, MultiValue.Repr fa ~ f (MultiValue.Repr a)) =>
    MultiValue.T fa -> Value (Ptr (f a)) -> CodeGenFunction r ()
storeFoldable (MultiValue.Cons xs) = storeFoldableRepr xs

storeFoldableRepr ::
   (Fold.Foldable f, C a) =>
   f (MultiValue.Repr a) -> Value (Ptr (f a)) -> CodeGenFunction r ()
storeFoldableRepr xs =
   MS.evalStateT (Fold.mapM_ storeState xs) <=< castElementPtr

storeState ::
   (C a, MultiValue.Repr a ~ al) =>
   al -> MS.StateT (Value (Ptr a)) (CodeGenFunction r) ()
storeState a = MT.lift . store (MultiValue.Cons a) =<< advancePtrState


advancePtrState ::
   (C a, Value (Ptr a) ~ ptr) =>
   MS.StateT ptr (CodeGenFunction r) ptr
advancePtrState = update $ advancePtrStatic 1