{-# LANGUAGE Safe #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Math.Tensor.LinearAlgebra.Scalar
( Lin(..)
, Poly(..)
, singletonPoly
, polyMap
, getVars
, shiftVars
, normalize
) where
import qualified Data.IntMap.Strict as IM
( IntMap
, singleton
, null
, keys
, map
, filter
, mapKeysMonotonic
, unionWith
, findMin
)
newtype Lin a = Lin (IM.IntMap a) deriving (Show, Ord, Eq)
data Poly a = Const !a
| Affine !a !(Lin a)
| NotSupported
deriving (Show, Ord, Eq)
singletonPoly :: a
-> Int
-> a
-> Poly a
singletonPoly a i v = Affine a $ Lin $ IM.singleton i v
polyMap :: (a -> b) -> Poly a -> Poly b
polyMap f (Const a) = Const (f a)
polyMap f (Affine a (Lin lin)) = Affine (f a) $ Lin $ IM.map f lin
polyMap _ _ = NotSupported
instance (Num a, Eq a) => Num (Poly a) where
Const a + Const b = Const $ a + b
Const a + Affine b lin = Affine (a+b) lin
Affine a lin + Const b = Affine (a+b) lin
Affine a (Lin m1) + Affine b (Lin m2)
| IM.null m' = Const $ a + b
| otherwise = Affine (a+b) (Lin m')
where
m' = IM.filter (/=0) $ IM.unionWith (+) m1 m2
NotSupported + _ = NotSupported
_ + NotSupported = NotSupported
negate = polyMap negate
abs (Const a) = Const $ abs a
abs _ = NotSupported
signum (Const a) = Const $ signum a
signum _ = NotSupported
fromInteger = Const . fromInteger
Const a * Const b = Const $ a * b
Const a * Affine b (Lin lin)
| a == 0 = Const 0
| otherwise = Affine (a*b) $ Lin $ IM.map (a*) lin
Affine a (Lin lin) * Const b
| b == 0 = Const 0
| otherwise = Affine (a*b) $ Lin $ IM.map (*b) lin
_ * _ = NotSupported
getVars :: Poly a -> [Int]
getVars (Const _) = []
getVars NotSupported = []
getVars (Affine _ (Lin lm)) = IM.keys lm
shiftVars :: Int -> Poly a -> Poly a
shiftVars _ (Const a) = Const a
shiftVars _ NotSupported = NotSupported
shiftVars s (Affine a (Lin lin)) =
Affine a $ Lin $ IM.mapKeysMonotonic (+s) lin
normalize :: Fractional a => Poly a -> Poly a
normalize (Const _) = Const 1
normalize NotSupported = NotSupported
normalize (Affine a (Lin lin)) = Affine (a/v) $ Lin $ IM.map (/v) lin
where
(_,v) = IM.findMin lin