{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Multi.Class where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Class as Class

import qualified LLVM.Core as LLVM

import qualified Types.Data.Num as TypeNum


class C value where
   type Size value :: *
   switch ::
      f MultiValue.T ->
      f (MultiVector.T (Size value)) ->
      f value

instance C MultiValue.T where
   type Size MultiValue.T = TypeNum.D1
   switch x _ = x

instance (TypeNum.PositiveT n) => C (MultiVector.T n) where
   type Size (MultiVector.T n) = n
   switch _ x = x


newtype Undef a value = Undef {getUndef :: value a}

undef ::
   (C value, Size value ~ n, TypeNum.PositiveT n,
    Class.MakeValueTuple a, MultiVector.C a) =>
   value a
undef =
   getUndef $
   switch
      (Undef MultiValue.undef)
      (Undef MultiVector.undef)


newtype
   Add r a value =
      Add {runAdd :: value a -> value a -> LLVM.CodeGenFunction r (value a)}

add ::
   (C value,
    A.Additive al, al ~ Class.ValueTuple a,
    A.Additive vl, vl ~ MultiVector.Vector n a, n ~ Size value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
add =
   runAdd $
   switch
      (Add $ \(MultiValue.Cons x) (MultiValue.Cons y) ->
          fmap MultiValue.Cons $ A.add x y)
      (Add $ \(MultiVector.Cons x) (MultiVector.Cons y) ->
          fmap MultiVector.Cons $ A.add x y)