{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module LLVM.Extra.Memory (
   C(load, store, decompose, compose), modify,
   Struct,
   Record, Element, element,
   loadRecord, storeRecord, decomposeRecord, composeRecord,
   loadNewtype, storeNewtype, decomposeNewtype, composeNewtype,
   ) where

import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value.Private as MultiValue
import qualified LLVM.Extra.Scalar as Scalar
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Struct as Struct
import qualified LLVM.Extra.Either as Either
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Core as LLVM
import LLVM.Core
   (getElementPtr0,
    extractvalue, insertvalue,
    Value, -- valueOf, Vector,
    IsType, IsSized,
    CodeGenFunction, )

import qualified Type.Data.Num.Decimal as TypeNum
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Decimal (d0, d1, d2, d3)
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.FixedLength as FixedLength
import qualified Data.Complex as Complex
import Data.Complex (Complex((:+)))
import Data.Tuple.HT (fst3, snd3, thd3, )
import Data.Word (Word)

import qualified Control.Applicative.HT as App
import Control.Monad (ap, (<=<))
import Control.Applicative (Applicative, pure, liftA2, liftA3, (<*>))

import Prelude2010 hiding (maybe, either, )
import Prelude ()


{- |
An implementation of both 'Tuple.Value' and 'Memory.C'
must ensure that @haskellValue@ is compatible
with @Stored (Struct haskellValue)@ (which we want to call @llvmStruct@).
That is, writing and reading @llvmStruct@ by LLVM
must be the same as accessing @haskellValue@ by 'Storable' methods.
ToDo: In future we may also require Storable constraint for @llvmStruct@.

We use a functional dependency in order to let type inference work nicely.
-}
class (Tuple.Phi llvmValue, Tuple.Undefined llvmValue, IsType (Struct llvmValue), IsSized (Struct llvmValue)) =>
      C llvmValue where
   type Struct llvmValue
   load :: Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r llvmValue
   load ptr  =  decompose =<< LLVM.load ptr
   store :: llvmValue -> Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   {- |
   In principle it holds:

   > decompose struct = do
   >   ptr <- LLVM.alloca
   >   LLVM.store struct ptr
   >   Memory.load ptr

   but 'LLVM.alloca' will blast your stack when used in a loop.
   -}
   decompose :: Value (Struct llvmValue) -> CodeGenFunction r llvmValue
   {- |
   In principle it holds:

   > compose struct = do
   >   ptr <- LLVM.alloca
   >   Memory.store struct ptr
   >   LLVM.load ptr

   but 'LLVM.alloca' will blast your stack when used in a loop.
   -}
   compose :: llvmValue -> CodeGenFunction r (Value (Struct llvmValue))

modify ::
   (C llvmValue) =>
   (llvmValue -> CodeGenFunction r llvmValue) ->
   Value (LLVM.Ptr (Struct llvmValue)) -> CodeGenFunction r ()
modify f ptr =
   flip store ptr =<< f =<< load ptr


instance C () where
   type Struct () = LLVM.Struct ()
   load _ = return ()
   store _ _ = return ()
   decompose _ = return ()
   compose _ = return (LLVM.value $ LLVM.constStruct ())


type Record r o v = Element r o v v

data Element r o v x =
   Element {
      loadElement :: Value (LLVM.Ptr o) -> CodeGenFunction r x,
      storeElement :: Value (LLVM.Ptr o) -> v -> CodeGenFunction r (),
      extractElement :: Value o -> CodeGenFunction r x,
      insertElement :: v -> Value o -> CodeGenFunction r (Value o)
         -- State.Monoid
   }

element ::
   (C x, IsType o,
    LLVM.GetValue o n, LLVM.ValueType o n ~ Struct x,
    LLVM.GetElementPtr o (n, ()), LLVM.ElementPtrType o (n, ()) ~ Struct x) =>
   (v -> x) -> n -> Element r o v x
element field n =
   Element {
      loadElement = \ptr -> load =<< getElementPtr0 ptr (n, ()),
      storeElement = \ptr v -> store (field v) =<< getElementPtr0 ptr (n, ()),
      extractElement = \o -> decompose =<< extractvalue o n,
      insertElement = \v o -> flip (insertvalue o) n =<< compose (field v)
   }

instance Functor (Element r o v) where
   fmap f m =
      Element {
         loadElement = fmap f . loadElement m,
         storeElement = storeElement m,
         extractElement = fmap f . extractElement m,
         insertElement = insertElement m
      }

instance Applicative (Element r o v) where
   pure x =
      Element {
         loadElement = \ _ptr -> return x,
         storeElement = \ _ptr _v -> return (),
         extractElement = \ _o -> return x,
         insertElement = \ _v o -> return o
      }
   f <*> x =
      Element {
         loadElement = \ptr -> loadElement f ptr `ap` loadElement x ptr,
         storeElement = \ptr y -> storeElement f ptr y >> storeElement x ptr y,
         extractElement = \o -> extractElement f o `ap` extractElement x o,
         insertElement = \y o -> insertElement f y o >>= insertElement x y
      }


loadRecord ::
   Record r o llvmValue ->
   Value (LLVM.Ptr o) -> CodeGenFunction r llvmValue
loadRecord = loadElement

storeRecord ::
   Record r o llvmValue ->
   llvmValue -> Value (LLVM.Ptr o) -> CodeGenFunction r ()
storeRecord m y ptr = storeElement m ptr y

decomposeRecord ::
   Record r o llvmValue ->
   Value o -> CodeGenFunction r llvmValue
decomposeRecord m =
   extractElement m

composeRecord ::
   (IsType o) =>
   Record r o llvmValue ->
   llvmValue -> CodeGenFunction r (Value o)
composeRecord m v =
   insertElement m v (LLVM.value LLVM.undef)



pair ::
   (C a, C b) =>
   Record r (LLVM.Struct (Struct a, (Struct b, ()))) (a, b)
pair =
   liftA2 (,)
      (element fst d0)
      (element snd d1)

instance (C a, C b) => C (a, b) where
   type Struct (a, b) = LLVM.Struct (Struct a, (Struct b, ()))
   load = loadRecord pair
   store = storeRecord pair
   decompose = decomposeRecord pair
   compose = composeRecord pair


triple ::
   (C a, C b, C c) =>
   Record r (LLVM.Struct (Struct a, (Struct b, (Struct c, ())))) (a, b, c)
triple =
   liftA3 (,,)
      (element fst3 d0)
      (element snd3 d1)
      (element thd3 d2)

instance (C a, C b, C c) => C (a, b, c) where
   type Struct (a, b, c) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, ())))
   load = loadRecord triple
   store = storeRecord triple
   decompose = decomposeRecord triple
   compose = composeRecord triple


