{-# LANGUAGE MultiParamTypeClasses #-}

{-# LANGUAGE FlexibleInstances #-}

{-# LANGUAGE StandaloneDeriving #-}

{-# LANGUAGE ScopedTypeVariables #-}



module Data.Linear (

  Linear,

  integerToLinear,

  constToLinear,

  termToLinear,

--  linearOpLinear,

--  linearOpLinears,

  linearToConst,

  linearToTerm,

  linearMultiply,

  linearMult,

  linearToList, linearToListEx,

  getCoef,

) where



import qualified Data.Map as Map

import Data.Map(Map)



data (Ord t, Num v) => Linear t v = Linear v (Map t v)



deriving instance (Num v, Eq v, Ord t, Eq t) => Eq (Linear t v)

deriving instance (Num v, Ord v, Ord t, Eq t) => Ord (Linear t v)

deriving instance (Num v, Show v, Ord t, Show t) => Show (Linear t v)



termToLinear :: (Num v, Ord t) => t -> Linear t v

termToLinear x = Linear 0 $ Map.singleton x 1



integerToLinear :: (Num v, Ord t) => Integer -> Linear t v

integerToLinear = constToLinear . fromInteger



constToLinear :: (Num v, Ord t) => v -> Linear t v

constToLinear x = Linear x Map.empty



-- linearOpLinear :: (Num v, Ord t) => v -> Linear t v -> v -> Linear t v -> Linear t v

-- linearOpLinear a (Linear ac am) b (Linear bc bm) = Linear (a*ac+b*bc) $ Map.filter (/=0) $ Map.unionWith (\ax bx -> ax*a+bx*b) am bm



-- linearOpLinears :: (Num v, Ord t) => [(v,Linear t v)] -> Linear t v

-- linearOpLinears l = foldr (\(c,t) a -> linearOpLinear 1 a c t) (integerToLinear 0) l



linearToList :: (Ord t, Num v) => Linear t v -> [(Maybe t,v)]

linearToList (Linear c m) = [(Nothing,c)] ++ (map (\(a,b) -> (Just a,b)) $ Map.toList m)



linearToListEx :: (Ord t, Num v) => Linear t v -> (v,[(t,v)])

linearToListEx (Linear c m) = (c,Map.toList m)



getCoef :: (Num v, Ord t) => Maybe t -> Linear t v -> v

getCoef Nothing (Linear c _) = c

getCoef (Just t) (Linear _ m) = Map.findWithDefault 0 t m



linearMult :: (Num v, Eq v, Ord t) => v -> Linear t v -> Linear t v

linearMult m (Linear ac am) = Linear (m*ac) $ if (m==0) then Map.empty else Map.filter (/=0) $ Map.map (m*) am



linearMultiply :: (Num v, Eq v, Ord t) => Linear t v -> Linear t v -> Maybe (Linear t v)

linearMultiply (Linear ac am) bl | (Map.null am) = Just $ linearMult ac bl

linearMultiply bl (Linear ac am) | (Map.null am) = Just $ linearMult ac bl

linearMultiply _ _ = Nothing



linearToConst :: (Num v, Ord t) => Linear t v -> Maybe v

linearToConst (Linear c m) | Map.null m = Just c

linearToConst _ = Nothing



linearToTerm :: (Num v, Eq v, Ord t) => Linear t v -> Maybe t

linearToTerm (Linear c m) | (c==0 && (Map.size m)==1) = 

  let (t,v) = Map.findMin m

      in if (v==1) then Just t else Nothing

linearToTerm _ = Nothing



instance (Num v, Eq v, Ord t, Eq t, Show t) => Num (Linear t v) where

  (Linear ac am) + (Linear bc bm) = Linear (ac+bc) $ Map.filter (/=0) $ Map.unionWith (+) am bm

  (Linear ac am) - (Linear bc bm) = Linear (ac-bc) $ Map.filter (/=0) $ Map.unionWith (+) am $ Map.map negate bm

  a * b = case linearMultiply a b of Just x -> x; Nothing -> error "Cannot multiply generic linear expressions"

  negate (Linear ac am) = Linear (-ac) $ Map.map negate am

  abs (Linear ac am) | (Map.null am) = Linear (abs ac) Map.empty

  abs _ = error "Cannot take abs of generic linear expressions"

  signum (Linear ac am) | (Map.null am) = Linear (signum ac) Map.empty

  signum _ = error "Cannot take signum of generic linear expressions"

  fromInteger x = integerToLinear x