{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module LLVM.Extra.Multi.Value.Vector (
   cons,
   fst, snd,
   fst3, snd3, thd3,
   zip, zip3,
   unzip, unzip3,

   swap,
   mapFst, mapSnd,
   mapFst3, mapSnd3, mapThd3,

   extract, insert,
   replicate,
   iterate,
   dissect,
   dissect1,
   select,
   cmp,
   take, takeRev,

   NativeInteger,
   NativeFloating,
   fromIntegral,
   truncateToInt,
   splitFractionToInt,
   ) where

import qualified LLVM.Extra.Multi.Vector.Instance as Inst
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value.Private as MultiValue
import qualified LLVM.Extra.ScalarOrVector as SoV
import LLVM.Extra.Multi.Vector.Instance (MVVector)

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum


import qualified Data.NonEmpty as NonEmpty
import qualified Data.Tuple.HT as TupleHT
import qualified Data.Tuple as Tuple
import Data.Word (Word8, Word16, Word32, Word64, Word)
import Data.Int (Int8, Int16, Int32, Int64, Int)

import Prelude (Float, Double, Bool, fmap, (.))


cons ::
   (TypeNum.Positive n, MultiVector.C a) =>
   LLVM.Vector n a -> MVVector n a
cons = Inst.toMultiValue . MultiVector.cons

fst :: MVVector n (a,b) -> MVVector n a
fst = MultiValue.lift1 Tuple.fst

snd :: MVVector n (a,b) -> MVVector n b
snd = MultiValue.lift1 Tuple.snd

swap :: MVVector n (a,b) -> MVVector n (b,a)
swap = MultiValue.lift1 TupleHT.swap

mapFst ::
   (MVVector n a0 -> MVVector n a1) ->
   MVVector n (a0,b) -> MVVector n (a1,b)
mapFst f = Tuple.uncurry zip . TupleHT.mapFst f . unzip

mapSnd ::
   (MVVector n b0 -> MVVector n b1) ->
   MVVector n (a,b0) -> MVVector n (a,b1)
mapSnd f = Tuple.uncurry zip . TupleHT.mapSnd f . unzip


fst3 :: MVVector n (a,b,c) -> MVVector n a
fst3 = MultiValue.lift1 TupleHT.fst3

snd3 :: MVVector n (a,b,c) -> MVVector n b
snd3 = MultiValue.lift1 TupleHT.snd3

thd3 :: MVVector n (a,b,c) -> MVVector n c
thd3 = MultiValue.lift1 TupleHT.thd3

mapFst3 ::
   (MVVector n a0 -> MVVector n a1) ->
   MVVector n (a0,b,c) -> MVVector n (a1,b,c)
mapFst3 f = TupleHT.uncurry3 zip3 . TupleHT.mapFst3 f . unzip3

mapSnd3 ::
   (MVVector n b0 -> MVVector n b1) ->
   MVVector n (a,b0,c) -> MVVector n (a,b1,c)
mapSnd3 f = TupleHT.uncurry3 zip3 . TupleHT.mapSnd3 f . unzip3

mapThd3 ::
   (MVVector n c0 -> MVVector n c1) ->
   MVVector n (a,b,c0) -> MVVector n (a,b,c1)
mapThd3 f = TupleHT.uncurry3 zip3 . TupleHT.mapThd3 f . unzip3


zip :: MVVector n a -> MVVector n b -> MVVector n (a,b)
zip (MultiValue.Cons a) (MultiValue.Cons b) = MultiValue.Cons (a,b)

zip3 :: MVVector n a -> MVVector n b -> MVVector n c -> MVVector n (a,b,c)
zip3 (MultiValue.Cons a) (MultiValue.Cons b) (MultiValue.Cons c) =
   MultiValue.Cons (a,b,c)

unzip :: MVVector n (a,b) -> (MVVector n a, MVVector n b)
unzip (MultiValue.Cons (a,b)) = (MultiValue.Cons a, MultiValue.Cons b)

unzip3 :: MVVector n (a,b,c) -> (MVVector n a, MVVector n b, MVVector n c)
unzip3 (MultiValue.Cons (a,b,c)) =
   (MultiValue.Cons a, MultiValue.Cons b, MultiValue.Cons c)


extract ::
   (TypeNum.Positive n, MultiVector.C a) =>
   LLVM.Value Word32 -> MVVector n a ->
   LLVM.CodeGenFunction r (MultiValue.T a)
extract k v = MultiVector.extract k (Inst.fromMultiValue v)

insert ::
   (TypeNum.Positive n, MultiVector.C a) =>
   LLVM.Value Word32 -> MultiValue.T a ->
   MVVector n a -> LLVM.CodeGenFunction r (MVVector n a)
insert k a = Inst.liftMultiValueM (MultiVector.insert k a)


replicate ::
   (TypeNum.Positive n, MultiVector.C a) =>
   MultiValue.T a -> LLVM.CodeGenFunction r (MVVector n a)
replicate = fmap Inst.toMultiValue . MultiVector.replicate

iterate ::
   (TypeNum.Positive n, MultiVector.C a) =>
   (MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T a)) ->
   MultiValue.T a -> LLVM.CodeGenFunction r (MVVector n a)
