{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE OverloadedStrings #-}
module Language.Fixpoint.Defunctionalize
( defunctionalize
, Defunc(..)
, defuncAny
) where
import qualified Data.HashMap.Strict as M
import Data.Hashable
import Control.Monad.State
import Language.Fixpoint.Misc (fM, secondM, mapSnd)
import Language.Fixpoint.Solver.Sanitize (symbolEnv)
import Language.Fixpoint.Types hiding (allowHO)
import Language.Fixpoint.Types.Config
import Language.Fixpoint.Types.Visitor (mapMExpr)
defunctionalize :: (Fixpoint a) => Config -> SInfo a -> SInfo a
defunctionalize cfg si = evalState (defunc si) (makeInitDFState cfg si)
defuncAny :: Defunc a => Config -> SymEnv -> a -> a
defuncAny cfg env e = evalState (defunc e) (makeDFState cfg env emptyIBindEnv)
txExpr :: Expr -> DF Expr
txExpr e = do
hoFlag <- gets dfHO
if hoFlag then defuncExpr e else return e
defuncExpr :: Expr -> DF Expr
defuncExpr = mapMExpr reBind
>=> mapMExpr (fM normalizeLams)
reBind :: Expr -> DF Expr
reBind (ELam (x, s) e) = ((\y -> ELam (y, s) (subst1 e (x, EVar y))) <$> freshSym s)
reBind e = return e
shiftLam :: Int -> Symbol -> Sort -> Expr -> Expr
shiftLam i x t e = ELam (x_i, t) (e `subst1` (x, x_i_t))
where
x_i = lamArgSymbol i
x_i_t = ECst (EVar x_i) t
normalizeLams :: Expr -> Expr
normalizeLams e = snd $ normalizeLamsFromTo 1 e
normalizeLamsFromTo :: Int -> Expr -> (Int, Expr)
normalizeLamsFromTo i = go
where
go (ELam (y, sy) e) = (i + 1, shiftLam i y sy e') where (i, e') = go e
go (EApp e1 e2) = let (i1, e1') = go e1
(i2, e2') = go e2
in (max i1 i2, EApp e1' e2')
go (ECst e s) = mapSnd (`ECst` s) (go e)
go (PAll bs e) = mapSnd (PAll bs) (go e)
go e = (i, e)
class Defunc a where
defunc :: a -> DF a
instance (Defunc (c a), TaggedC c a) => Defunc (GInfo c a) where
defunc fi = do
cm' <- defunc $ cm fi
ws' <- defunc $ ws fi
gLits' <- defunc $ gLits fi
dLits' <- defunc $ dLits fi
bs' <- defunc $ bs fi
ass' <- defunc $ asserts fi
return $ fi { cm = cm'
, ws = ws'
, gLits = gLits'
, dLits = dLits'
, bs = bs'
, asserts = ass'
}
instance (Defunc a) => Defunc (Triggered a) where
defunc (TR t e) = TR t <$> defunc e
instance Defunc (SimpC a) where
defunc sc = do crhs' <- defunc $ _crhs sc
return $ sc {_crhs = crhs'}
instance Defunc (WfC a) where
defunc wf@(WfC {}) = do
let (x, t, k) = wrft wf
t' <- defunc t
return $ wf { wrft = (x, t', k) }
defunc wf@(GWfC {}) = do
let (x, t, k) = wrft wf
t' <- defunc t
e' <- defunc $ wexpr wf
return $ wf { wrft = (x, t', k), wexpr = e' }
instance Defunc SortedReft where
defunc (RR s r) = RR s <$> defunc r
instance Defunc (Symbol, SortedReft) where
defunc (x, sr) = (x,) <$> defunc sr
instance Defunc (Symbol, Sort) where
defunc (x, t) = (x,) <$> defunc t
instance Defunc Reft where
defunc (Reft (x, e)) = Reft . (x,) <$> defunc e
instance Defunc Expr where
defunc = txExpr
instance Defunc a => Defunc (SEnv a) where
defunc = mapMSEnv defunc
instance Defunc BindEnv where
defunc bs = do dfbs <- gets dfBEnv
let f (i, xs) = if i `memberIBindEnv` dfbs
then (i,) <$> defunc xs
else (i,) <$> matchSort xs
mapWithKeyMBindEnv f bs
where
matchSort (x, RR s r) = ((x,) . (`RR` r)) <$> defunc s
instance Defunc Sort where
defunc = return
instance Defunc a => Defunc [a] where
defunc = mapM defunc
instance (Defunc a, Eq k, Hashable k) => Defunc (M.HashMap k a) where
defunc m = M.fromList <$> mapM (secondM defunc) (M.toList m)
type DF = State DFST
data DFST = DFST
{ dfFresh :: !Int
, dfEnv :: !SymEnv
, dfBEnv :: !IBindEnv
, dfHO :: !Bool
, dfLams :: ![Expr]
, dfRedex :: ![Expr]
, dfBinds :: !(SEnv Sort)
}
makeDFState :: Config -> SymEnv -> IBindEnv -> DFST
makeDFState cfg env ibind = DFST
{ dfFresh = 0
, dfEnv = env
, dfBEnv = ibind
, dfHO = allowHO cfg || defunction cfg
, dfLams = []
, dfRedex = []
, dfBinds = mempty
}
makeInitDFState :: Config -> SInfo a -> DFST
makeInitDFState cfg si
= makeDFState cfg
(symbolEnv cfg si)
(mconcat ((senv <$> M.elems (cm si)) ++ (wenv <$> M.elems (ws si))))
freshSym :: Sort -> DF Symbol
freshSym t = do
n <- gets dfFresh
let x = intSymbol "lambda_fun_" n
modify $ \s -> s {dfFresh = n + 1, dfBinds = insertSEnv x t (dfBinds s)}
return x