{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Synthesizer.LLVM.Frame.SerialVector.Code ( T(Cons), Value, size, fromOrdinary, toOrdinary, fromMultiVector, toMultiVector, extract, insert, modify, assemble, dissect, assemble1, dissect1, upsample, subsample, last, reverse, shiftUp, shiftUpMultiZero, shiftDown, cumulate, iterate, scale, ) where import qualified LLVM.Extra.Multi.Vector.Instance as MultiVectorInst import qualified LLVM.Extra.Multi.Vector as MultiVector import qualified LLVM.Extra.Multi.Value.Storable as Storable import qualified LLVM.Extra.Multi.Value.Marshal as Marshal import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Core as LLVM import qualified Type.Data.Num.Decimal as TypeNum import qualified Foreign.Storable as Store import Foreign.Storable (Storable) import Foreign.Ptr (castPtr) import Control.Applicative ((<$>)) import qualified Data.NonEmpty as NonEmpty import Data.Word (Word32) import Data.Tuple.HT (mapSnd) import Prelude as P hiding (last, reverse, iterate) newtype T n a = Cons (LLVM.Vector n a) deriving (Eq, Num) type Value n a = MultiValue.T (T n a) instance (TypeNum.Positive n, MultiVector.C a) => MultiValue.C (T n a) where type Repr (T n a) = MultiVector.Repr n a cons (Cons v) = fromOrdinary $ MultiValue.cons v undef = fromOrdinary MultiValue.undef zero = fromOrdinary MultiValue.zero phi bb = fmap fromOrdinary . MultiValue.phi bb . toOrdinary addPhi bb a b = MultiValue.addPhi bb (toOrdinary a) (toOrdinary b) instance (Marshal.Vector n a) => Marshal.C (T n a) where pack (Cons v) = Marshal.pack v unpack = Cons . Marshal.unpack instance (TypeNum.Positive n, Storable a) => Storable (T n a) where sizeOf (Cons v) = Store.sizeOf v alignment (Cons v) = Store.alignment v poke ptr (Cons v) = Store.poke (castPtr ptr) v peek ptr = Cons <$> Store.peek (castPtr ptr) instance (TypeNum.Positive n, Storable.Vector a, MultiVector.C a) => Storable.C (T n a) where load ptr = fmap fromOrdinary $ Storable.load =<< LLVM.bitcast ptr store v ptr = Storable.store (toOrdinary v) =<< LLVM.bitcast ptr instance (TypeNum.Positive n, MultiVector.IntegerConstant a) => MultiValue.IntegerConstant (T n a) where fromInteger' = fromMultiVector . MultiVector.fromInteger' instance (TypeNum.Positive n, MultiVector.RationalConstant a) => MultiValue.RationalConstant (T n a) where fromRational' = fromMultiVector . MultiVector.fromRational' instance (TypeNum.Positive n, MultiVector.Additive a) => MultiValue.Additive (T n a) where add = lift2 MultiVector.add sub = lift2 MultiVector.sub neg = lift1 MultiVector.neg instance (TypeNum.Positive n, MultiVector.PseudoRing a) => MultiValue.PseudoRing (T n a) where mul = lift2 MultiVector.mul scale :: (TypeNum.Positive n, MultiVector.PseudoRing a) => MultiValue.T a -> Value n a -> LLVM.CodeGenFunction r (Value n a) scale = lift1 . MultiVector.scale instance (TypeNum.Positive n, MultiVector.Real a) => MultiValue.Real (T n a) where min = lift2 MultiVector.min max = lift2 MultiVector.max abs = lift1 MultiVector.abs signum = lift1 MultiVector.signum instance (TypeNum.Positive n, MultiVector.Fraction a) => MultiValue.Fraction (T n a) where truncate = lift1 MultiVector.truncate fraction = lift1 MultiVector.fraction instance (TypeNum.Positive n, MultiVector.Field a) => MultiValue.Field (T n a) where fdiv = lift2 MultiVector.fdiv instance (TypeNum.Positive n, MultiVector.Algebraic a) => MultiValue.Algebraic (T n a) where sqrt = lift1 MultiVector.sqrt instance (TypeNum.Positive n, MultiVector.Transcendental a) => MultiValue.Transcendental (T n a) where pi = fmap fromMultiVector MultiVector.pi sin = lift1 MultiVector.sin log = lift1 MultiVector.log exp = lift1 MultiVector.exp cos = lift1 MultiVector.cos pow = lift2 MultiVector.pow instance (TypeNum.Positive n, n ~ m, MultiVector.NativeInteger n a ar, MultiValue.NativeInteger a ar) => MultiValueVec.NativeInteger (T n a) (LLVM.Vector m ar) where instance (TypeNum.Positive n, n ~ m, MultiVector.NativeFloating n a ar, MultiValue.NativeFloating a ar) => MultiValueVec.NativeFloating (T n a) (LLVM.Vector m ar) where lift1 :: (Functor f) => (MultiVector.T n a -> f (MultiVector.T m b)) -> (Value n a -> f (Value m b)) lift1 f a = fromMultiVector <$> f (toMultiVector a) lift2 :: (Functor f) => (MultiVector.T n a -> MultiVector.T m b -> f (MultiVector.T k c)) -> (Value n a -> Value m b -> f (Value k c)) lift2 f a b = fromMultiVector <$> f (toMultiVector a) (toMultiVector b) extract :: (TypeNum.Positive n, MultiVector.C x, MultiValue.T x ~ a, Value n x ~ v) => LLVM.Value Word32 -> v -> LLVM.CodeGenFunction r a extract i v = MultiVector.extract i (toMultiVector v) insert :: (TypeNum.Positive n, MultiVector.C x, MultiValue.T x ~ a, Value n x ~ v) => LLVM.Value Word32 -> a -> v -> LLVM.CodeGenFunction r v insert i a v = fromMultiVector <$> MultiVector.insert i a (toMultiVector v) modify :: (TypeNum.Positive n, MultiVector.C x, MultiValue.T x ~ a, Value n x ~ v) => LLVM.Value Word32 -> (a -> LLVM.CodeGenFunction r a) -> v -> LLVM.CodeGenFunction r v modify k f v = flip (insert k) v =<< f =<< extract k v assemble :: (TypeNum.Positive n, MultiVector.C a) => [MultiValue.T a] -> LLVM.CodeGenFunction r (Value n a) assemble = fmap fromMultiVector . MultiVector.assemble dissect :: (TypeNum.Positive n, MultiVector.C a) => Value n a -> LLVM.CodeGenFunction r [MultiValue.T a] dissect = MultiVector.dissect . toMultiVector assemble1 :: (TypeNum.Positive n, MultiVector.C a) => NonEmpty.T [] (MultiValue.T a) -> LLVM.CodeGenFunction r (Value n a) assemble1 = fmap fromMultiVector . MultiVector.assemble1 dissect1 :: (TypeNum.Positive n, MultiVector.C a) => Value n a -> LLVM.CodeGenFunction r (NonEmpty.T [] (MultiValue.T a)) dissect1 = MultiVector.dissect1 . toMultiVector sizeS :: TypeNum.Positive n => Value n a -> TypeNum.Singleton n sizeS _ = TypeNum.singleton size :: (TypeNum.Positive n, P.Integral i) => Value n a -> i size = TypeNum.integralFromSingleton . sizeS last :: (TypeNum.Positive n, MultiVector.C a) => Value n a -> LLVM.CodeGenFunction r (MultiValue.T a) last v = extract (LLVM.valueOf (size v - 1 :: Word32)) v subsample :: (TypeNum.Positive n, MultiVector.C a) => Value n a -> LLVM.CodeGenFunction r (MultiValue.T a) subsample = extract (A.zero :: LLVM.Value Word32) upsample :: (TypeNum.Positive n, MultiVector.C a) => MultiValue.T a -> LLVM.CodeGenFunction r (Value n a) upsample = fmap fromOrdinary . MultiValueVec.replicate reverse :: (TypeNum.Positive n, MultiVector.C a) => Value n a -> LLVM.CodeGenFunction r (Value n a) reverse = fmap fromMultiVector . MultiVector.reverse . toMultiVector shiftUp :: (TypeNum.Positive n, MultiVector.C x, MultiValue.T x ~ a, Value n x ~ v) => a -> v -> LLVM.CodeGenFunction r (a, v) shiftUp a v = mapSnd fromMultiVector <$> MultiVector.shiftUp a (toMultiVector v) shiftUpMultiZero :: (TypeNum.Positive n, MultiVector.C x, Value n x ~ v) => Int -> v -> LLVM.CodeGenFunction r v shiftUpMultiZero k v = fromMultiVector <$> MultiVector.shiftUpMultiZero k (toMultiVector v) shiftDown :: (TypeNum.Positive n, MultiVector.C x, MultiValue.T x ~ a, Value n x ~ v) => a -> v -> LLVM.CodeGenFunction r (a, v) shiftDown a v = mapSnd fromMultiVector <$> MultiVector.shiftDown a (toMultiVector v) iterate :: (TypeNum.Positive n, MultiVector.C a) => (MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T a)) -> MultiValue.T a -> LLVM.CodeGenFunction r (Value n a) iterate f = fmap fromOrdinary . MultiValueVec.iterate f cumulate :: (TypeNum.Positive n, MultiVector.Additive a) => MultiValue.T a -> Value n a -> LLVM.CodeGenFunction r (MultiValue.T a, Value n a) cumulate a = fmap (mapSnd fromMultiVector) . MultiVector.cumulate a . toMultiVector fromOrdinary :: MultiValue.T (LLVM.Vector n a) -> Value n a fromOrdinary = MultiValue.cast toOrdinary :: Value n a -> MultiValue.T (LLVM.Vector n a) toOrdinary = MultiValue.cast fromMultiVector :: MultiVector.T n a -> Value n a fromMultiVector = fromOrdinary . MultiVectorInst.toMultiValue toMultiVector :: Value n a -> MultiVector.T n a toMultiVector = MultiVectorInst.fromMultiValue . toOrdinary