{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module LLVM.Extra.Multi.Value.Array where

import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value.Private as MultiValue
import LLVM.Extra.Multi.Value.Private (Repr)

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum
import qualified Type.Data.Num.Decimal.Number as Dec
import Type.Base.Proxy (Proxy(Proxy))

import Control.Applicative (Applicative(pure, (<*>)))

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import Data.Functor.Identity (Identity(Identity, runIdentity))
import Data.Functor ((<$>))

import Prelude2010
import Prelude ()



newtype Array n a = Array [a]
   deriving (Eq, Show)

instance (Dec.Integer n) => Functor (Array n) where
   fmap f (Array xs) = Array (map f xs)

instance (Dec.Integer n) => Applicative (Array n) where
   pure x =
      runIdentity $ withArraySize $
         \n -> Identity $ Array $ replicate (Dec.integralFromProxy n) x
   Array fs <*> Array xs = Array $ zipWith id fs xs

instance (Dec.Integer n) => Fold.Foldable (Array n) where
   foldMap f (Array xs) = Fold.foldMap f xs

instance (Dec.Integer n) => Trav.Traversable (Array n) where
   traverse f (Array xs) = Array <$> Trav.traverse f xs

withArraySize :: (Proxy n -> gen (Array n a)) -> gen (Array n a)
withArraySize f = f Proxy


instance (TypeNum.Natural n, Marshal.C a) => MultiValue.C (Array n a) where
   type Repr (Array n a) = LLVM.Value (LLVM.Array n (Marshal.Struct a))
   cons (Array xs) = MultiValue.consPrimitive $ LLVM.Array $ map Marshal.pack xs
   undef = MultiValue.undefPrimitive
   zero = MultiValue.zeroPrimitive
   phi = MultiValue.phiPrimitive
   addPhi = MultiValue.addPhiPrimitive

instance
   (TypeNum.Natural n, Marshal.C a,
    Dec.Natural (n Dec.:*: LLVM.SizeOf (Marshal.Struct a))) =>
      Marshal.C (Array n a) where
   pack (Array xs) = LLVM.Array $ map Marshal.pack xs
   unpack (LLVM.Array xs) = Array $ map Marshal.unpack xs

extractArrayValue ::
   (TypeNum.Natural n, LLVM.ArrayIndex n i, Marshal.C a) =>
   i -> MultiValue.T (Array n a) ->
   LLVM.CodeGenFunction r (MultiValue.T a)
extractArrayValue i (MultiValue.Cons arr) =
   MultiValue.Cons <$> (Memory.decompose =<< LLVM.extractvalue arr i)

insertArrayValue ::
   (TypeNum.Natural n, LLVM.ArrayIndex n i, Marshal.C a) =>
   i -> MultiValue.T a -> MultiValue.T (Array n a) ->
   LLVM.CodeGenFunction r (MultiValue.T (Array n a))
insertArrayValue i (MultiValue.Cons a) (MultiValue.Cons arr) =
   MultiValue.Cons <$> (flip (LLVM.insertvalue arr) i =<< Memory.compose a)