{-# LANGUAGE TupleSections #-}

module Data.UnitsOfMeasure.Plugin.NormalForm
  ( Atom(..)
  , BaseUnit
  , NormUnit
    -- * Constructors
  , one
  , varUnit
  , baseUnit
  , famUnit
  , mkNormUnit

    -- * Algebraic operations
  , (*:)
  , (/:)
  , (^:)
  , invert

    -- * Predicates
  , isOne
  , maybeSingleVariable
  , isConstant
  , maybeConstant
  , isBase
  , divisible
  , occurs

    -- * Destructors
  , ascending
  , leftover
  , divideExponents
  , substUnit
  ) where

import Prelude hiding ((<>))
import GhcApi (elemVarSet, tyCoVarsOfType, tyCoVarsOfTypes, text, (<>))
import GhcApi.Compare (cmpType, cmpTypes, cmpTyCon, thenCmp)

import GHC.TcPlugin.API

import qualified Data.Foldable as Foldable
import qualified Data.Map as Map
import Data.List ( sortOn )
import Data.Maybe

-- | Base units are just represented as strings, for simplicity
type BaseUnit = FastString

-- | An atom in the normal form is either a base unit, a variable or a
-- stuck type family application (but not one of the built-in type
-- families that correspond to group operations).
data Atom = BaseAtom Type | VarAtom TyVar | FamAtom TyCon [Type]

instance Eq Atom where
  Atom
a == :: Atom -> Atom -> Bool
== Atom
b = Atom
a Atom -> Atom -> Bool
forall a. Eq a => a -> a -> Bool
== Atom
b

-- TODO: using cmpTypes here probably isn't ideal, but does it matter?
instance Ord Atom where
  compare :: Atom -> Atom -> Ordering
compare (BaseAtom Type
x)    (BaseAtom Type
y)      = Type -> Type -> Ordering
cmpType Type
x Type
y
  compare (BaseAtom Type
_)    Atom
_                 = Ordering
LT
  compare (VarAtom  TyVar
_)    (BaseAtom Type
_)      = Ordering
GT
  compare (VarAtom  TyVar
a)    (VarAtom  TyVar
b)      = TyVar -> TyVar -> Ordering
forall a. Ord a => a -> a -> Ordering
compare TyVar
a TyVar
b
  compare (VarAtom  TyVar
_)    (FamAtom TyCon
_ [Type]
_)     = Ordering
LT
  compare (FamAtom TyCon
f [Type]
tys) (FamAtom TyCon
f' [Type]
tys') = TyCon -> TyCon -> Ordering
cmpTyCon TyCon
f TyCon
f' Ordering -> Ordering -> Ordering
`thenCmp` [Type] -> [Type] -> Ordering
cmpTypes [Type]
tys [Type]
tys'
  compare (FamAtom TyCon
_ [Type]
_)   Atom
_                 = Ordering
GT

instance Outputable Atom where
  ppr :: Atom -> SDoc
ppr (BaseAtom Type
b) = Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
b
  ppr (VarAtom  TyVar
v) = TyVar -> SDoc
forall a. Outputable a => a -> SDoc
ppr TyVar
v
  ppr (FamAtom TyCon
tc [Type]
tys) = TyCon -> SDoc
forall a. Outputable a => a -> SDoc
ppr TyCon
tc SDoc -> SDoc -> SDoc
<> String -> SDoc
text String
" " SDoc -> SDoc -> SDoc
<> [Type] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Type]
tys


-- | A unit normal form is a signed multiset of atoms; we maintain the
-- invariant that the map does not contain any zero values.
newtype NormUnit = NormUnit { NormUnit -> Map Atom Integer
_NormUnit :: Map.Map Atom Integer }

instance Outputable NormUnit where
    ppr :: NormUnit -> SDoc
ppr = Map Atom String -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Map Atom String -> SDoc)
-> (NormUnit -> Map Atom String) -> NormUnit -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> String) -> Map Atom Integer -> Map Atom String
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map Integer -> String
forall a. Show a => a -> String
show (Map Atom Integer -> Map Atom String)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> Map Atom String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit


-- | The group identity, representing the dimensionless unit
one :: NormUnit
one :: NormUnit
one = Map Atom Integer -> NormUnit
NormUnit Map Atom Integer
forall k a. Map k a
Map.empty

