{-# LANGUAGE MultiWayIf #-}
module MatchSigs.Matching.Env
  ( Env
  , (/\)
  , checkOr
  , checkAnd
  , introVars
  , tryAssignVar
  , initEnv
  ) where

import           Control.Monad.State.Strict
import           Data.List
import qualified Data.IntMap.Strict as IM

import           MatchSigs.Sig (FreeVarIdx)

type Level = Int
type VarLevel = IM.IntMap Int
type VarAssign = IM.IntMap FreeVarIdx

-- | Context for matching the free vars in two 'Sigs'
data Env =
  MkEnv { Env -> Int
level    :: !Level -- current var level
        , Env -> VarLevel
vass     :: !VarAssign -- map from B vars to A vars
        , Env -> VarLevel
vlA      :: !VarLevel -- the level at which an A var with introduced
        , Env -> VarLevel
vlB      :: !VarLevel
        }

initEnv :: Env
initEnv :: Env
initEnv =
  MkEnv { level :: Int
level    = Int
0
        , vass :: VarLevel
vass     = forall a. Monoid a => a
mempty
        , vlA :: VarLevel
vlA      = forall a. Monoid a => a
mempty
        , vlB :: VarLevel
vlB      = forall a. Monoid a => a
mempty
        }

-- | Identify var from one sig with var in other sig
tryAssignVar :: FreeVarIdx
             -> FreeVarIdx
             -> State Env Bool
tryAssignVar :: Int -> Int -> State Env Bool
tryAssignVar Int
ai Int
bi = do
  Env
env <- forall s (m :: * -> *). MonadState s m => m s
get
  let mb :: Maybe Int
mb = forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
bi forall a b. (a -> b) -> a -> b
$ Env -> VarLevel
vass Env
env
  if -- already assigned
     | Just Int
x <- Maybe Int
mb
     , Int
x forall a. Eq a => a -> a -> Bool
== Int
ai -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

     -- not assigned and levels match
     | Maybe Int
Nothing <- Maybe Int
mb
     , Just Int
lA <- forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
ai forall a b. (a -> b) -> a -> b
$ Env -> VarLevel
vlA Env
env
     , Just Int
lB <- forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
bi forall a b. (a -> b) -> a -> b
$ Env -> VarLevel
vlB Env
env
     , Int
lA forall a. Eq a => a -> a -> Bool
== Int
lB
     -> do forall s (m :: * -> *). MonadState s m => s -> m ()
put Env
env { vass :: VarLevel
vass = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
bi Int
ai forall a b. (a -> b) -> a -> b
$ Env -> VarLevel
vass Env
env }
           forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True

     | Bool
otherwise -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

-- | Add vars from both sigs to the context, accounting for level
introVars :: [FreeVarIdx]
          -> [FreeVarIdx]
          -> State Env Bool
introVars :: [Int] -> [Int] -> State Env Bool
introVars [] [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
introVars [Int]
va [Int]
vb
  | forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
va forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
vb
  = (Bool
True forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' forall a b. (a -> b) -> a -> b
$ \Env
env ->
      let lvl :: Int
lvl = Env -> Int
level Env
env
       in Env
env { vlA :: VarLevel
vlA = forall a. [(Int, a)] -> IntMap a
IM.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
va forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Int
lvl) forall a. Semigroup a => a -> a -> a
<> Env -> VarLevel
vlA Env
env
              , vlB :: VarLevel
vlB = forall a. [(Int, a)] -> IntMap a
IM.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
vb forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Int
lvl) forall a. Semigroup a => a -> a -> a
<> Env -> VarLevel
vlB Env
env
              , level :: Int
level = Int
lvl forall a. Num a => a -> a -> a
+ Int
1
              }
  | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

-- | Logical conjuction
(/\) :: State env Bool
     -> State env Bool
     -> State env Bool
State env Bool
a /\ :: forall env. State env Bool -> State env Bool -> State env Bool
/\ State env Bool
b = do
  Bool
r <- State env Bool
a
  if Bool
r then State env Bool
b else forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

checkAnd :: [State Env Bool]
         -> State Env Bool
checkAnd :: [State Env Bool] -> State Env Bool
checkAnd = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall env. State env Bool -> State env Bool -> State env Bool
(/\) (forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True)

-- | Logical disjunction. Discards state if False
(\/) :: State env Bool
     -> State env Bool
     -> State env Bool
State env Bool
a \/ :: forall env. State env Bool -> State env Bool -> State env Bool
\/ State env Bool
b = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \env
env ->
  let (Bool
ar, env
as) = forall s a. State s a -> s -> (a, s)
runState State env Bool
a env
env
      ~(Bool
br, env
bs) = forall s a. State s a -> s -> (a, s)
runState State env Bool
b env
env
   in if Bool
ar then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
ar, env
as)
            else if Bool
br then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
br, env
bs)
                 else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
False, env
env)

checkOr :: [State env Bool]
        -> State env Bool
checkOr :: forall env. [State env Bool] -> State env Bool
checkOr = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall env. State env Bool -> State env Bool -> State env Bool
(\/) (forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False)