quadruple ::
   (C a, C b, C c, C d) =>
   Record r
      (LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ())))))
      (a, b, c, d)
quadruple =
   App.lift4 (,,,)
      (element (\(x,_,_,_) -> x) d0)
      (element (\(_,x,_,_) -> x) d1)
      (element (\(_,_,x,_) -> x) d2)
      (element (\(_,_,_,x) -> x) d3)

instance (C a, C b, C c, C d) => C (a, b, c, d) where
   type Struct (a, b, c, d) =
           LLVM.Struct (Struct a, (Struct b, (Struct c, (Struct d, ()))))
   load = loadRecord quadruple
   store = storeRecord quadruple
   decompose = decomposeRecord quadruple
   compose = composeRecord quadruple


complex ::
   (C a) =>
   Record r (LLVM.Struct (Struct a, (Struct a, ()))) (Complex a)
complex =
   liftA2 (:+)
      (element Complex.realPart d0)
      (element Complex.imagPart d1)

instance (C a) => C (Complex a) where
   type Struct (Complex a) = LLVM.Struct (Struct a, (Struct a, ()))
   load = loadRecord complex
   store = storeRecord complex
   decompose = decomposeRecord complex
   compose = composeRecord complex


instance
   (Unary.Natural n, C a,
    TypeNum.Natural (TypeNum.FromUnary n),
    TypeNum.Natural (TypeNum.FromUnary n TypeNum.:*: LLVM.SizeOf (Struct a)),
    LLVM.IsFirstClass (Struct a)) =>
      C (FixedLength.T n a) where
   type Struct (FixedLength.T n a) =
            LLVM.Array (TypeNum.FromUnary n) (Struct a)
   compose xs =
      Fold.foldlM
         (\arr (x,i) -> compose x >>= \xc -> LLVM.insertvalue arr xc i)
         (LLVM.value LLVM.undef) $
      FixedLength.zipWith (,) xs $ iterateTrav (1+) (0::Word)
   decompose arr =
      Trav.mapM (decompose <=< LLVM.extractvalue arr) $
      iterateTrav (1+) (0::Word)

iterateTrav :: (Applicative t, Trav.Traversable t) => (a -> a) -> a -> t a
iterateTrav f a0 = snd $ Trav.mapAccumL (\a () -> (f a, a)) a0 $ pure ()


maybe ::
   (C a) =>
   Record r (LLVM.Struct (Bool, (Struct a, ()))) (Maybe.T a)
maybe =
   liftA2 Maybe.Cons
      (element Maybe.isJust d0)
      (element Maybe.fromJust d1)

