module Numeric.LinearAlgebra.Tensor (
Tensor, Variant(..),
listTensor,
superindex, subindex,
vector, covector, transf,
switch, cov, contrav, forget,
module Numeric.LinearAlgebra.Array
) where
import Numeric.LinearAlgebra.Array.Internal
import Numeric.LinearAlgebra.HMatrix hiding (vector)
import Numeric.LinearAlgebra.Array
import Data.List(intersperse)
type Tensor t = NArray Variant t
data Variant = Contra | Co deriving (Eq,Show)
instance Compat Variant where
compat d1 d2 = iDim d1 == iDim d2 && iType d1 /= iType d2
opos (Idx x n s) = Idx (flipV x) n s
instance Show (Idx Variant) where
show (Idx Co n s) = s ++ "_" ++ show n
show (Idx Contra n s) = s ++ "^" ++ show n
instance (Coord t) => Show (Tensor t) where
show t | null (dims t) = "scalar "++ show (coords t `atIndex` 0)
| order t == 1 = ixn ++ show n ++" " ++ (show . toList . coords $ t)
| otherwise = ixn ++ show n ++ " [" ++ ps ++ "]"
where n = head (namesR t)
ps = concat $ intersperse ", " $ map show (parts t n)
ixn = idxn (typeOf n t)
idxn Co = "subindex "
idxn Contra = "superindex "
flipV Co = Contra
flipV Contra = Co
listTensor :: Coord t
=> [Int]
-> [t]
-> Tensor t
listTensor ds cs = mkNArray dms (product ds' |> (cs ++ repeat 0))
where dms = zipWith3 Idx (map f ds) ds' (map show [1::Int ..])
ds' = map abs ds
f n | n>0 = Contra
| otherwise = Co
superindex :: Coord t => Name -> [Tensor t] -> Tensor t
superindex = newIndex Contra
subindex :: Coord t => Name -> [Tensor t] -> Tensor t
subindex = newIndex Co
switch :: Tensor t -> Tensor t
switch = mapTypes flipV
cov :: NArray i t -> Tensor t
cov = mapTypes (const Co)
contrav :: NArray i t -> Tensor t
contrav = mapTypes (const Contra)
forget :: NArray i t -> Array t
forget = mapTypes (const None)
vector :: [Double] -> Tensor Double
vector = fromVector Contra . fromList
covector :: [Double] -> Tensor Double
covector = fromVector Co . fromList
transf :: [[Double]] -> Tensor Double
transf = fromMatrix Contra Co . fromLists