-- | Construct a normalised unit from an atom
atom :: Atom -> NormUnit
atom :: Atom -> NormUnit
atom Atom
a = Map Atom Integer -> NormUnit
NormUnit (Map Atom Integer -> NormUnit) -> Map Atom Integer -> NormUnit
forall a b. (a -> b) -> a -> b
$ Atom -> Integer -> Map Atom Integer
forall k a. k -> a -> Map k a
Map.singleton Atom
a Integer
1

-- | Construct a normalised unit from a single variable
varUnit :: TyVar -> NormUnit
varUnit :: TyVar -> NormUnit
varUnit = Atom -> NormUnit
atom (Atom -> NormUnit) -> (TyVar -> Atom) -> TyVar -> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVar -> Atom
VarAtom

-- | Construct a normalised unit from a single base unit
baseUnit :: Type -> NormUnit
baseUnit :: Type -> NormUnit
baseUnit = Atom -> NormUnit
atom (Atom -> NormUnit) -> (Type -> Atom) -> Type -> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Atom
BaseAtom

-- | Construct a normalised unit from a stuck type family application:
-- this must not be one of the built-in type families!
famUnit :: TyCon -> [Type] -> NormUnit
famUnit :: TyCon -> [Type] -> NormUnit
famUnit TyCon
tc = Atom -> NormUnit
atom (Atom -> NormUnit) -> ([Type] -> Atom) -> [Type] -> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyCon -> [Type] -> Atom
FamAtom TyCon
tc

-- | Construct a normalised unit from a list of atom-exponent pairs
mkNormUnit :: [(Atom, Integer)] -> NormUnit
mkNormUnit :: [(Atom, Integer)] -> NormUnit
mkNormUnit = Map Atom Integer -> NormUnit
mkNormUnitMap (Map Atom Integer -> NormUnit)
-> ([(Atom, Integer)] -> Map Atom Integer)
-> [(Atom, Integer)]
-> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Atom, Integer)] -> Map Atom Integer
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList

-- | Construct a normalised unit from an atom-exponent map, applying
-- the signed multiset invariant
mkNormUnitMap :: Map.Map Atom Integer -> NormUnit
mkNormUnitMap :: Map Atom Integer -> NormUnit
mkNormUnitMap =  Map Atom Integer -> NormUnit
NormUnit (Map Atom Integer -> NormUnit)
-> (Map Atom Integer -> Map Atom Integer)
-> Map Atom Integer
-> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Bool) -> Map Atom Integer -> Map Atom Integer
forall a k. (a -> Bool) -> Map k a -> Map k a
Map.filter (Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0)


-- | Multiplication of normalised units
(*:) :: NormUnit -> NormUnit -> NormUnit
NormUnit
u *: :: NormUnit -> NormUnit -> NormUnit
*: NormUnit
v = Map Atom Integer -> NormUnit
mkNormUnitMap (Map Atom Integer -> NormUnit) -> Map Atom Integer -> NormUnit
forall a b. (a -> b) -> a -> b
$ (Integer -> Integer -> Integer)
-> Map Atom Integer -> Map Atom Integer -> Map Atom Integer
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) (NormUnit -> Map Atom Integer
_NormUnit NormUnit
u) (NormUnit -> Map Atom Integer
_NormUnit NormUnit
v)

-- | Division of normalised units
(/:) :: NormUnit -> NormUnit -> NormUnit
NormUnit
u /: :: NormUnit -> NormUnit -> NormUnit
/: NormUnit
v = NormUnit
u NormUnit -> NormUnit -> NormUnit
*: NormUnit -> NormUnit
invert NormUnit
v

-- | Expontentiation of normalised units
(^:) :: NormUnit -> Integer -> NormUnit
NormUnit
_ ^: :: NormUnit -> Integer -> NormUnit
^: Integer
0 = NormUnit
one
NormUnit
u ^: Integer
n = Map Atom Integer -> NormUnit
NormUnit (Map Atom Integer -> NormUnit) -> Map Atom Integer -> NormUnit
forall a b. (a -> b) -> a -> b
$ (Integer -> Integer) -> Map Atom Integer -> Map Atom Integer
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
n) (Map Atom Integer -> Map Atom Integer)
-> Map Atom Integer -> Map Atom Integer
forall a b. (a -> b) -> a -> b
$ NormUnit -> Map Atom Integer
_NormUnit NormUnit
u

infixl 7 *:, /:
infixr 8 ^:

