{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE UndecidableInstances #-}
module Agda.TypeChecking.SizedTypes.Syntax where
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ( (<>), null )
#else
import Prelude hiding ( null )
#endif
import Data.Maybe
import Data.Foldable (Foldable)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (Traversable)
import Agda.TypeChecking.SizedTypes.Utils
import Agda.Utils.Functor
import Agda.Utils.Null
import Agda.Utils.Pretty
newtype Offset = O Int
deriving (Eq, Ord, Num, Enum)
instance Show Offset where
show (O n) = show n
instance Pretty Offset where
pretty (O n) = pretty n
instance MeetSemiLattice Offset where
meet = min
instance Plus Offset Offset Offset where
plus (O x) (O y) = O (plus x y)
newtype Rigid = RigidId { rigidId :: String }
deriving (Eq, Ord)
instance Show Rigid where
show (RigidId s) = "RigidId " ++ show s
instance Pretty Rigid where
pretty = text . rigidId
newtype Flex = FlexId { flexId :: String }
deriving (Eq, Ord)
instance Show Flex where
show (FlexId s) = "FlexId " ++ show s
instance Pretty Flex where
pretty = text . flexId
data SizeExpr' rigid flex
= Const { offset :: Offset }
| Rigid { rigid :: rigid, offset :: Offset }
| Infty
| Flex { flex :: flex, offset :: Offset }
deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
type SizeExpr = SizeExpr' Rigid Flex
data Cmp
= Lt
| Le
deriving (Show, Eq, Bounded, Enum)
instance Dioid Cmp where
compose = min
unitCompose = Le
instance Ord Cmp where
Lt <= x = True
Le <= Lt = False
Le <= Le = True
instance MeetSemiLattice Cmp where
meet = min
instance Top Cmp where
top = Le
data Constraint' rigid flex = Constraint
{ leftExpr :: SizeExpr' rigid flex
, cmp :: Cmp
, rightExpr :: SizeExpr' rigid flex
}
deriving (Show, Functor, Foldable, Traversable)
type Constraint = Constraint' Rigid Flex
data Polarity = Least | Greatest
deriving (Eq, Ord)
data PolarityAssignment flex = PolarityAssignment Polarity flex
type Polarities flex = Map flex Polarity
emptyPolarities :: Polarities flex
emptyPolarities = Map.empty
polaritiesFromAssignments :: Ord flex => [PolarityAssignment flex] -> Polarities flex
polaritiesFromAssignments = Map.fromList . map (\ (PolarityAssignment p x) -> (x,p))
getPolarity :: Ord flex => Polarities flex -> flex -> Polarity
getPolarity pols x = Map.findWithDefault Least x pols
newtype Solution rigid flex = Solution { theSolution :: Map flex (SizeExpr' rigid flex) }
deriving (Show, Null)
instance (Pretty r, Pretty f) => Pretty (Solution r f) where
pretty (Solution sol) = prettyList $ for (Map.toList sol) $ \ (x, e) ->
pretty x <+> text ":=" <+> pretty e
emptySolution :: Solution r f
emptySolution = Solution Map.empty
class Substitute r f a where
subst :: Solution r f -> a -> a
instance Ord f => Substitute r f (SizeExpr' r f) where
subst (Solution sol) e =
case e of
Flex x n -> maybe e (`plus` n) $ Map.lookup x sol
_ -> e
instance Ord f => Substitute r f (Constraint' r f) where
subst sol (Constraint e cmp e') = Constraint (subst sol e) cmp (subst sol e')
instance Substitute r f a => Substitute r f [a] where
subst = map . subst
instance Substitute r f a => Substitute r f (Map k a) where
subst = fmap . subst
instance Ord f => Substitute r f (Solution r f) where
subst s = Solution . subst s . theSolution
instance Plus (SizeExpr' r f) Offset (SizeExpr' r f) where
plus e m =
case e of
Const n -> Const $ n + m
Rigid i n -> Rigid i $ n + m
Flex x n -> Flex x $ n + m
Infty -> Infty
type CTrans r f = Constraint' r f -> Either String [Constraint' r f]
simplify1 :: (Pretty f, Pretty r, Eq r) => CTrans r f -> CTrans r f
simplify1 test c = do
let err = Left $ "size constraint " ++ prettyShow c ++ " is inconsistent"
case c of
Constraint a Le Infty -> return []
Constraint Const{} Lt Infty -> return []
Constraint Infty Lt Infty -> err
Constraint (Rigid i n) Lt Infty -> test $ Constraint (Rigid i 0) Lt Infty
Constraint a@Flex{} Lt Infty -> return [c { leftExpr = a { offset = 0 }}]
Constraint (Const n) cmp (Const m) -> if compareOffset n cmp m then return [] else err
Constraint Infty cmp Const{} -> err
Constraint (Rigid i n) cmp (Const m) ->
if compareOffset n cmp m then
test (Constraint (Rigid i 0) Le (Const (m - n - ifLe cmp 0 1)))
else err
Constraint (Flex x n) cmp (Const m) ->
if compareOffset n cmp m
then return [Constraint (Flex x 0) Le (Const (m - n - ifLe cmp 0 1))]
else err
Constraint Infty cmp Rigid{} -> err
Constraint (Const m) cmp (Rigid i n) ->
if compareOffset m cmp n then return []
else test (Constraint (Const $ m - n) cmp (Rigid i 0))
Constraint (Rigid j m) cmp (Rigid i n) | i == j ->
if compareOffset m cmp n then return [] else err
Constraint (Rigid j m) cmp (Rigid i n) -> test c
Constraint (Flex x m) cmp (Rigid i n) ->
if compareOffset m cmp n
then return [Constraint (Flex x 0) Le (Rigid i (n - m - ifLe cmp 0 1))]
else return [Constraint (Flex x $ m - n + ifLe cmp 0 1) Le (Rigid i 0)]
Constraint Infty Le (Flex x n) -> return [Constraint Infty Le (Flex x 0)]
Constraint Infty Lt (Flex x n) -> err
Constraint (Const m) cmp (Flex x n) ->
if compareOffset m cmp n then return []
else return [Constraint (Const $ m - n + ifLe cmp 0 1) Le (Flex x 0)]
Constraint (Rigid i m) cmp (Flex x n) ->
if compareOffset m cmp n
then return [Constraint (Rigid i 0) cmp (Flex x $ n - m)]
else return [Constraint (Rigid i $ m - n) cmp (Flex x 0)]
Constraint (Flex y m) cmp (Flex x n) ->
if compareOffset m cmp n
then return [Constraint (Flex y 0) cmp (Flex x $ n - m)]
else return [Constraint (Flex y $ m - n) cmp (Flex x 0)]
ifLe :: Cmp -> a -> a -> a
ifLe Le a b = a
ifLe Lt a b = b
compareOffset :: Offset -> Cmp -> Offset -> Bool
compareOffset n Le m = n <= m
compareOffset n Lt m = n < m
instance (Pretty r, Pretty f) => Pretty (SizeExpr' r f) where
pretty (Const n) = pretty n
pretty (Infty) = text "∞"
pretty (Rigid i 0) = pretty i
pretty (Rigid i n) = pretty i <> text ("+" ++ show n)
pretty (Flex x 0) = pretty x
pretty (Flex x n) = pretty x <> text ("+" ++ show n)
instance Pretty Polarity where
pretty Least = text "-"
pretty Greatest = text "+"
instance Pretty flex => Pretty (PolarityAssignment flex) where
pretty (PolarityAssignment pol flex) = pretty pol <> pretty flex
instance Pretty Cmp where
pretty Le = text "≤"
pretty Lt = text "<"
instance (Pretty r, Pretty f) => Pretty (Constraint' r f) where
pretty (Constraint a cmp b) = pretty a <+> pretty cmp <+> pretty b
class ValidOffset a where
validOffset :: a -> Bool
instance ValidOffset Offset where
validOffset = (>= 0)
instance ValidOffset (SizeExpr' r f) where
validOffset e =
case e of
Infty -> True
_ -> validOffset (offset e)
class TruncateOffset a where
truncateOffset :: a -> a
instance TruncateOffset Offset where
truncateOffset n | n >= 0 = n
| otherwise = 0
instance TruncateOffset (SizeExpr' r f) where
truncateOffset e =
case e of
Infty -> e
Const n -> Const $ truncateOffset n
Rigid i n -> Rigid i $ truncateOffset n
Flex x n -> Flex x $ truncateOffset n
class Rigids r a where
rigids :: a -> Set r
instance (Ord r, Rigids r a) => Rigids r [a] where
rigids as = Set.unions (map rigids as)
instance Rigids r (SizeExpr' r f) where
rigids (Rigid x _) = Set.singleton x
rigids _ = Set.empty
instance Ord r => Rigids r (Constraint' r f) where
rigids (Constraint l _ r) = Set.union (rigids l) (rigids r)
class Flexs flex a | a -> flex where
flexs :: a -> Set flex
instance (Ord flex, Flexs flex a) => Flexs flex [a] where
flexs as = Set.unions (map flexs as)
instance Flexs flex (SizeExpr' rigid flex) where
flexs (Flex x _) = Set.singleton x
flexs _ = Set.empty
instance (Ord flex) => Flexs flex (Constraint' rigid flex) where
flexs (Constraint l _ r) = Set.union (flexs l) (flexs r)