{-# LANGUAGE CPP #-}
module Data.UnitsOfMeasure.Plugin.NormalForm
( Atom(..)
, BaseUnit
, NormUnit
, one
, varUnit
, baseUnit
, famUnit
, mkNormUnit
, (*:)
, (/:)
, (^:)
, invert
, isOne
, isConstant
, maybeConstant
, isBase
, divisible
, occurs
, ascending
, leftover
, divideExponents
, substUnit
) where
import Type
import TyCon
import VarSet
import FastString
import Outputable
import Util ( thenCmp )
import qualified Data.Foldable as Foldable
import qualified Data.Map as Map
import Data.List ( sortBy )
import Data.Maybe
import Data.Ord
import TcPluginExtras
type BaseUnit = FastString
data Atom = BaseAtom Type | VarAtom TyVar | FamAtom TyCon [Type]
instance Eq Atom where
a == b = a == b
instance Ord Atom where
compare (BaseAtom x) (BaseAtom y) = cmpType x y
compare (BaseAtom _) _ = LT
compare (VarAtom _) (BaseAtom _) = GT
compare (VarAtom a) (VarAtom b) = compare a b
compare (VarAtom _) (FamAtom _ _) = LT
compare (FamAtom f tys) (FamAtom f' tys') = cmpTyCon f f' `thenCmp` cmpTypes tys tys'
compare (FamAtom _ _) _ = GT
instance Outputable Atom where
ppr (BaseAtom b) = ppr b
ppr (VarAtom v) = ppr v
ppr (FamAtom tc tys) = ppr tc <> text " " <> ppr tys
newtype NormUnit = NormUnit { _NormUnit :: Map.Map Atom Integer }
instance Outputable NormUnit where
ppr = ppr . Map.map show . _NormUnit
one :: NormUnit
one = NormUnit Map.empty
atom :: Atom -> NormUnit
atom a = NormUnit $ Map.singleton a 1
varUnit :: TyVar -> NormUnit
varUnit = atom . VarAtom
baseUnit :: Type -> NormUnit
baseUnit = atom . BaseAtom
famUnit :: TyCon -> [Type] -> NormUnit
famUnit tc = atom . FamAtom tc
mkNormUnit :: [(Atom, Integer)] -> NormUnit
mkNormUnit = mkNormUnitMap . Map.fromList
mkNormUnitMap :: Map.Map Atom Integer -> NormUnit
mkNormUnitMap = NormUnit . Map.filter (/= 0)
(*:) :: NormUnit -> NormUnit -> NormUnit
u *: v = mkNormUnitMap $ Map.unionWith (+) (_NormUnit u) (_NormUnit v)
(/:) :: NormUnit -> NormUnit -> NormUnit
u /: v = u *: invert v
(^:) :: NormUnit -> Integer -> NormUnit
_ ^: 0 = one
u ^: n = NormUnit $ Map.map (* n) $ _NormUnit u
infixl 7 *:, /:
infixr 8 ^:
invert :: NormUnit -> NormUnit
invert = NormUnit . Map.map negate . _NormUnit
isOne :: NormUnit -> Bool
isOne = Map.null . _NormUnit
isConstant :: NormUnit -> Bool
isConstant = all isBaseLiteral . Map.keys . _NormUnit
maybeConstant :: NormUnit -> Maybe [(BaseUnit, Integer)]
maybeConstant = mapM getBase . Map.toList . _NormUnit
where
getBase (BaseAtom ty, i) = (\ b -> (b, i)) <$> isStrLitTy ty
getBase _ = Nothing
isBase :: Atom -> Bool
isBase (BaseAtom _) = True
isBase _ = False
isBaseLiteral :: Atom -> Bool
isBaseLiteral (BaseAtom ty) = isJust $ isStrLitTy ty
isBaseLiteral _ = False
divisible :: Integer -> NormUnit -> Bool
divisible i = Foldable.all (\ j -> j `rem` i == 0) . _NormUnit
occurs :: TyVar -> NormUnit -> Bool
occurs a = any occursAtom . Map.keys . _NormUnit
where
occursAtom (BaseAtom ty) = elemVarSet a $ tyVarsOfType ty
occursAtom (VarAtom b) = a == b
occursAtom (FamAtom _ tys) = elemVarSet a $ tyVarsOfTypes tys
ascending :: NormUnit -> [(Atom, Integer)]
ascending = sortBy (comparing (abs . snd)) . Map.toList . _NormUnit
leftover :: TyVar -> NormUnit -> NormUnit
leftover a = NormUnit . Map.delete (VarAtom a) . _NormUnit
divideExponents :: Integer -> NormUnit -> NormUnit
divideExponents i = mkNormUnitMap . Map.map (`quot` i) . _NormUnit
substUnit :: TyVar -> NormUnit -> NormUnit -> NormUnit
substUnit a v u = case Map.lookup (VarAtom a) $ _NormUnit u of
Nothing -> u
Just i -> (v ^: i) *: leftover a u
#if __GLASGOW_HASKELL__ > 710
tyVarsOfType :: Type -> TyCoVarSet
tyVarsOfType = tyCoVarsOfType
tyVarsOfTypes :: [Type] -> TyCoVarSet
tyVarsOfTypes = tyCoVarsOfTypes
#endif