{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module LLVM.Extra.Multi.Vector.Memory where

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Vector.Instance as Inst
import qualified LLVM.Extra.Multi.Value.Memory as MultiMem
import LLVM.Extra.MemoryPrivate (decomposeFromLoad, composeFromStore, )

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal ((:*:), )

import Foreign.Ptr (Ptr, )

import Control.Applicative (liftA2, liftA3, )

import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int (Int8, Int16, Int32, Int64)


class
   (TypeNum.Positive n, MultiVector.C a, LLVM.IsSized (Struct n a)) =>
      C n a where
   {-# MINIMAL (load|decompose), (store|compose) #-}
   type Struct n a :: *
   load :: Value (Ptr (Struct n a)) -> CodeGenFunction r (MultiVector.T n a)
   load ptr  =  decompose =<< LLVM.load ptr
   store :: MultiVector.T n a -> Value (Ptr (Struct n a)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value (Struct n a) -> CodeGenFunction r (MultiVector.T n a)
   decompose = decomposeFromLoad load
   compose :: MultiVector.T n a -> CodeGenFunction r (Value (Struct n a))
   compose = composeFromStore store

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D8)) =>
      C n Word8 where
   type Struct n Word8 = LLVM.Vector n Word8
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D16)) =>
      C n Word16 where
   type Struct n Word16 = LLVM.Vector n Word16
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D32)) =>
      C n Word32 where
   type Struct n Word32 = LLVM.Vector n Word32
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D64)) =>
      C n Word64 where
   type Struct n Word64 = LLVM.Vector n Word64
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D8)) =>
      C n Int8 where
   type Struct n Int8 = LLVM.Vector n Int8
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D16)) =>
      C n Int16 where
   type Struct n Int16 = LLVM.Vector n Int16
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D32)) =>
      C n Int32 where
   type Struct n Int32 = LLVM.Vector n Int32
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D64)) =>
      C n Int64 where
   type Struct n Int64 = LLVM.Vector n Int64
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D32)) =>
      C n Float where
   type Struct n Float = LLVM.Vector n Float
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D64)) =>
      C n Double where
   type Struct n Double = LLVM.Vector n Double
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance (C n a, C n b) => C n (a,b) where
   type Struct n (a,b) = (LLVM.Struct (Struct n a, (Struct n b, ())))
   decompose ab =
      liftA2 MultiVector.zip
         (decompose =<< LLVM.extractvalue ab TypeNum.d0)
         (decompose =<< LLVM.extractvalue ab TypeNum.d1)
   compose ab =
      case MultiVector.unzip ab of
         (a,b) -> do
            sa <- compose a
            sb <- compose b
            ra <- LLVM.insertvalue (LLVM.value LLVM.undef) sa TypeNum.d0
            LLVM.insertvalue ra sb TypeNum.d1

instance (C n a, C n b, C n c) => C n (a,b,c) where
   type Struct n (a,b,c) =
         (LLVM.Struct (Struct n a, (Struct n b, (Struct n c, ()))))
   decompose abc =
      liftA3 MultiVector.zip3
         (decompose =<< LLVM.extractvalue abc TypeNum.d0)
         (decompose =<< LLVM.extractvalue abc TypeNum.d1)
         (decompose =<< LLVM.extractvalue abc TypeNum.d2)
   compose abc =
      case MultiVector.unzip3 abc of
         (a,b,c) -> do
            sa <- compose a
            sb <- compose b
            sc <- compose c
            ra <- LLVM.insertvalue (LLVM.value LLVM.undef) sa TypeNum.d0
            rb <- LLVM.insertvalue ra sb TypeNum.d1
            LLVM.insertvalue rb sc TypeNum.d2


-- orphan
instance (C n a) => MultiMem.C (LLVM.Vector n a) where
   type Struct (LLVM.Vector n a) = Struct n a
   load = fmap Inst.toMultiValue . load
   store = store . Inst.fromMultiValue
   decompose = fmap Inst.toMultiValue . decompose
   compose = compose . Inst.fromMultiValue