{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Scalar where

import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Arithmetic as A

import qualified Control.Monad as Monad


{- |
The entire purpose of this datatype is to mark a type as scalar,
although it might also be interpreted as vector.
This way you can write generic operations for vectors
using the 'A.PseudoModule' class,
and specialise them to scalar types with respect to the 'A.PseudoRing' class.
From another perspective
you can consider the 'Scalar.T' type constructor a marker
where the 'A.Scalar' type function
stops reducing nested vector types to scalar types.
-}
newtype T a = Cons {forall a. T a -> a
decons :: a}

liftM :: (Monad m) => (a -> m b) -> T a -> m (T b)
liftM :: forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> m b
f (Cons a
a) = (b -> T b) -> m b -> m (T b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM b -> T b
forall a. a -> T a
Cons (m b -> m (T b)) -> m b -> m (T b)
forall a b. (a -> b) -> a -> b
$ a -> m b
f a
a

liftM2 :: (Monad m) => (a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> b -> m c
f (Cons a
a) (Cons b
b) = (c -> T c) -> m c -> m (T c)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM c -> T c
forall a. a -> T a
Cons (m c -> m (T c)) -> m c -> m (T c)
forall a b. (a -> b) -> a -> b
$ a -> b -> m c
f a
a b
b


unliftM ::
   (Monad m) =>
   (T a -> m (T r)) ->
   a -> m r
unliftM :: forall (m :: * -> *) a r. Monad m => (T a -> m (T r)) -> a -> m r
unliftM T a -> m (T r)
f a
a =
   (T r -> r) -> m (T r) -> m r
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM T r -> r
forall a. T a -> a
decons (m (T r) -> m r) -> m (T r) -> m r
forall a b. (a -> b) -> a -> b
$ T a -> m (T r)
f (a -> T a
forall a. a -> T a
Cons a
a)

unliftM2 ::
   (Monad m) =>
   (T a -> T b -> m (T r)) ->
   a -> b -> m r
unliftM2 :: forall (m :: * -> *) a b r.
Monad m =>
(T a -> T b -> m (T r)) -> a -> b -> m r
unliftM2 T a -> T b -> m (T r)
f a
a b
b =
   (T r -> r) -> m (T r) -> m r
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM T r -> r
forall a. T a -> a
decons (m (T r) -> m r) -> m (T r) -> m r
forall a b. (a -> b) -> a -> b
$ T a -> T b -> m (T r)
f (a -> T a
forall a. a -> T a
Cons a
a) (b -> T b
forall a. a -> T a
Cons b
b)

unliftM3 ::
   (Monad m) =>
   (T a -> T b -> T c -> m (T r)) ->
   a -> b -> c -> m r
unliftM3 :: forall (m :: * -> *) a b c r.
Monad m =>
(T a -> T b -> T c -> m (T r)) -> a -> b -> c -> m r
unliftM3 T a -> T b -> T c -> m (T r)
f a
a b
b c
c =
   (T r -> r) -> m (T r) -> m r
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM T r -> r
forall a. T a -> a
decons (m (T r) -> m r) -> m (T r) -> m r
forall a b. (a -> b) -> a -> b
$ T a -> T b -> T c -> m (T r)
f (a -> T a
forall a. a -> T a
Cons a
a) (b -> T b
forall a. a -> T a
Cons b
b) (c -> T c
forall a. a -> T a
Cons c
c)

unliftM4 ::
   (Monad m) =>
   (T a -> T b -> T c -> T d -> m (T r)) ->
   a -> b -> c -> d -> m r
unliftM4 :: forall (m :: * -> *) a b c d r.
Monad m =>
(T a -> T b -> T c -> T d -> m (T r)) -> a -> b -> c -> d -> m r
unliftM4 T a -> T b -> T c -> T d -> m (T r)
f a
a b
b c
c d
d =
   (T r -> r) -> m (T r) -> m r
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM T r -> r
forall a. T a -> a
decons (m (T r) -> m r) -> m (T r) -> m r
forall a b. (a -> b) -> a -> b
$ T a -> T b -> T c -> T d -> m (T r)
f (a -> T a
forall a. a -> T a
Cons a
a) (b -> T b
forall a. a -> T a
Cons b
b) (c -> T c
forall a. a -> T a
Cons c
c) (d -> T d
forall a. a -> T a
Cons d
d)

unliftM5 ::
   (Monad m) =>
   (T a -> T b -> T c -> T d -> T e -> m (T r)) ->
   a -> b -> c -> d -> e -> m r
unliftM5 :: forall (m :: * -> *) a b c d e r.
Monad m =>
(T a -> T b -> T c -> T d -> T e -> m (T r))
-> a -> b -> c -> d -> e -> m r
unliftM5 T a -> T b -> T c -> T d -> T e -> m (T r)
f a
a b
b c
c d
d e
e =
   (T r -> r) -> m (T r) -> m r
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM T r -> r
forall a. T a -> a
decons (m (T r) -> m r) -> m (T r) -> m r
forall a b. (a -> b) -> a -> b
$ T a -> T b -> T c -> T d -> T e -> m (T r)
f (a -> T a
forall a. a -> T a
Cons a
a) (b -> T b
forall a. a -> T a
Cons b
b) (c -> T c
forall a. a -> T a
Cons c
c) (d -> T d
forall a. a -> T a
Cons d
d) (e -> T e
forall a. a -> T a
Cons e
e)


instance (Tuple.Zero a) => Tuple.Zero (T a) where
   zero :: T a
zero = a -> T a
forall a. a -> T a
Cons a
forall a. Zero a => a
Tuple.zero

instance (Tuple.Undefined a) => Tuple.Undefined (T a) where
   undef :: T a
undef = a -> T a
forall a. a -> T a
Cons a
forall a. Undefined a => a
Tuple.undef

instance (Tuple.Phi a) => Tuple.Phi (T a) where
   phi :: forall r. BasicBlock -> T a -> CodeGenFunction r (T a)
phi BasicBlock
bb = (a -> T a) -> CodeGenFunction r a -> CodeGenFunction r (T a)
forall a b. (a -> b) -> CodeGenFunction r a -> CodeGenFunction r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> T a
forall a. a -> T a
Cons (CodeGenFunction r a -> CodeGenFunction r (T a))
-> (T a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicBlock -> a -> CodeGenFunction r a
forall r. BasicBlock -> a -> CodeGenFunction r a
forall a r. Phi a => BasicBlock -> a -> CodeGenFunction r a
Tuple.phi BasicBlock
bb (a -> CodeGenFunction r a)
-> (T a -> a) -> T a -> CodeGenFunction r a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. T a -> a
forall a. T a -> a
decons
   addPhi :: forall r. BasicBlock -> T a -> T a -> CodeGenFunction r ()
addPhi BasicBlock
bb (Cons a
a) (Cons a
b) = BasicBlock -> a -> a -> CodeGenFunction r ()
forall r. BasicBlock -> a -> a -> CodeGenFunction r ()
forall a r. Phi a => BasicBlock -> a -> a -> CodeGenFunction r ()
Tuple.addPhi BasicBlock
bb a
a a
b

instance (A.IntegerConstant a) => A.IntegerConstant (T a) where
   fromInteger' :: Integer -> T a
fromInteger' = a -> T a
forall a. a -> T a
Cons (a -> T a) -> (Integer -> a) -> Integer -> T a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. IntegerConstant a => Integer -> a
A.fromInteger'

instance (A.RationalConstant a) => A.RationalConstant (T a) where
   fromRational' :: Rational -> T a
fromRational' = a -> T a
forall a. a -> T a
Cons (a -> T a) -> (Rational -> a) -> Rational -> T a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. RationalConstant a => Rational -> a
A.fromRational'

instance (A.Additive a) => A.Additive (T a) where
   zero :: T a
zero = a -> T a
forall a. a -> T a
Cons a
forall a. Additive a => a
A.zero
   add :: forall r. T a -> T a -> CodeGenFunction r (T a)
add = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. Additive a => a -> a -> CodeGenFunction r a
A.add
   sub :: forall r. T a -> T a -> CodeGenFunction r (T a)
sub = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. Additive a => a -> a -> CodeGenFunction r a
A.sub
   neg :: forall r. T a -> CodeGenFunction r (T a)
neg = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Additive a => a -> CodeGenFunction r a
A.neg

instance (A.PseudoRing a) => A.PseudoRing (T a) where
   mul :: forall r. T a -> T a -> CodeGenFunction r (T a)
mul = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. PseudoRing a => a -> a -> CodeGenFunction r a
A.mul

instance (A.Field a) => A.Field (T a) where
   fdiv :: forall r. T a -> T a -> CodeGenFunction r (T a)
fdiv = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. Field a => a -> a -> CodeGenFunction r a
A.fdiv

type instance A.Scalar (T a) = T a

instance (A.PseudoRing a) => A.PseudoModule (T a) where
   scale :: forall r. Scalar (T a) -> T a -> CodeGenFunction r (T a)
scale = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. PseudoRing a => a -> a -> CodeGenFunction r a
A.mul


instance (A.Real a) => A.Real (T a) where
   min :: forall r. T a -> T a -> CodeGenFunction r (T a)
min = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. Real a => a -> a -> CodeGenFunction r a
A.min
   max :: forall r. T a -> T a -> CodeGenFunction r (T a)
max = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. Real a => a -> a -> CodeGenFunction r a
A.max
   abs :: forall r. T a -> CodeGenFunction r (T a)
abs = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Real a => a -> CodeGenFunction r a
A.abs
   signum :: forall r. T a -> CodeGenFunction r (T a)
signum = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Real a => a -> CodeGenFunction r a
A.signum

instance (A.Fraction a) => A.Fraction (T a) where
   truncate :: forall r. T a -> CodeGenFunction r (T a)
truncate = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Fraction a => a -> CodeGenFunction r a
A.truncate
   fraction :: forall r. T a -> CodeGenFunction r (T a)
fraction = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Fraction a => a -> CodeGenFunction r a
A.fraction

instance (A.Algebraic a) => A.Algebraic (T a) where
   sqrt :: forall r. T a -> CodeGenFunction r (T a)
sqrt = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Algebraic a => a -> CodeGenFunction r a
A.sqrt

instance (A.Transcendental a) => A.Transcendental (T a) where
   pi :: forall r. CodeGenFunction r (T a)
pi = (a -> T a) -> CodeGenFunction r a -> CodeGenFunction r (T a)
forall a b. (a -> b) -> CodeGenFunction r a -> CodeGenFunction r b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> T a
forall a. a -> T a
Cons CodeGenFunction r a
forall r. CodeGenFunction r a
forall a r. Transcendental a => CodeGenFunction r a
A.pi
   sin :: forall r. T a -> CodeGenFunction r (T a)
sin = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Transcendental a => a -> CodeGenFunction r a
A.sin
   cos :: forall r. T a -> CodeGenFunction r (T a)
cos = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Transcendental a => a -> CodeGenFunction r a
A.cos
   exp :: forall r. T a -> CodeGenFunction r (T a)
exp = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Transcendental a => a -> CodeGenFunction r a
A.exp
   log :: forall r. T a -> CodeGenFunction r (T a)
log = (a -> CodeGenFunction r a) -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> T a -> m (T b)
liftM a -> CodeGenFunction r a
forall r. a -> CodeGenFunction r a
forall a r. Transcendental a => a -> CodeGenFunction r a
A.log
   pow :: forall r. T a -> T a -> CodeGenFunction r (T a)
pow = (a -> a -> CodeGenFunction r a)
-> T a -> T a -> CodeGenFunction r (T a)
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> m c) -> T a -> T b -> m (T c)
liftM2 a -> a -> CodeGenFunction r a
forall r. a -> a -> CodeGenFunction r a
forall a r. Transcendental a => a -> a -> CodeGenFunction r a
A.pow