{-# LANGUAGE FlexibleInstances, UndecidableInstances, FlexibleContexts, TemplateHaskell #-}
module AST.Term.Row
( RowConstraints(..), RowKey
, RowExtend(..), eKey, eVal, eRest, KWitness(..)
, FlatRowExtends(..), freExtends, freRest
, flattenRow, flattenRowExtend, unflattenRow
, verifyRowExtendConstraints, rowExtendStructureMismatch
, rowElementInfer
) where
import AST
import AST.TH.Internal.Instances (makeCommonInstances)
import AST.Unify
import AST.Unify.Lookup (semiPruneLookup)
import AST.Unify.New (newTerm, newUnbound)
import AST.Unify.Term (UTerm(..), _UTerm, uBody)
import Control.DeepSeq (NFData)
import Control.Lens (Prism', Lens', makeLenses, contains)
import qualified Control.Lens as Lens
import Control.Lens.Operators
import Control.Monad (foldM)
import Data.Binary (Binary)
import Data.Foldable (sequenceA_)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import Generics.Constraints (Constraints, makeDerivings, makeInstances)
import GHC.Generics (Generic)
import Text.Show.Combinators ((@|), showCon)
import Prelude.Compat
class
(Ord (RowConstraintsKey constraints), TypeConstraints constraints) =>
RowConstraints constraints where
type RowConstraintsKey constraints
forbidden :: Lens' constraints (Set (RowConstraintsKey constraints))
type RowKey typ = RowConstraintsKey (TypeConstraintsOf typ)
data RowExtend key val rest k = RowExtend
{ _eKey :: key
, _eVal :: k # val
, _eRest :: k # rest
} deriving Generic
data FlatRowExtends key val rest k = FlatRowExtends
{ _freExtends :: Map key (k # val)
, _freRest :: k # rest
} deriving Generic
makeLenses ''RowExtend
makeLenses ''FlatRowExtends
makeCommonInstances [''FlatRowExtends]
makeZipMatch ''RowExtend
makeKTraversableApplyAndBases ''RowExtend
makeKTraversableApplyAndBases ''FlatRowExtends
makeDerivings [''Eq, ''Ord] [''RowExtend]
makeInstances [''Binary, ''NFData] [''RowExtend]
instance
Constraints (RowExtend key val rest k) Show =>
Show (RowExtend key val rest k) where
showsPrec p (RowExtend k v r) = (showCon "RowExtend" @| k @| v @| r) p
{-# INLINE flattenRowExtend #-}
flattenRowExtend ::
(Ord key, Monad m) =>
(Tree v rest -> m (Maybe (Tree (RowExtend key val rest) v))) ->
Tree (RowExtend key val rest) v ->
m (Tree (FlatRowExtends key val rest) v)
flattenRowExtend nextExtend (RowExtend k v rest) =
flattenRow nextExtend rest
<&> freExtends %~ Map.unionWith (error "Colliding keys") (Map.singleton k v)
{-# INLINE flattenRow #-}
flattenRow ::
(Ord key, Monad m) =>
(Tree v rest -> m (Maybe (Tree (RowExtend key val rest) v))) ->
Tree v rest ->
m (Tree (FlatRowExtends key val rest) v)
flattenRow nextExtend x =
nextExtend x
>>= maybe (pure (FlatRowExtends mempty x)) (flattenRowExtend nextExtend)
{-# INLINE unflattenRow #-}
unflattenRow ::
Monad m =>
(Tree (RowExtend key val rest) v -> m (Tree v rest)) ->
Tree (FlatRowExtends key val rest) v -> m (Tree v rest)
unflattenRow mkExtend (FlatRowExtends fields rest) =
Map.toList fields & foldM f rest
where
f acc (key, val) = RowExtend key val acc & mkExtend
{-# INLINE verifyRowExtendConstraints #-}
verifyRowExtendConstraints ::
RowConstraints (TypeConstraintsOf rowTyp) =>
(TypeConstraintsOf rowTyp -> TypeConstraintsOf valTyp) ->
TypeConstraintsOf rowTyp ->
Tree (RowExtend (RowKey rowTyp) valTyp rowTyp) k ->
Maybe (Tree (RowExtend (RowKey rowTyp) valTyp rowTyp) (WithConstraint k))
verifyRowExtendConstraints toChildC c (RowExtend k v rest)
| c ^. forbidden . contains k = Nothing
| otherwise =
RowExtend k
(WithConstraint (c & forbidden .~ mempty & toChildC) v)
(WithConstraint (c & forbidden . contains k .~ True) rest)
& Just
{-# INLINE rowExtendStructureMismatch #-}
rowExtendStructureMismatch ::
Ord key =>
( Unify m rowTyp
, Unify m valTyp
) =>
(forall c. Unify m c => Tree (UVarOf m) c -> Tree (UVarOf m) c -> m (Tree (UVarOf m) c)) ->
Prism' (Tree rowTyp (UVarOf m))
(Tree (RowExtend key valTyp rowTyp) (UVarOf m)) ->
(TypeConstraintsOf rowTyp, Tree (RowExtend key valTyp rowTyp) (UVarOf m)) ->
(TypeConstraintsOf rowTyp, Tree (RowExtend key valTyp rowTyp) (UVarOf m)) ->
m ()
rowExtendStructureMismatch match extend (c0, r0) (c1, r1) =
do
flat0 <- flattenRowExtend nextExtend r0
flat1 <- flattenRowExtend nextExtend r1
Map.intersectionWith match (flat0 ^. freExtends) (flat1 ^. freExtends)
& sequenceA_
restVar <- c0 <> c1 & UUnbound & newVar binding
let side x y =
unflattenRow mkExtend FlatRowExtends
{ _freExtends =
(x ^. freExtends) `Map.difference` (y ^. freExtends)
, _freRest = restVar
} >>= match (y ^. freRest)
_ <- side flat0 flat1
_ <- side flat1 flat0
pure ()
where
mkExtend ext = extend # ext & newTerm
nextExtend v = semiPruneLookup v <&> (^? Lens._2 . _UTerm . uBody . extend)
{-# INLINE rowElementInfer #-}
rowElementInfer ::
( Unify m valTyp
, Unify m rowTyp
, RowConstraints (TypeConstraintsOf rowTyp)
) =>
(Tree (RowExtend (RowKey rowTyp) valTyp rowTyp) (UVarOf m) -> Tree rowTyp (UVarOf m)) ->
RowKey rowTyp ->
m (Tree (UVarOf m) valTyp, Tree (UVarOf m) rowTyp)
rowElementInfer extendToRow k =
do
restVar <-
scopeConstraints
>>= newVar binding . UUnbound . (forbidden . contains k .~ True)
part <- newUnbound
whole <- RowExtend k part restVar & extendToRow & newTerm
pure (part, whole)