iterate f = fmap Inst.toMultiValue . MultiVector.iterate f

take ::
   (TypeNum.Positive n, TypeNum.Positive m, MultiVector.C a) =>
   MVVector n a -> LLVM.CodeGenFunction r (MVVector m a)
take = Inst.liftMultiValueM MultiVector.take

takeRev ::
   (TypeNum.Positive n, TypeNum.Positive m, MultiVector.C a) =>
   MVVector n a -> LLVM.CodeGenFunction r (MVVector m a)
takeRev = Inst.liftMultiValueM MultiVector.takeRev


dissect ::
   (TypeNum.Positive n, MultiVector.C a) =>
   MVVector n a -> LLVM.CodeGenFunction r [MultiValue.T a]
dissect = MultiVector.dissect . Inst.fromMultiValue

dissect1 ::
   (TypeNum.Positive n, MultiVector.C a) =>
   MVVector n a -> LLVM.CodeGenFunction r (NonEmpty.T [] (MultiValue.T a))
dissect1 = MultiVector.dissect1 . Inst.fromMultiValue

select ::
   (TypeNum.Positive n, MultiVector.Select a) =>
   MVVector n Bool ->
   MVVector n a -> MVVector n a ->
   LLVM.CodeGenFunction r (MVVector n a)
select = Inst.liftMultiValueM3 MultiVector.select

cmp ::
   (TypeNum.Positive n, MultiVector.Comparison a) =>
   LLVM.CmpPredicate ->
   MVVector n a -> MVVector n a ->
   LLVM.CodeGenFunction r (MVVector n Bool)
cmp = Inst.liftMultiValueM2 . MultiVector.cmp


{-
ToDo: make this a super-class of MultiValue.NativeInteger
problem: we need MultiValue.Repr, which provokes an import cycle
maybe we should break the cycle using a ConstraintKind,
i.e. define class NativeIntegerVec in MultiValue,
and define NativeInteger = MultiValue.NativeIntegerVec here
and export only MultiValueVec.NativeInteger constraint synonym.
-}
class
   (MultiValue.Repr i ~ LLVM.Value ir,
    LLVM.CmpRet ir, LLVM.IsInteger ir, SoV.IntegerConstant ir) =>
      NativeInteger i ir where

instance NativeInteger Word   Word   where
instance NativeInteger Word8  Word8  where
instance NativeInteger Word16 Word16 where
instance NativeInteger Word32 Word32 where
instance NativeInteger Word64 Word64 where

instance NativeInteger Int   Int   where
instance NativeInteger Int8  Int8  where
instance NativeInteger Int16 Int16 where
instance NativeInteger Int32 Int32 where
instance NativeInteger Int64 Int64 where

instance
   (TypeNum.Positive n, n ~ m,
    MultiVector.NativeInteger n i ir,
    MultiValue.NativeInteger i ir) =>
      NativeInteger (LLVM.Vector n i) (LLVM.Vector m ir) where


class
   (MultiValue.Repr a ~ LLVM.Value ar,
    LLVM.CmpRet ar,  SoV.RationalConstant ar, LLVM.IsFloating ar) =>
      NativeFloating a ar where

instance NativeFloating Float  Float  where
instance NativeFloating Double Double where

instance
   (TypeNum.Positive n, n ~ m,
    MultiVector.NativeFloating n a ar,
    MultiValue.NativeFloating a ar) =>
      NativeFloating (LLVM.Vector n a) (LLVM.Vector m ar) where

fromIntegral ::
   (NativeInteger i ir, NativeFloating a ar,
    LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
   MultiValue.T i -> LLVM.CodeGenFunction r (MultiValue.T a)
fromIntegral = MultiValue.liftM LLVM.inttofp


truncateToInt ::
   (NativeInteger i ir, NativeFloating a ar,
    LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
   MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T i)
truncateToInt = MultiValue.liftM LLVM.fptoint

splitFractionToInt ::
   (NativeInteger i ir, NativeFloating a ar,
    LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
   MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T (i,a))
splitFractionToInt = MultiValue.liftM SoV.splitFractionToInt