{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Linear (
Linear,
integerToLinear,
constToLinear,
termToLinear,
linearToConst,
linearToTerm,
linearMultiply,
linearMult,
linearToList, linearToListEx,
getCoef,
) where
import qualified Data.Map as Map
import Data.Map(Map)
data (Ord t, Num v) => Linear t v = Linear v (Map t v)
deriving instance (Num v, Eq v, Ord t, Eq t) => Eq (Linear t v)
deriving instance (Num v, Ord v, Ord t, Eq t) => Ord (Linear t v)
deriving instance (Num v, Show v, Ord t, Show t) => Show (Linear t v)
termToLinear :: (Num v, Ord t) => t -> Linear t v
termToLinear x = Linear 0 $ Map.singleton x 1
integerToLinear :: (Num v, Ord t) => Integer -> Linear t v
integerToLinear = constToLinear . fromInteger
constToLinear :: (Num v, Ord t) => v -> Linear t v
constToLinear x = Linear x Map.empty
linearToList :: (Ord t, Num v) => Linear t v -> [(Maybe t,v)]
linearToList (Linear c m) = [(Nothing,c)] ++ (map (\(a,b) -> (Just a,b)) $ Map.toList m)
linearToListEx :: (Ord t, Num v) => Linear t v -> (v,[(t,v)])
linearToListEx (Linear c m) = (c,Map.toList m)
getCoef :: (Num v, Ord t) => Maybe t -> Linear t v -> v
getCoef Nothing (Linear c _) = c
getCoef (Just t) (Linear _ m) = Map.findWithDefault 0 t m
linearMult :: (Num v, Eq v, Ord t) => v -> Linear t v -> Linear t v
linearMult m (Linear ac am) = Linear (m*ac) $ if (m==0) then Map.empty else Map.filter (/=0) $ Map.map (m*) am
linearMultiply :: (Num v, Eq v, Ord t) => Linear t v -> Linear t v -> Maybe (Linear t v)
linearMultiply (Linear ac am) bl | (Map.null am) = Just $ linearMult ac bl
linearMultiply bl (Linear ac am) | (Map.null am) = Just $ linearMult ac bl
linearMultiply _ _ = Nothing
linearToConst :: (Num v, Ord t) => Linear t v -> Maybe v
linearToConst (Linear c m) | Map.null m = Just c
linearToConst _ = Nothing
linearToTerm :: (Num v, Eq v, Ord t) => Linear t v -> Maybe t
linearToTerm (Linear c m) | (c==0 && (Map.size m)==1) =
let (t,v) = Map.findMin m
in if (v==1) then Just t else Nothing
linearToTerm _ = Nothing
instance (Num v, Eq v, Ord t, Eq t, Show t) => Num (Linear t v) where
(Linear ac am) + (Linear bc bm) = Linear (ac+bc) $ Map.filter (/=0) $ Map.unionWith (+) am bm
(Linear ac am) - (Linear bc bm) = Linear (ac-bc) $ Map.filter (/=0) $ Map.unionWith (+) am $ Map.map negate bm
a * b = case linearMultiply a b of Just x -> x; Nothing -> error "Cannot multiply generic linear expressions"
negate (Linear ac am) = Linear (-ac) $ Map.map negate am
abs (Linear ac am) | (Map.null am) = Linear (abs ac) Map.empty
abs _ = error "Cannot take abs of generic linear expressions"
signum (Linear ac am) | (Map.null am) = Linear (signum ac) Map.empty
signum _ = error "Cannot take signum of generic linear expressions"
fromInteger x = integerToLinear x