module CLaSH.Normalize.DEC
(collectGlobals
,isDisjoint
,mkDisjointGroup
)
where
import qualified Control.Lens as Lens
import Data.Bits ((.&.),complement)
import qualified Data.Either as Either
import qualified Data.Foldable as Foldable
import qualified Data.HashMap.Strict as HashMap
import qualified Data.IntMap.Strict as IM
import qualified Data.List as List
import qualified Data.Map.Strict as Map
import qualified Data.Maybe as Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Unbound.Generics.LocallyNameless (Bind, bind, embed, fv, unbind,
unembed, unrec)
import qualified Unbound.Generics.LocallyNameless as Unbound
import CLaSH.Core.DataCon (DataCon, dcTag)
import CLaSH.Core.FreeVars (termFreeIds, typeFreeVars)
import CLaSH.Core.Literal (Literal (..))
import CLaSH.Core.Term (LetBinding, Pat (..), Term (..), TmName)
import CLaSH.Core.TyCon (tyConDataCons)
import CLaSH.Core.Type (Type, mkTyConApp, splitFunForallTy)
import CLaSH.Core.Util (collectArgs, mkApps, termType)
import CLaSH.Normalize.Types (NormalizeState)
import CLaSH.Normalize.Util (isConstant)
import CLaSH.Rewrite.Types (RewriteMonad, evaluator, tcCache, tupleTcCache)
import CLaSH.Rewrite.Util (mkInternalVar, mkSelectorCase,
isUntranslatableType)
import CLaSH.Util
data CaseTree a
= Leaf a
| LB [LetBinding] (CaseTree a)
| Branch Term [(Pat,CaseTree a)]
deriving (Eq,Show,Functor,Foldable)
isDisjoint :: CaseTree ([Either Term Type])
-> Bool
isDisjoint (Branch _ [_]) = False
isDisjoint ct = go ct
where
go (Leaf _) = False
go (LB _ ct') = go ct'
go (Branch _ []) = False
go (Branch _ [(_,x)]) = go x
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty l@(Leaf _) = l
removeEmpty (LB lb ct) =
case removeEmpty ct of
Leaf [] -> Leaf []
ct' -> LB lb ct'
removeEmpty (Branch s bs) =
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
[] -> Leaf []
bs' -> Branch s bs'
allEqual :: Eq a => [a] -> Bool
allEqual [] = True
allEqual (x:xs) = all (== x) xs
collectGlobals ::
Set TmName
-> [(Term,Term)]
-> [Term]
-> Term
-> RewriteMonad NormalizeState
(Term,[(Term,CaseTree [(Either Term Type)])])
collectGlobals inScope substitution seen (Case scrut ty alts) = do
rec (alts' ,collected) <- collectGlobalsAlts inScope substitution seen scrut'
alts
(scrut',collected') <- collectGlobals inScope substitution
(map fst collected ++ seen) scrut
return (Case scrut' ty alts',collected ++ collected')
collectGlobals inScope substitution seen e@(collectArgs -> (fun, args@(_:_)))
| not (isConstant e) = do
tcm <- Lens.view tcCache
eval <- Lens.view evaluator
eTy <- termType tcm e
case splitFunForallTy eTy of
([],_) -> case interestingToLift inScope (eval tcm False) fun args of
Just fun' | fun' `notElem` seen -> do
(args',collected) <- collectGlobalsArgs inScope substitution
(fun':seen) args
let e' = Maybe.fromMaybe e (List.lookup fun' substitution)
return (e',(fun',Leaf args'):collected)
_ -> do (args',collected) <- collectGlobalsArgs inScope substitution
seen args
return (mkApps fun args',collected)
_ -> return (e,[])
collectGlobals inScope substitution seen (Letrec b) = do
(unrec -> lbs,body) <- unbind b
(body',collected) <- collectGlobals inScope substitution seen body
(lbs',collected') <- collectGlobalsLbs inScope substitution
(map fst collected ++ seen)
lbs
return (Letrec (bind (Unbound.rec lbs') body')
,map (second (LB lbs')) (collected ++ collected')
)
collectGlobals _ _ _ e = return (e,[])
collectGlobalsArgs ::
Set TmName
-> [(Term,Term)]
-> [Term]
-> [Either Term Type]
-> RewriteMonad NormalizeState
([Either Term Type]
,[(Term,CaseTree [(Either Term Type)])]
)
collectGlobalsArgs inScope substitution seen args = do
(_,(args',collected)) <- second unzip <$> mapAccumLM go seen args
return (args',concat collected)
where
go s (Left tm) = do
(tm',collected) <- collectGlobals inScope substitution s tm
return (map fst collected ++ s,(Left tm',collected))
go s (Right ty) = return (s,(Right ty,[]))
collectGlobalsAlts ::
Set TmName
-> [(Term,Term)]
-> [Term]
-> Term
-> [Bind Pat Term]
-> RewriteMonad NormalizeState
([Bind Pat Term]
,[(Term,CaseTree [(Either Term Type)])]
)
collectGlobalsAlts inScope substitution seen scrut alts = do
(alts',collected) <- unzip <$> mapM go alts
let collectedM = map (Map.fromList . map (second (:[]))) collected
collectedUN = Map.unionsWith (++) collectedM
collected' = map (second (Branch scrut)) (Map.toList collectedUN)
return (alts',collected')
where
go pe = do (p,e) <- unbind pe
(e',collected) <- collectGlobals inScope substitution seen e
return (bind p e',map (second (p,)) collected)
collectGlobalsLbs ::
Set TmName
-> [(Term,Term)]
-> [Term]
-> [LetBinding]
-> RewriteMonad NormalizeState
([LetBinding]
,[(Term,CaseTree [(Either Term Type)])]
)
collectGlobalsLbs inScope substitution seen lbs = do
(_,(lbs',collected)) <- second unzip <$> mapAccumLM go seen lbs
return (lbs',concat collected)
where
go :: [Term] -> LetBinding
-> RewriteMonad NormalizeState
([Term]
,(LetBinding
,[(Term,CaseTree [(Either Term Type)])]
)
)
go s (id_,unembed -> e) = do
(e',collected) <- collectGlobals inScope substitution s e
return (map fst collected ++ s,((id_,embed e'),collected))
mkDisjointGroup :: Set TmName
-> (Term,CaseTree [(Either Term Type)])
-> RewriteMonad NormalizeState Term
mkDisjointGroup fvs (fun,cs) = do
let argss = Foldable.toList cs
argssT = zip [0..] (List.transpose argss)
(commonT,uncommonT) = List.partition (isCommon fvs . snd) argssT
common = map (second head) commonT
uncommon = map (Either.lefts) (List.transpose (map snd uncommonT))
cs' = fmap (zip [0..]) cs
cs'' = removeEmpty
$ fmap (Either.lefts . map snd)
(if null common
then cs'
else fmap (filter (`notElem` common)) cs')
tcm <- Lens.view tcCache
(uncommonCaseM,uncommonProjections) <- case uncommon of
[] -> return (Nothing,[])
(uc:_) -> do
argTys <- mapM (termType tcm) uc
disJointSelProj argTys cs''
let newArgs = mkDJArgs 0 common uncommonProjections
case uncommonCaseM of
Just lb -> return (Letrec (bind (Unbound.rec [lb]) (mkApps fun newArgs)))
Nothing -> return (mkApps fun newArgs)
disJointSelProj :: [Type]
-> CaseTree [Term]
-> RewriteMonad NormalizeState (Maybe LetBinding,[Term])
disJointSelProj _ (Leaf []) = return (Nothing,[])
disJointSelProj argTys cs = do
let maxIndex = length argTys 1
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
(untran,tran) <- partitionM (isUntranslatableType . snd) (zip [0..] argTys)
let untranCs = map (css!!) (map fst untran)
untranSels = zipWith (\(_,ty) cs' -> genCase ty Nothing [] cs')
untran untranCs
(lbM,projs) <- case tran of
[] -> return (Nothing,[])
[(i,ty)] -> return (Nothing,[genCase ty Nothing [] (css!!i)])
tys -> do
tcm <- Lens.view tcCache
tupTcm <- Lens.view tupleTcCache
let m = length tys
Just tupTcNm = IM.lookup m tupTcm
Just tupTc = HashMap.lookup tupTcNm tcm
[tupDc] = tyConDataCons tupTc
(tyIxs,tys') = unzip tys
tupTy = mkTyConApp tupTcNm tys'
cs' = fmap (\es -> map (es !!) tyIxs) cs
djCase = genCase tupTy (Just tupDc) tys' cs'
(scrutId,scrutVar) <- mkInternalVar "tupIn" tupTy
projections <- mapM (mkSelectorCase ($(curLoc) ++ "disJointSelProj")
tcm scrutVar (dcTag tupDc)) [0..m1]
return (Just (scrutId,embed djCase),projections)
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
return (lbM,selProjs)
where
tranOrUnTran _ [] projs = projs
tranOrUnTran _ sels [] = map snd sels
tranOrUnTran n ((ut,s):uts) (p:projs)
| n == ut = s : tranOrUnTran (n+1) uts (p:projs)
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
isCommon :: Set TmName -> [Either Term Type] -> Bool
isCommon _ [] = True
isCommon _ (Right ty:tys) = Set.null (Lens.setOf typeFreeVars ty) &&
allEqual (Right ty:tys)
isCommon fvs (Left tm:tms) = Set.null (Lens.setOf termFreeIds tm Set.\\ fvs) &&
allEqual (Left tm:tms)
mkDJArgs :: Int
-> [(Int,Either Term Type)]
-> [Term]
-> [Either Term Type]
mkDJArgs _ cms [] = map snd cms
mkDJArgs _ [] uncms = map Left uncms
mkDJArgs n ((m,x):cms) (y:uncms)
| n == m = x : mkDJArgs (n+1) cms (y:uncms)
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms
genCase :: Type
-> Maybe DataCon
-> [Type]
-> CaseTree [Term]
-> Term
genCase ty dcM argTys = go
where
go (Leaf tms) =
case dcM of
Just dc -> mkApps (Data dc) (map Right argTys ++ map Left tms)
_ -> head tms
go (LB lb ct) =
Letrec (bind (Unbound.rec lb) (go ct))
go (Branch scrut [(p,ct)]) =
let ct' = go ct
alt = bind p ct'
in case Lens.setOf termFreeIds ct' == Lens.setOf fv alt of
True -> ct'
_ -> Case scrut ty [alt]
go (Branch scrut pats) =
Case scrut ty (map (\(p,ct) -> bind p (go ct)) pats)
interestingToLift :: Set TmName
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> Maybe Term
interestingToLift inScope _ e@(Var _ nm) _ =
if nm `Set.member` inScope
then Just e
else Nothing
interestingToLift _ eval e@(Prim nm _) args =
case List.lookup nm interestingPrims of
Just t | t || not (all isConstant lArgs) -> Just e
_ -> Nothing
where
interestingPrims =
[("CLaSH.Sized.Internal.BitVector.*#",tailNonPow2)
,("CLaSH.Sized.Internal.BitVector.times#",tailNonPow2)
,("CLaSH.Sized.Internal.BitVector.quot#",lastNotPow2)
,("CLaSH.Sized.Internal.BitVector.rem#",lastNotPow2)
,("CLaSH.Sized.Internal.Index.*#",tailNonPow2)
,("CLaSH.Sized.Internal.Index.quot#",lastNotPow2)
,("CLaSH.Sized.Internal.Index.rem#",lastNotPow2)
,("CLaSH.Sized.Internal.Signed.*#",tailNonPow2)
,("CLaSH.Sized.Internal.Signed.times#",tailNonPow2)
,("CLaSH.Sized.Internal.Signed.rem#",lastNotPow2)
,("CLaSH.Sized.Internal.Signed.quot#",lastNotPow2)
,("CLaSH.Sized.Internal.Signed.div#",lastNotPow2)
,("CLaSH.Sized.Internal.Signed.mod#",lastNotPow2)
,("CLaSH.Sized.Internal.Unsigned.*#",tailNonPow2)
,("CLaSH.Sized.Internal.Unsigned.times#",tailNonPow2)
,("CLaSH.Sized.Internal.Unsigned.quot#",lastNotPow2)
,("CLaSH.Sized.Internal.Unsigned.rem#",lastNotPow2)
,("GHC.Base.quotInt",lastNotPow2)
,("GHC.Base.remInt",lastNotPow2)
,("GHC.Base.divInt",lastNotPow2)
,("GHC.Base.modInt",lastNotPow2)
,("GHC.Classes.divInt#",lastNotPow2)
,("GHC.Classes.modInt#",lastNotPow2)
,("GHC.Integer.Type.timesInteger",allNonPow2)
,("GHC.Integer.Type.divInteger",lastNotPow2)
,("GHC.Integer.Type.modInteger",lastNotPow2)
,("GHC.Integer.Type.quotInteger",lastNotPow2)
,("GHC.Integer.Type.remInteger",lastNotPow2)
,("GHC.Prim.*#",allNonPow2)
,("GHC.Prim.quotInt#",lastNotPow2)
,("GHC.Prim.remInt#",lastNotPow2)
]
lArgs = Either.lefts args
allNonPow2 = all (not . termIsPow2) lArgs
tailNonPow2 = all (not . termIsPow2) (tail lArgs)
lastNotPow2 = not (termIsPow2 (last lArgs))
termIsPow2 e' = case eval e' of
Literal (IntegerLiteral n) -> isPow2 n
a -> case collectArgs a of
(Prim nm' _,[Right _,Left _,Left (Literal (IntegerLiteral n))])
| isFromInteger nm' -> isPow2 n
_ -> False
isPow2 x = x /= 0 && (x .&. (complement x + 1)) == x
isFromInteger x = x `elem` ["CLaSH.Sized.Internal.BitVector.fromInteger#"
,"CLaSH.Sized.Integer.Index.fromInteger"
,"CLaSH.Sized.Internal.Signed.fromInteger#"
,"CLaSH.Sized.Internal.Unsigned.fromInteger#"
]
interestingToLift _ _ _ _ = Nothing