module Multilinear.Generic (
Tensor(..), (!),
isScalar, isSimple, isFiniteTensor,
dot, _elemByElem, contractionErr, tensorIndex, _standardize
) where
import Control.DeepSeq
import Data.Foldable
import Data.List
import Data.Maybe
import qualified Data.Vector as Boxed
import qualified Data.Vector.Unboxed as Unboxed
import GHC.Generics
import Multilinear.Class as Multilinear
import qualified Multilinear.Index as Index
import qualified Multilinear.Index.Finite as Finite
incompatibleTypes :: String
incompatibleTypes = "Incompatible tensor types!"
scalarIndices :: String
scalarIndices = "Scalar has no indices!"
indexNotFound :: String
indexNotFound = "This tensor has not such index!"
data Tensor a where
Scalar :: {
scalarVal :: a
} -> Tensor a
SimpleFinite :: {
tensorFiniteIndex :: Finite.Index,
tensorScalars :: Unboxed.Vector a
} -> Tensor a
FiniteTensor :: {
tensorFiniteIndex :: Finite.Index,
tensorsFinite :: Boxed.Vector (Tensor a)
} -> Tensor a
deriving (Eq, Generic)
{-# INLINE isScalar #-}
isScalar :: Unboxed.Unbox a => Tensor a -> Bool
isScalar x = case x of
Scalar _ -> True
_ -> False
{-# INLINE isSimple #-}
isSimple :: Unboxed.Unbox a => Tensor a -> Bool
isSimple x = case x of
SimpleFinite _ _ -> True
_ -> False
{-# INLINE isFiniteTensor #-}
isFiniteTensor :: Unboxed.Unbox a => Tensor a -> Bool
isFiniteTensor x = case x of
FiniteTensor _ _ -> True
_ -> False
{-# INLINE tensorIndex #-}
tensorIndex :: Unboxed.Unbox a => Tensor a -> Index.TIndex
tensorIndex x = case x of
Scalar _ -> error scalarIndices
SimpleFinite i _ -> Index.toTIndex i
FiniteTensor i _ -> Index.toTIndex i
{-# INLINE firstTensor #-}
firstTensor :: Unboxed.Unbox a => Tensor a -> Tensor a
firstTensor x = case x of
FiniteTensor _ ts -> Boxed.head ts
_ -> x
{-# INLINE (!) #-}
(!) :: Unboxed.Unbox a => Tensor a
-> Int
-> Tensor a
t ! i = case t of
Scalar _ -> error scalarIndices
SimpleFinite ind ts ->
if i >= Finite.indexSize ind then
error ("Index + " ++ show ind ++ " out of bonds!")
else Scalar $ ts Unboxed.! i
FiniteTensor ind ts ->
if i >= Finite.indexSize ind then
error ("Index + " ++ show ind ++ " out of bonds!")
else ts Boxed.! i
instance NFData a => NFData (Tensor a)
_standardize :: (Num a, Unboxed.Unbox a, NFData a) => Tensor a -> Tensor a
_standardize tens = foldr' f tens $ indices tens
where
f i t = if Index.isContravariant i then
t <<<| Index.indexName i
else t
instance (
Unboxed.Unbox a, Show a, Num a, NFData a
) => Show (Tensor a) where
show = show' . _standardize
where
show' x = case x of
Scalar v -> show v
SimpleFinite index ts -> show index ++ "S: " ++ case index of
Finite.Contravariant _ _ -> "\n" ++ tail (Unboxed.foldl' (\string e -> string ++ "\n |" ++ show e) "" ts)
_ -> "[" ++ tail (Unboxed.foldl' (\string e -> string ++ "," ++ show e) "" ts) ++ "]"
FiniteTensor index ts -> show index ++ "T: " ++ case index of
Finite.Contravariant _ _ -> "\n" ++ tail (Boxed.foldl' (\string e -> string ++ "\n |" ++ show e) "" ts)
_ -> "[" ++ tail (Boxed.foldl' (\string e -> string ++ "," ++ show e) "" ts) ++ "]"
instance (
Ord a, Unboxed.Unbox a
) => Ord (Tensor a) where
{-# INLINE (<=) #-}
Scalar x1 <= Scalar x2 = x1 <= x2
Scalar _ <= _ = True
_ <= Scalar _ = False
SimpleFinite _ ts1 <= SimpleFinite _ ts2 = ts1 <= ts2
FiniteTensor _ ts1 <= FiniteTensor _ ts2 = ts1 <= ts2
FiniteTensor _ _ <= SimpleFinite _ _ = False
SimpleFinite _ _ <= FiniteTensor _ _ = True
{-# INLINE _mergeScalars #-}
_mergeScalars :: Unboxed.Unbox a => Tensor a -> Tensor a
_mergeScalars x = case x of
(FiniteTensor index1 ts1) -> case ts1 Boxed.! 0 of
Scalar _ -> SimpleFinite index1 $ Unboxed.generate (Boxed.length ts1) (\i -> scalarVal (ts1 Boxed.! i))
_ -> FiniteTensor index1 $ _mergeScalars <$> ts1
_ -> x
{-# INLINE _elemByElem' #-}
_elemByElem' :: (Num a, Unboxed.Unbox a, NFData a)
=> Tensor a
-> Tensor a
-> (a -> a -> a)
-> (Tensor a -> Tensor a -> Tensor a)
-> Tensor a
_elemByElem' (Scalar x1) (Scalar x2) f _ = Scalar $ f x1 x2
_elemByElem' (Scalar x) t f _ = (x `f`) `Multilinear.map` t
_elemByElem' t (Scalar x) f _ = (`f` x) `Multilinear.map` t
_elemByElem' t1@(SimpleFinite index1 v1) t2@(SimpleFinite index2 _) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = FiniteTensor index1 $
Boxed.generate (Unboxed.length v1)
(\i -> (\x -> f x `Multilinear.map` t2) (v1 Unboxed.! i))
_elemByElem' t1@(FiniteTensor index1 v1) t2@(FiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| Index.indexName index1 `Data.List.elem` indicesNames t2 =
FiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
| otherwise = FiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
_elemByElem' t1@(SimpleFinite index1 _) t2@(FiniteTensor index2 v2) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = FiniteTensor index2 $ (\x -> _elemByElem' t1 x f op) <$> v2
_elemByElem' t1@(FiniteTensor index1 v1) t2@(SimpleFinite index2 _) f op
| Index.indexName index1 == Index.indexName index2 = op t1 t2
| otherwise = FiniteTensor index1 $ (\x -> _elemByElem' x t2 f op) <$> v1
{-# INLINE _elemByElem #-}
_elemByElem :: (Num a, Unboxed.Unbox a, NFData a)
=> Tensor a
-> Tensor a
-> (a -> a -> a)
-> (Tensor a -> Tensor a -> Tensor a)
-> Tensor a
_elemByElem t1 t2 f op =
let commonIndices = filter (`Data.List.elem` indicesNames t2) $ indicesNames t1
t1' = foldl' (|>>>) t1 commonIndices
t2' = foldl' (|>>>) t2 commonIndices
in t1' `deepseq` t2' `deepseq` _mergeScalars $ _elemByElem' t1' t2' f op
{-# INLINE zipT #-}
zipT :: (Num a, Unboxed.Unbox a, NFData a)
=> (Tensor a -> Tensor a -> Tensor a)
-> (Tensor a -> a -> Tensor a)
-> (a -> Tensor a -> Tensor a)
-> (a -> a -> a)
-> Tensor a
-> Tensor a
-> Tensor a
zipT _ _ _ f (SimpleFinite index1 v1) (SimpleFinite index2 v2) =
if index1 == index2 then
SimpleFinite index1 $ Unboxed.zipWith f v1 v2
else error incompatibleTypes
zipT f _ _ _ (FiniteTensor index1 v1) (FiniteTensor index2 v2) =
if index1 == index2 then
FiniteTensor index1 $ Boxed.zipWith f v1 v2
else error incompatibleTypes
zipT _ f _ _ (FiniteTensor index1 v1) (SimpleFinite index2 v2) =
if index1 == index2 then
FiniteTensor index1 $
Boxed.generate (Boxed.length v1) (\i -> f (v1 Boxed.! i) (v2 Unboxed.! i))
else error incompatibleTypes
zipT _ _ f _ (SimpleFinite index1 v1) (FiniteTensor index2 v2) =
if index1 == index2 then
FiniteTensor index1 $
Boxed.generate (Unboxed.length v1) (\i -> f (v1 Unboxed.! i) (v2 Boxed.! i))
else error incompatibleTypes
zipT _ _ _ _ _ _ = error scalarIndices
{-# INLINE dot #-}
dot :: (Num a, Unboxed.Unbox a, NFData a)
=> Tensor a
-> Tensor a
-> Tensor a
dot (SimpleFinite i1@(Finite.Covariant count1 _) ts1') (SimpleFinite i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 =
Scalar $ Unboxed.sum $ Unboxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot (FiniteTensor i1@(Finite.Covariant count1 _) ts1') (FiniteTensor i2@(Finite.Contravariant count2 _) ts2')
| count1 == count2 = Boxed.sum $ Boxed.zipWith (*) ts1' ts2'
| otherwise = contractionErr (Index.toTIndex i1) (Index.toTIndex i2)
dot t1' t2' = contractionErr (tensorIndex t1') (tensorIndex t2')
contractionErr :: Index.TIndex
-> Index.TIndex
-> Tensor a
contractionErr i1' i2' = error $
"Tensor product: " ++ incompatibleTypes ++
" - index1 is " ++ show i1' ++
" and index2 is " ++ show i2'
instance (Unboxed.Unbox a, Num a, NFData a) => Num (Tensor a) where
{-# INLINE (+) #-}
t1 + t2 = _elemByElem t1 t2 (+) $ zipT (+) (.+) (+.) (+)
{-# INLINE (-) #-}
t1 - t2 = _elemByElem t1 t2 (-) $ zipT (-) (.-) (-.) (-)
{-# INLINE (*) #-}
t1 * t2 = _elemByElem t1 t2 (*) dot
{-# INLINE abs #-}
abs t = abs `Multilinear.map` t
{-# INLINE signum #-}
signum t = signum `Multilinear.map` t
{-# INLINE fromInteger #-}
fromInteger x = Scalar $ fromInteger x
instance (Unboxed.Unbox a, Fractional a, NFData a) => Fractional (Tensor a) where
{-# INLINE (/) #-}
Scalar x1 / Scalar x2 = Scalar $ x1 / x2
Scalar x1 / t2 = (x1 /) `Multilinear.map` t2
t1 / Scalar x2 = (/ x2) `Multilinear.map` t1
_ / _ = error "TODO"
{-# INLINE fromRational #-}
fromRational x = Scalar $ fromRational x
instance (Unboxed.Unbox a, Floating a, NFData a) => Floating (Tensor a) where
{-# INLINE pi #-}
pi = Scalar pi
{-# INLINE exp #-}
exp t = exp `Multilinear.map` t
{-# INLINE log #-}
log t = log `Multilinear.map` t
{-# INLINE sin #-}
sin t = sin `Multilinear.map` t
{-# INLINE cos #-}
cos t = cos `Multilinear.map` t
{-# INLINE asin #-}
asin t = asin `Multilinear.map` t
{-# INLINE acos #-}
acos t = acos `Multilinear.map` t
{-# INLINE atan #-}
atan t = atan `Multilinear.map` t
{-# INLINE sinh #-}
sinh t = sinh `Multilinear.map` t
{-# INLINE cosh #-}
cosh t = cosh `Multilinear.map` t
{-# INLINE asinh #-}
asinh t = acosh `Multilinear.map` t
{-# INLINE acosh #-}
acosh t = acosh `Multilinear.map` t
{-# INLINE atanh #-}
atanh t = atanh `Multilinear.map` t
instance (Unboxed.Unbox a, Num a, NFData a) => Multilinear Tensor a where
{-# INLINE (.+) #-}
t .+ x = (+x) `Multilinear.map` t
{-# INLINE (.-) #-}
t .- x = (\p -> p - x) `Multilinear.map` t
{-# INLINE (.*) #-}
t .* x = (*x) `Multilinear.map` t
{-# INLINE (+.) #-}
x +. t = (x+) `Multilinear.map` t
{-# INLINE (-.) #-}
x -. t = (x-) `Multilinear.map` t
{-# INLINE (*.) #-}
x *. t = (x*) `Multilinear.map` t
{-# INLINE (.+.) #-}
t1 .+. t2 = _elemByElem t1 t2 (+) $ zipT (+) (.+) (+.) (+)
{-# INLINE (.-.) #-}
t1 .-. t2 = _elemByElem t1 t2 (-) $ zipT (+) (.+) (+.) (+)
{-# INLINE (.*.) #-}
t1 .*. t2 = _elemByElem t1 t2 (+) dot
{-# INLINE indices #-}
indices x = case x of
Scalar _ -> []
FiniteTensor i ts -> Index.toTIndex i : indices (head $ toList ts)
SimpleFinite i _ -> [Index.toTIndex i]
{-# INLINE order #-}
order x = case x of
Scalar _ -> (0,0)
SimpleFinite index _ -> case index of
Finite.Contravariant _ _ -> (1,0)
Finite.Covariant _ _ -> (0,1)
Finite.Indifferent _ _ -> (0,0)
_ -> let (cnvr, covr) = order $ firstTensor x
in case tensorIndex x of
Index.Contravariant _ _ -> (cnvr+1,covr)
Index.Covariant _ _ -> (cnvr,covr+1)
Index.Indifferent _ _ -> (cnvr,covr)
{-# INLINE size #-}
size t iname = case t of
Scalar _ -> error scalarIndices
SimpleFinite index _ ->
if Index.indexName index == iname
then Finite.indexSize index
else error indexNotFound
FiniteTensor index _ ->
if Index.indexName index == iname
then Finite.indexSize index
else size (firstTensor t) iname
{-# INLINE ($|) #-}
Scalar x $| _ = Scalar x
SimpleFinite (Finite.Contravariant isize _) ts $| (u:_, _) = SimpleFinite (Finite.Contravariant isize [u]) ts
SimpleFinite (Finite.Covariant isize _) ts $| (_, d:_) = SimpleFinite (Finite.Covariant isize [d]) ts
FiniteTensor (Finite.Contravariant isize _) ts $| (u:us, ds) = FiniteTensor (Finite.Contravariant isize [u]) $ ($| (us,ds)) <$> ts
FiniteTensor (Finite.Covariant isize _) ts $| (us, d:ds) = FiniteTensor (Finite.Covariant isize [d]) $ ($| (us,ds)) <$> ts
t $| _ = t
{-# INLINE (/\) #-}
Scalar x /\ _ = Scalar x
FiniteTensor index ts /\ n
| Index.indexName index == n =
FiniteTensor (Finite.Contravariant (Finite.indexSize index) n) $ (/\ n) <$> ts
| otherwise =
FiniteTensor index $ (/\ n) <$> ts
t1@(SimpleFinite index ts) /\ n
| Index.indexName index == n =
SimpleFinite (Finite.Contravariant (Finite.indexSize index) n) ts
| otherwise = t1
{-# INLINE (\/) #-}
Scalar x \/ _ = Scalar x
FiniteTensor index ts \/ n
| Index.indexName index == n =
FiniteTensor (Finite.Covariant (Finite.indexSize index) n) $ (\/ n) <$> ts
| otherwise =
FiniteTensor index $ (\/ n) <$> ts
t1@(SimpleFinite index ts) \/ n
| Index.indexName index == n =
SimpleFinite (Finite.Covariant (Finite.indexSize index) n) ts
| otherwise = t1
{-# INLINE transpose #-}
transpose (Scalar x) = Scalar x
transpose (FiniteTensor (Finite.Covariant count name) ts) =
FiniteTensor (Finite.Contravariant count name) (Multilinear.transpose <$> ts)
transpose (FiniteTensor (Finite.Contravariant count name) ts) =
FiniteTensor (Finite.Covariant count name) (Multilinear.transpose <$> ts)
transpose (FiniteTensor (Finite.Indifferent count name) ts) =
FiniteTensor (Finite.Indifferent count name) (Multilinear.transpose <$> ts)
transpose (SimpleFinite (Finite.Covariant count name) ts) =
SimpleFinite (Finite.Contravariant count name) ts
transpose (SimpleFinite (Finite.Contravariant count name) ts) =
SimpleFinite (Finite.Covariant count name) ts
transpose (SimpleFinite (Finite.Indifferent count name) ts) =
SimpleFinite (Finite.Indifferent count name) ts
Scalar x `shiftRight` _ = Scalar x
t1@(SimpleFinite _ _) `shiftRight` _ = t1
t1@(FiniteTensor index1 ts1) `shiftRight` ind
| Data.List.length (indicesNames t1) > 1 && Index.indexName index1 /= ind =
FiniteTensor index1 $ (|>> ind) <$> ts1
| Data.List.length (indicesNames t1) > 1 && Index.indexName index1 == ind =
let index2 = tensorFiniteIndex (ts1 Boxed.! 0)
dane = if isSimple (ts1 Boxed.! 0)
then (\un -> Boxed.generate (Unboxed.length un) (\i -> Scalar $ un Unboxed.! i)) <$>
(tensorScalars <$> ts1)
else tensorsFinite <$> ts1
daneList = Boxed.toList <$> Boxed.toList dane
transposedList = Data.List.transpose daneList
transposed = Boxed.fromList <$> Boxed.fromList transposedList
in _mergeScalars $ FiniteTensor index2 $ FiniteTensor index1 <$> transposed
| otherwise = t1
{-# INLINE map #-}
map f x = case x of
Scalar v -> Scalar $ f v
SimpleFinite index ts -> SimpleFinite index (f `Unboxed.map` ts)
FiniteTensor index ts -> FiniteTensor index $ Multilinear.map f <$> ts
instance (Unboxed.Unbox a, Num a, NFData a) => Accessible Tensor a where
{-# INLINE el #-}
el (Scalar x) _ = Scalar x
el t1@(SimpleFinite index1 _) (inds,vals) =
let indvals = zip inds vals
val = Data.List.find (\(n,_) -> [n] == Index.indexName index1) indvals
in if isJust val
then t1 ! snd (fromJust val)
else t1
el t1@(FiniteTensor index1 v1) (inds,vals) =
let indvals = zip inds vals
val = Data.List.find (\(n,_) -> [n] == Index.indexName index1) indvals
indvals1 = Data.List.filter (\(n,_) -> [n] /= Index.indexName index1) indvals
inds1 = Data.List.map fst indvals1
vals1 = Data.List.map snd indvals1
in if isJust val
then el (t1 ! snd (fromJust val)) (inds1,vals1)
else FiniteTensor index1 $ (\t -> el t (inds,vals)) <$> v1