instance (C a) => C (Maybe.T a) where
   type Struct (Maybe.T a) = LLVM.Struct (Bool, (Struct a, ()))
   load = loadRecord maybe
   store = storeRecord maybe
   decompose = decomposeRecord maybe
   compose = composeRecord maybe


either ::
   (C a, C b) =>
   Record r (LLVM.Struct (Bool, (Struct a, (Struct b, ())))) (Either.T a b)
either =
   liftA3 Either.Cons
      (element Either.isRight d0)
      (element Either.fromLeft d1)
      (element Either.fromRight d2)

instance (C a, C b) => C (Either.T a b) where
   type Struct (Either.T a b) = LLVM.Struct (Bool, (Struct a, (Struct b, ())))
   load = loadRecord either
   store = storeRecord either
   decompose = decomposeRecord either
   compose = composeRecord either



instance (C a) => C (Scalar.T a) where
   type Struct (Scalar.T a) = Struct a
   load = loadNewtype Scalar.Cons
   store = storeNewtype Scalar.decons
   decompose = decomposeNewtype Scalar.Cons
   compose = composeNewtype Scalar.decons


instance (IsSized a) => C (Value a) where
   type Struct (Value a) = a
   load = LLVM.load
   store = LLVM.store
   decompose = return
   compose = return


type family StructStruct s
type instance StructStruct (a,as) = (Struct a, StructStruct as)
type instance StructStruct () = ()

instance
   (Struct.Phi s, Struct.Undefined s,
    LLVM.StructFields (StructStruct s),
    ConvertStruct (StructStruct s) TypeNum.D0 s) =>
      C (Struct.T s) where
   type Struct (Struct.T s) = LLVM.Struct (StructStruct s)
   decompose = fmap Struct.Cons . decomposeFields TypeNum.d0
   compose (Struct.Cons s) = composeFields TypeNum.d0 s

class ConvertStruct s i rem where
   decomposeFields ::
      Proxy i -> Value (LLVM.Struct s) -> CodeGenFunction r rem
   composeFields ::
      Proxy i -> rem -> CodeGenFunction r (Value (LLVM.Struct s))

instance
   (TypeNum.Natural i, LLVM.GetField s i, LLVM.FieldType s i ~ Struct a, C a,
    ConvertStruct s (TypeNum.Succ i) rem) =>
      ConvertStruct s i (a,rem) where
   decomposeFields i sm =
      liftA2 (,)
         (decompose =<< LLVM.extractvalue sm i)
         (decomposeFields (decSucc i) sm)
   composeFields i (a,as) = do
      sm <- composeFields (decSucc i) as
      am <- compose a
      LLVM.insertvalue sm am i

decSucc :: Proxy n -> Proxy (TypeNum.Succ n)
decSucc Proxy = Proxy

instance (LLVM.StructFields s) => ConvertStruct s i () where
   decomposeFields _ _ = return ()
   composeFields _ _ = return (LLVM.value LLVM.undef)



-- redundant IsType and IsSized constraints required for loopy instance
instance
   (IsType (Struct (MultiValue.Repr a)),
    IsSized (Struct (MultiValue.Repr a)),
    MultiValue.C a, C (MultiValue.Repr a)) =>
      C (MultiValue.T a) where
   type Struct (MultiValue.T a) = Struct (MultiValue.Repr a)
   load = fmap MultiValue.Cons . load
   store (MultiValue.Cons a) = store a
   decompose = fmap MultiValue.Cons . decompose
   compose (MultiValue.Cons a) = compose a

instance
   (IsType (Struct (MultiVector.Repr n a)),
    IsSized (Struct (MultiVector.Repr n a)),
    TypeNum.Positive n, MultiVector.C a, C (MultiVector.Repr n a)) =>
      C (MultiVector.T n a) where
   type Struct (MultiVector.T n a) = Struct (MultiVector.Repr n a)
   load = fmap MultiVector.Cons . load
   store (MultiVector.Cons a) = store a
   decompose = fmap MultiVector.Cons . decompose
   compose (MultiVector.Cons a) = compose a



loadNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (LLVM.Ptr (Struct a)) -> CodeGenFunction r llvmValue
loadNewtype wrap ptr =
   fmap wrap $ load ptr

storeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> Value (LLVM.Ptr (Struct a)) -> CodeGenFunction r ()
storeNewtype unwrap y ptr =
   store (unwrap y) ptr

decomposeNewtype ::
   (C a) =>
   (a -> llvmValue) ->
   Value (Struct a) -> CodeGenFunction r llvmValue
decomposeNewtype wrap y =
   fmap wrap $ decompose y

composeNewtype ::
   (C a) =>
   (llvmValue -> a) ->
   llvmValue -> CodeGenFunction r (Value (Struct a))
composeNewtype unwrap y =
   compose (unwrap y)