-- | Invert a normalised unit
invert :: NormUnit -> NormUnit
invert :: NormUnit -> NormUnit
invert = Map Atom Integer -> NormUnit
NormUnit (Map Atom Integer -> NormUnit)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer) -> Map Atom Integer -> Map Atom Integer
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map Integer -> Integer
forall a. Num a => a -> a
negate (Map Atom Integer -> Map Atom Integer)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> Map Atom Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit


-- | Test whether a unit is dimensionless
isOne :: NormUnit -> Bool
isOne :: NormUnit -> Bool
isOne = Map Atom Integer -> Bool
forall k a. Map k a -> Bool
Map.null (Map Atom Integer -> Bool)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit

-- | Test whether a unit consists of a single variable with multiplicity 1.
maybeSingleVariable :: NormUnit -> Maybe TyVar
maybeSingleVariable :: NormUnit -> Maybe TyVar
maybeSingleVariable NormUnit
x = case Map Atom Integer -> [(Atom, Integer)]
forall k a. Map k a -> [(k, a)]
Map.toList (NormUnit -> Map Atom Integer
_NormUnit NormUnit
x) of
    [(VarAtom TyVar
v, Integer
1)] -> TyVar -> Maybe TyVar
forall a. a -> Maybe a
Just TyVar
v
    [(Atom, Integer)]
_                -> Maybe TyVar
forall a. Maybe a
Nothing

-- | Test whether a unit is constant (contains only base literals)
isConstant :: NormUnit -> Bool
isConstant :: NormUnit -> Bool
isConstant = (Atom -> Bool) -> [Atom] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Atom -> Bool
isBaseLiteral ([Atom] -> Bool) -> (NormUnit -> [Atom]) -> NormUnit -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Atom Integer -> [Atom]
forall k a. Map k a -> [k]
Map.keys (Map Atom Integer -> [Atom])
-> (NormUnit -> Map Atom Integer) -> NormUnit -> [Atom]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit

