module LLVM.Core.Vector (MkVector(..), vector, cyclicVector, ) where
import qualified LLVM.ExecutionEngine.Target as Target
import qualified LLVM.Core.UnaryVector as UnaryVector
import qualified LLVM.Util.Proxy as Proxy
import LLVM.Core.Type (IsPrimitive, unsafeTypeRef)
import LLVM.Core.Data (Vector(Vector), FixedList)
import qualified Type.Data.Num.Decimal.Proof as DecProof
import qualified Type.Data.Num.Decimal.Number as Dec
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Decimal.Literal (D2, D4, D8)
import qualified Foreign.Storable.Traversable as Store
import Foreign.Storable (Storable(..))
import Control.Applicative (Applicative, pure, liftA2, (<*>))
import Control.Functor.HT (unzip)
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.NonEmpty as NonEmpty
import qualified Data.Empty as Empty
import Data.Traversable (Traversable, foldMapDefault)
import Data.Foldable (Foldable, foldMap)
import Data.NonEmpty ((!:))
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (replicate, map, head, unzip, zipWith, uncurry)
class (Dec.Positive n, IsPrimitive a) => MkVector n a where
type Tuple n a :: *
toVector :: Tuple n a -> Vector n a
fromVector :: Vector n a -> Tuple n a
instance (IsPrimitive a) => MkVector D2 a where
type Tuple D2 a = (a,a)
toVector (a1, a2) = vector (a1 !: a2 !: Empty.Cons)
fromVector = uncurry $ \a1 a2 -> (a1, a2)
instance (IsPrimitive a) => MkVector D4 a where
type Tuple D4 a = (a,a,a,a)
toVector (a1, a2, a3, a4) = vector (a1 !: a2 !: a3 !: a4 !: Empty.Cons)
fromVector = uncurry $ \a1 a2 a3 a4 -> (a1, a2, a3, a4)
instance (IsPrimitive a) => MkVector D8 a where
type Tuple D8 a = (a,a,a,a,a,a,a,a)
toVector (a1, a2, a3, a4, a5, a6, a7, a8) =
vector (a1 !: a2 !: a3 !: a4 !: a5 !: a6 !: a7 !: a8 !: Empty.Cons)
fromVector =
uncurry $ \a1 a2 a3 a4 a5 a6 a7 a8 ->
(a1, a2, a3, a4, a5, a6, a7, a8)
head :: (Dec.Positive n) => Vector n a -> a
head =
withPosDict1 $ \dict v ->
case dict of
DecProof.UnaryPos ->
UnaryVector.head . unaryFromDecimalVector $ v
unaryFromDecimalVector :: Vector n a -> UnaryVector.T (Dec.ToUnary n) a
unaryFromDecimalVector (Vector xs) = UnaryVector.fromFixedList xs
decimalFromUnaryVector :: UnaryVector.T (Dec.ToUnary n) a -> Vector n a
decimalFromUnaryVector = Vector . UnaryVector.toFixedList
type Curried n a b = UnaryVector.Curried (Dec.ToUnary n) a b
uncurry ::
(Dec.Natural n) =>
Curried n a b -> Vector n a -> b
uncurry f =
withNatDict1 $ \dict v ->
case dict of
DecProof.UnaryNat ->
UnaryVector.uncurry f $ unaryFromDecimalVector v
withNatDict ::
(Dec.Natural n) =>
(DecProof.UnaryNat n -> Vector n a) -> Vector n a
withNatDict f = f DecProof.unaryNat
withNatDict1 ::
(Dec.Natural n) =>
(DecProof.UnaryNat n -> Vector n a -> b) -> Vector n a -> b
withNatDict1 f = f DecProof.unaryNat
withPosDict1 ::
(Dec.Positive n) =>
(DecProof.UnaryPos n -> Vector n a -> b) -> Vector n a -> b
withPosDict1 f = f DecProof.unaryPos
withUnaryDecVector ::
(Dec.Natural n) =>
(forall m. (Dec.ToUnary n ~ m, Unary.Natural m) => UnaryVector.T m a) ->
Vector n a
withUnaryDecVector v =
withNatDict
(\dict ->
case dict of DecProof.UnaryNat -> decimalFromUnaryVector v)
instance (Storable a, Dec.Positive n, IsPrimitive a) => Storable (Vector n a) where
sizeOf a =
Target.storeSizeOfType ourTargetData $
unsafeTypeRef $ Proxy.fromValue a
alignment a =
Target.abiAlignmentOfType ourTargetData $
unsafeTypeRef $ Proxy.fromValue a
peek = Store.peekApplicative
poke = Store.poke
ourTargetData :: Target.TargetData
ourTargetData = unsafePerformIO Target.getTargetData
vector ::
(Dec.Positive n) =>
FixedList (Dec.ToUnary n) a -> Vector n a
vector = Vector
cyclicVector :: (Dec.Positive n) => NonEmpty.T [] a -> Vector n a
cyclicVector xs =
withUnaryDecVector (UnaryVector.cyclicVector xs)
replicate :: (Dec.Positive n) => a -> Vector n a
replicate a = withUnaryDecVector (pure a)
instance (Dec.Positive n) => Functor (Vector n) where
fmap f a =
withUnaryDecVector (fmap f $ unaryFromDecimalVector a)
instance (Dec.Positive n) => Applicative (Vector n) where
pure = replicate
f <*> a =
withUnaryDecVector
(unaryFromDecimalVector f <*> unaryFromDecimalVector a)
instance (Dec.Positive n) => Foldable (Vector n) where
foldMap = foldMapDefault
instance (Dec.Positive n) => Traversable (Vector n) where
sequenceA =
withNatDict1 $ \dict v ->
case dict of
DecProof.UnaryNat ->
fmap decimalFromUnaryVector $ Trav.sequenceA $
unaryFromDecimalVector v
instance (Eq a, Dec.Positive n) => Eq (Vector n a) where
x == y = Fold.and $ liftA2 (==) x y
instance (Ord a, Dec.Positive n) => Ord (Vector n a) where
compare x y =
Fold.foldr (\r rs -> if r==EQ then rs else r) EQ $
liftA2 compare x y
instance (Num a, Dec.Positive n) => Num (Vector n a) where
(+) = liftA2 (+)
() = liftA2 ()
(*) = liftA2 (*)
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger
instance (Enum a, Dec.Positive n) => Enum (Vector n a) where
succ = fmap succ
pred = fmap pred
fromEnum = error "Vector fromEnum"
toEnum = pure . toEnum
instance (Real a, Dec.Positive n) => Real (Vector n a) where
toRational = error "Vector toRational"
instance (Integral a, Dec.Positive n) => Integral (Vector n a) where
quot = liftA2 quot
rem = liftA2 rem
div = liftA2 div
mod = liftA2 mod
quotRem xs ys = unzip $ liftA2 quotRem xs ys
divMod xs ys = unzip $ liftA2 divMod xs ys
toInteger = error "Vector toInteger"
instance (Fractional a, Dec.Positive n) => Fractional (Vector n a) where
(/) = liftA2 (/)
fromRational = pure . fromRational
instance (RealFrac a, Dec.Positive n) => RealFrac (Vector n a) where
properFraction = error "Vector properFraction"
instance (Floating a, Dec.Positive n) => Floating (Vector n a) where
pi = pure pi
sqrt = fmap sqrt
log = fmap log
logBase = liftA2 logBase
(**) = liftA2 (**)
exp = fmap exp
sin = fmap sin
cos = fmap cos
tan = fmap tan
asin = fmap asin
acos = fmap acos
atan = fmap atan
sinh = fmap sinh
cosh = fmap cosh
tanh = fmap tanh
asinh = fmap asinh
acosh = fmap acosh
atanh = fmap atanh
instance (RealFloat a, Dec.Positive n) => RealFloat (Vector n a) where
floatRadix = floatRadix . head
floatDigits = floatDigits . head
floatRange = floatRange . head
decodeFloat = error "Vector decodeFloat"
encodeFloat = error "Vector encodeFloat"
exponent _ = 0
scaleFloat 0 x = x
scaleFloat _ _ = error "Vector scaleFloat"
isNaN = error "Vector isNaN"
isInfinite = error "Vector isInfinite"
isDenormalized = error "Vector isDenormalized"
isNegativeZero = error "Vector isNegativeZero"
isIEEE = isIEEE . head