module Data.Tensor.TypeLevel
(
(:~)(..), Vec(..), Axis(..), (!),
Vector(..), VectorRing(..),
contract,
Vec0, Vec1, Vec2, Vec3, Vec4
) where
import qualified Algebra.Additive as Additive
import qualified Algebra.Ring as Ring
import Control.Monad.Failure
import System.IO.Unsafe
import Control.Applicative (Applicative(..), (<$>))
import Control.Monad hiding
(mapM_, sequence_, forM_, msum, mapM, sequence, forM)
import Data.Foldable
import Data.Traversable
import NumericPrelude hiding
(Monad, Functor, (*>),
(>>=), (>>), return, fail, fmap, mapM, mapM_, sequence, sequence_,
(=<<), foldl, foldl1, foldr, foldr1, and, or, any, all, sum, product,
concat, concatMap, maximum, minimum, elem, notElem)
import qualified NumericPrelude as Prelude
infixl 9 !
(!) :: Vector v => v a -> Axis v -> a
v ! i = component i v
data Vec a
= Vec
deriving (Eq, Ord, Show, Read)
data (n :: * -> * ) :~ a
= (n a) :~ a
deriving (Eq, Show, Read)
infixl 3 :~
instance (Ord (n a), Ord a) => Ord (n :~ a) where
compare (xs :~ x) (ys :~ y) = compare (x, xs) (y, ys)
instance Foldable Vec where
foldMap = foldMapDefault
instance Functor Vec where
fmap = fmapDefault
instance Traversable Vec where
traverse _ Vec = pure Vec
instance Applicative Vec where
pure _ = Vec
_ <*> _ = Vec
instance (Traversable n) => Foldable ((:~) n) where
foldMap = foldMapDefault
instance (Traversable n) => Functor ((:~) n) where
fmap = fmapDefault
instance (Traversable n) => Traversable ((:~) n) where
traverse f (x :~ y) = (:~) <$> traverse f x <*> f y
instance (Applicative n, Traversable n) => Applicative ((:~) n) where
pure x = pure x :~ x
(vf :~ f) <*> (vx :~ x) = (vf <*> vx) :~ (f x)
newtype (Vector v) => Axis v = Axis {axisIndex::Int} deriving (Eq,Ord,Show,Read)
class (Traversable v) => Vector v where
componentF :: (Failure StringException f) =>
Axis v
-> v a
-> f a
component :: Axis v -> v a -> a
component axis vec = unsafePerformFailure $ componentF axis vec
dimension :: v a -> Int
compose :: (Axis v -> a) -> v a
instance Vector Vec where
componentF axis Vec
= failureString $ "axis out of bound: " ++ show axis
dimension _ = 0
compose _ = Vec
instance (Vector v) => Vector ((:~) v) where
componentF (Axis i) vx@(v :~ x)
| i==dimension vx 1 = return x
| True = componentF (Axis i) v
dimension (v :~ _) = 1 + dimension v
compose f = let
xs = compose (\(Axis i)->f (Axis i)) in xs :~ f (Axis (dimension xs))
instance (Additive.C a) => Additive.C (Vec a) where
zero = compose $ const Additive.zero
x+y = compose (\i -> x!i + y!i)
xy = compose (\i -> x!i y!i)
negate x = compose (\i -> negate $ x!i)
instance (Vector v, Additive.C a) => Additive.C ((:~) v a) where
zero = compose $ const Additive.zero
x+y = compose (\i -> x!i + y!i)
xy = compose (\i -> x!i y!i)
negate x = compose (\i -> negate $ x!i)
contract :: (Vector v, Additive.C a) => (Axis v -> a) -> a
contract f = foldl (+) Additive.zero (compose f)
class (Vector v, Ring.C a) => VectorRing v a where
unitVectorF :: (Failure StringException f) => Axis v -> f (v a)
unitVector :: Axis v -> v a
unitVector = unsafePerformFailure . unitVectorF
instance (Ring.C a) => VectorRing Vec a where
unitVectorF axis
= failureString $ "axis out of bound: " ++ show axis
instance (Ring.C a, VectorRing v a, Additive.C (v a))
=> VectorRing ((:~) v) a where
unitVectorF axis@(Axis i) = ret
where
z = Additive.zero
d = dimension z
ret
| i < 0 || i >= d = failureString $ "axis out of bound: " ++ show axis
| i == d1 = return $ Additive.zero :~ Ring.one
| 0 <= i && i < d1 = liftM (:~ Additive.zero) $ unitVectorF (Axis i)
| True = return z
type Vec0 = Vec
type Vec1 = (:~) Vec0
type Vec2 = (:~) Vec1
type Vec3 = (:~) Vec2
type Vec4 = (:~) Vec3
unsafePerformFailure :: IO a -> a
unsafePerformFailure = unsafePerformIO