-- | Extract the base units if a unit is constant
maybeConstant :: NormUnit -> Maybe [(BaseUnit, Integer)]
maybeConstant :: NormUnit -> Maybe [(BaseUnit, Integer)]
maybeConstant = ((Atom, Integer) -> Maybe (BaseUnit, Integer))
-> [(Atom, Integer)] -> Maybe [(BaseUnit, Integer)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Atom, Integer) -> Maybe (BaseUnit, Integer)
forall {t}. (Atom, t) -> Maybe (BaseUnit, t)
getBase ([(Atom, Integer)] -> Maybe [(BaseUnit, Integer)])
-> (NormUnit -> [(Atom, Integer)])
-> NormUnit
-> Maybe [(BaseUnit, Integer)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Atom Integer -> [(Atom, Integer)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map Atom Integer -> [(Atom, Integer)])
-> (NormUnit -> Map Atom Integer) -> NormUnit -> [(Atom, Integer)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit
  where
    getBase :: (Atom, t) -> Maybe (BaseUnit, t)
getBase (BaseAtom Type
ty, t
i) = (, t
i) (BaseUnit -> (BaseUnit, t))
-> Maybe BaseUnit -> Maybe (BaseUnit, t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Maybe BaseUnit
isStrLitTy Type
ty
    getBase (Atom, t)
_                = Maybe (BaseUnit, t)
forall a. Maybe a
Nothing

-- | Test whether an atom is a base unit (but not necessarily a
-- *literal*, e.g. it could be @Base b@ for some variable @b@)
isBase :: Atom -> Bool
isBase :: Atom -> Bool
isBase (BaseAtom Type
_) = Bool
True
isBase Atom
_            = Bool
False

-- | Test whether an atom is a literal base unit
isBaseLiteral :: Atom -> Bool
isBaseLiteral :: Atom -> Bool
isBaseLiteral (BaseAtom Type
ty) = Maybe BaseUnit -> Bool
forall a. Maybe a -> Bool
isJust (Maybe BaseUnit -> Bool) -> Maybe BaseUnit -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Maybe BaseUnit
isStrLitTy Type
ty
isBaseLiteral Atom
_             = Bool
False

-- | Test whether all exponents in a unit are divisble by an integer
divisible :: Integer -> NormUnit -> Bool
divisible :: Integer -> NormUnit -> Bool
divisible Integer
i = (Integer -> Bool) -> Map Atom Integer -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
Foldable.all (\ Integer
j -> Integer
j Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) (Map Atom Integer -> Bool)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit

-- | Test whether a type variable occurs in a unit (possibly under a
-- type family application)
occurs :: TyVar -> NormUnit -> Bool
occurs :: TyVar -> NormUnit -> Bool
occurs TyVar
a = (Atom -> Bool) -> [Atom] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Atom -> Bool
occursAtom ([Atom] -> Bool) -> (NormUnit -> [Atom]) -> NormUnit -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Atom Integer -> [Atom]
forall k a. Map k a -> [k]
Map.keys (Map Atom Integer -> [Atom])
-> (NormUnit -> Map Atom Integer) -> NormUnit -> [Atom]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit
  where
    occursAtom :: Atom -> Bool
occursAtom (BaseAtom Type
ty)   = TyVar -> VarSet -> Bool
elemVarSet TyVar
a (VarSet -> Bool) -> VarSet -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> VarSet
tyCoVarsOfType Type
ty
    occursAtom (VarAtom TyVar
b)     = TyVar
a TyVar -> TyVar -> Bool
forall a. Eq a => a -> a -> Bool
== TyVar
b
    occursAtom (FamAtom TyCon
_ [Type]
tys) = TyVar -> VarSet -> Bool
elemVarSet TyVar
a (VarSet -> Bool) -> VarSet -> Bool
forall a b. (a -> b) -> a -> b
$ [Type] -> VarSet
tyCoVarsOfTypes [Type]
tys


-- | View a unit as a list of atoms in order of ascending absolute exponent
ascending :: NormUnit -> [(Atom, Integer)]
ascending :: NormUnit -> [(Atom, Integer)]
ascending = ((Atom, Integer) -> Integer)
-> [(Atom, Integer)] -> [(Atom, Integer)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Integer -> Integer
forall a. Num a => a -> a
abs (Integer -> Integer)
-> ((Atom, Integer) -> Integer) -> (Atom, Integer) -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Atom, Integer) -> Integer
forall a b. (a, b) -> b
snd) ([(Atom, Integer)] -> [(Atom, Integer)])
-> (NormUnit -> [(Atom, Integer)]) -> NormUnit -> [(Atom, Integer)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Atom Integer -> [(Atom, Integer)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map Atom Integer -> [(Atom, Integer)])
-> (NormUnit -> Map Atom Integer) -> NormUnit -> [(Atom, Integer)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit

-- | Drop a variable from a unit
leftover :: TyVar -> NormUnit -> NormUnit
leftover :: TyVar -> NormUnit -> NormUnit
leftover TyVar
a = Map Atom Integer -> NormUnit
NormUnit (Map Atom Integer -> NormUnit)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Atom -> Map Atom Integer -> Map Atom Integer
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete (TyVar -> Atom
VarAtom TyVar
a) (Map Atom Integer -> Map Atom Integer)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> Map Atom Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit

-- | Divide all the exponents in a unit by an integer
divideExponents :: Integer -> NormUnit -> NormUnit
divideExponents :: Integer -> NormUnit -> NormUnit
divideExponents Integer
i = Map Atom Integer -> NormUnit
mkNormUnitMap (Map Atom Integer -> NormUnit)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> NormUnit
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer) -> Map Atom Integer -> Map Atom Integer
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
i) (Map Atom Integer -> Map Atom Integer)
-> (NormUnit -> Map Atom Integer) -> NormUnit -> Map Atom Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NormUnit -> Map Atom Integer
_NormUnit

-- | Substitute the first unit for the variable in the second unit
substUnit :: TyVar -> NormUnit -> NormUnit -> NormUnit
substUnit :: TyVar -> NormUnit -> NormUnit -> NormUnit
substUnit TyVar
a NormUnit
v NormUnit
u = case Atom -> Map Atom Integer -> Maybe Integer
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (TyVar -> Atom
VarAtom TyVar
a) (Map Atom Integer -> Maybe Integer)
-> Map Atom Integer -> Maybe Integer
forall a b. (a -> b) -> a -> b
$ NormUnit -> Map Atom Integer
_NormUnit NormUnit
u of
                    Maybe Integer
Nothing -> NormUnit
u
                    Just Integer
i  -> (NormUnit
v NormUnit -> Integer -> NormUnit
^: Integer
i) NormUnit -> NormUnit -> NormUnit
*: TyVar -> NormUnit -> NormUnit
leftover TyVar
a NormUnit
u