{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Normalize.DEC
(collectGlobals
,isDisjoint
,mkDisjointGroup
)
where
import Control.Concurrent.Supply (splitSupply)
import qualified Control.Lens as Lens
import Data.Bits ((.&.),complement)
import qualified Data.Either as Either
import qualified Data.Foldable as Foldable
import qualified Data.HashMap.Strict as HashMap
import qualified Data.IntMap.Strict as IM
import qualified Data.List as List
import qualified Data.Map.Strict as Map
import qualified Data.Maybe as Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Unbound.Generics.LocallyNameless
(Bind, bind, embed, fv, unbind, unembed, unrec)
import qualified Unbound.Generics.LocallyNameless as Unbound
import Clash.Core.DataCon (DataCon, dcTag)
import Clash.Core.Evaluator (whnf')
import Clash.Core.FreeVars (termFreeIds, typeFreeVars)
import Clash.Core.Name (Name (..), string2InternalName)
import Clash.Core.Literal (Literal (..))
import Clash.Core.Term (LetBinding, Pat (..), Term (..), TmOccName)
import Clash.Core.TyCon (tyConDataCons)
import Clash.Core.Type (Type, isPolyFunTy, mkTyConApp, splitFunForallTy)
import Clash.Core.Util (collectArgs, mkApps, termType)
import Clash.Normalize.Types (NormalizeState)
import Clash.Normalize.Util (isConstant)
import Clash.Rewrite.Types
(RewriteMonad, bindings, evaluator, globalHeap, tcCache, tupleTcCache, uniqSupply)
import Clash.Rewrite.Util (mkInternalVar, mkSelectorCase,
isUntranslatableType)
import Clash.Util
data CaseTree a
= Leaf a
| LB [LetBinding] (CaseTree a)
| Branch Term [(Pat,CaseTree a)]
deriving (Eq,Show,Functor,Foldable)
isDisjoint :: CaseTree ([Either Term Type])
-> Bool
isDisjoint (Branch _ [_]) = False
isDisjoint ct = go ct
where
go (Leaf _) = False
go (LB _ ct') = go ct'
go (Branch _ []) = False
go (Branch _ [(_,x)]) = go x
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty l@(Leaf _) = l
removeEmpty (LB lb ct) =
case removeEmpty ct of
Leaf [] -> Leaf []
ct' -> LB lb ct'
removeEmpty (Branch s bs) =
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
[] -> Leaf []
bs' -> Branch s bs'
allEqual :: Eq a => [a] -> Bool
allEqual [] = True
allEqual (x:xs) = all (== x) xs
collectGlobals ::
Set TmOccName
-> [(Term,Term)]
-> [Term]
-> Term
-> RewriteMonad NormalizeState
(Term,[(Term,([Term],CaseTree [(Either Term Type)]))])
collectGlobals inScope substitution seen (Case scrut ty alts) = do
rec (alts' ,collected) <- collectGlobalsAlts inScope substitution seen scrut'
alts
(scrut',collected') <- collectGlobals inScope substitution
(map fst collected ++ seen) scrut
return (Case scrut' ty alts',collected ++ collected')
collectGlobals inScope substitution seen e@(collectArgs -> (fun, args@(_:_)))
| not (isConstant e) = do
tcm <- Lens.view tcCache
bndrs <- Lens.use bindings
primEval <- Lens.view evaluator
gh <- Lens.use globalHeap
ids <- Lens.use uniqSupply
let (ids1,ids2) = splitSupply ids
uniqSupply Lens..= ids2
let eval = snd . whnf' primEval bndrs tcm gh ids1 False
eTy <- termType tcm e
untran <- isUntranslatableType False eTy
case untran of
False -> case interestingToLift inScope eval fun args of
Just fun' | fun' `notElem` seen -> do
(args',collected) <- collectGlobalsArgs inScope substitution
(fun':seen) args
let e' = Maybe.fromMaybe (mkApps fun' args') (List.lookup fun' substitution)
return (e',(fun',(seen,Leaf args')):collected)
_ -> do (args',collected) <- collectGlobalsArgs inScope substitution
seen args
return (mkApps fun args',collected)
_ -> return (e,[])
collectGlobals inScope substitution seen (Letrec b) = do
(unrec -> lbs,body) <- unbind b
(body',collected) <- collectGlobals inScope substitution seen body
(lbs',collected') <- collectGlobalsLbs inScope substitution
(map fst collected ++ seen)
lbs
return (Letrec (bind (Unbound.rec lbs') body')
,map (second (second (LB lbs'))) (collected ++ collected')
)
collectGlobals _ _ _ e = return (e,[])
collectGlobalsArgs ::
Set TmOccName
-> [(Term,Term)]
-> [Term]
-> [Either Term Type]
-> RewriteMonad NormalizeState
([Either Term Type]
,[(Term,([Term],CaseTree [(Either Term Type)]))]
)
collectGlobalsArgs inScope substitution seen args = do
(_,(args',collected)) <- second unzip <$> mapAccumLM go seen args
return (args',concat collected)
where
go s (Left tm) = do
(tm',collected) <- collectGlobals inScope substitution s tm
return (map fst collected ++ s,(Left tm',collected))
go s (Right ty) = return (s,(Right ty,[]))
collectGlobalsAlts ::
Set TmOccName
-> [(Term,Term)]
-> [Term]
-> Term
-> [Bind Pat Term]
-> RewriteMonad NormalizeState
([Bind Pat Term]
,[(Term,([Term],CaseTree [(Either Term Type)]))]
)
collectGlobalsAlts inScope substitution seen scrut alts = do
(alts',collected) <- unzip <$> mapM go alts
let collectedM = map (Map.fromList . map (second (second (:[])))) collected
collectedUN = Map.unionsWith (\(l1,r1) (l2,r2) -> (List.nub (l1 ++ l2),r1 ++ r2)) collectedM
collected' = map (second (second (Branch scrut))) (Map.toList collectedUN)
return (alts',collected')
where
go pe = do (p,e) <- unbind pe
(e',collected) <- collectGlobals inScope substitution seen e
return (bind p e',map (second (second (p,))) collected)
collectGlobalsLbs ::
Set TmOccName
-> [(Term,Term)]
-> [Term]
-> [LetBinding]
-> RewriteMonad NormalizeState
([LetBinding]
,[(Term,([Term],CaseTree [(Either Term Type)]))]
)
collectGlobalsLbs inScope substitution seen lbs = do
(_,(lbs',collected)) <- second unzip <$> mapAccumLM go seen lbs
return (lbs',concat collected)
where
go :: [Term] -> LetBinding
-> RewriteMonad NormalizeState
([Term]
,(LetBinding
,[(Term,([Term],CaseTree [(Either Term Type)]))]
)
)
go s (id_,unembed -> e) = do
(e',collected) <- collectGlobals inScope substitution s e
return (map fst collected ++ s,((id_,embed e'),collected))
mkDisjointGroup :: Set TmOccName
-> (Term,([Term],CaseTree [(Either Term Type)]))
-> RewriteMonad NormalizeState (Term,[Term])
mkDisjointGroup fvs (fun,(seen,cs)) = do
let argss = Foldable.toList cs
argssT = zip [0..] (List.transpose argss)
(commonT,uncommonT) = List.partition (isCommon fvs . snd) argssT
common = map (second head) commonT
uncommon = map (Either.lefts) (List.transpose (map snd uncommonT))
cs' = fmap (zip [0..]) cs
cs'' = removeEmpty
$ fmap (Either.lefts . map snd)
(if null common
then cs'
else fmap (filter (`notElem` common)) cs')
tcm <- Lens.view tcCache
(uncommonCaseM,uncommonProjections) <- case uncommon of
[] -> return (Nothing,[])
(uc:_) -> do
argTys <- mapM (termType tcm) uc
disJointSelProj argTys cs''
let newArgs = mkDJArgs 0 common uncommonProjections
case uncommonCaseM of
Just lb -> return (Letrec (bind (Unbound.rec [lb]) (mkApps fun newArgs)), seen)
Nothing -> return (mkApps fun newArgs, seen)
disJointSelProj :: [Type]
-> CaseTree [Term]
-> RewriteMonad NormalizeState (Maybe LetBinding,[Term])
disJointSelProj _ (Leaf []) = return (Nothing,[])
disJointSelProj argTys cs = do
let maxIndex = length argTys - 1
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
(untran,tran) <- partitionM (isUntranslatableType False . snd) (zip [0..] argTys)
let untranCs = map (css!!) (map fst untran)
untranSels = zipWith (\(_,ty) cs' -> genCase ty Nothing [] cs')
untran untranCs
(lbM,projs) <- case tran of
[] -> return (Nothing,[])
[(i,ty)] -> return (Nothing,[genCase ty Nothing [] (css!!i)])
tys -> do
tcm <- Lens.view tcCache
tupTcm <- Lens.view tupleTcCache
let m = length tys
Just tupTcNm = IM.lookup m tupTcm
Just tupTc = HashMap.lookup (nameOcc tupTcNm) tcm
[tupDc] = tyConDataCons tupTc
(tyIxs,tys') = unzip tys
tupTy = mkTyConApp tupTcNm tys'
cs' = fmap (\es -> map (es !!) tyIxs) cs
djCase = genCase tupTy (Just tupDc) tys' cs'
(scrutId,scrutVar) <- mkInternalVar (string2InternalName "tupIn") tupTy
projections <- mapM (mkSelectorCase ($(curLoc) ++ "disJointSelProj")
tcm scrutVar (dcTag tupDc)) [0..m-1]
return (Just (scrutId,embed djCase),projections)
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
return (lbM,selProjs)
where
tranOrUnTran _ [] projs = projs
tranOrUnTran _ sels [] = map snd sels
tranOrUnTran n ((ut,s):uts) (p:projs)
| n == ut = s : tranOrUnTran (n+1) uts (p:projs)
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
isCommon :: Set TmOccName -> [Either Term Type] -> Bool
isCommon _ [] = True
isCommon _ (Right ty:tys) = Set.null (Lens.setOf typeFreeVars ty) &&
allEqual (Right ty:tys)
isCommon fvs (Left tm:tms) = Set.null (Lens.setOf termFreeIds tm Set.\\ fvs) &&
allEqual (Left tm:tms)
mkDJArgs :: Int
-> [(Int,Either Term Type)]
-> [Term]
-> [Either Term Type]
mkDJArgs _ cms [] = map snd cms
mkDJArgs _ [] uncms = map Left uncms
mkDJArgs n ((m,x):cms) (y:uncms)
| n == m = x : mkDJArgs (n+1) cms (y:uncms)
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms
genCase :: Type
-> Maybe DataCon
-> [Type]
-> CaseTree [Term]
-> Term
genCase ty dcM argTys = go
where
go (Leaf tms) =
case dcM of
Just dc -> mkApps (Data dc) (map Right argTys ++ map Left tms)
_ -> head tms
go (LB lb ct) =
Letrec (bind (Unbound.rec lb) (go ct))
go (Branch scrut [(p,ct)]) =
let ct' = go ct
alt = bind p ct'
in case Lens.setOf termFreeIds ct' == Lens.setOf fv alt of
True -> ct'
_ -> Case scrut ty [alt]
go (Branch scrut pats) =
Case scrut ty (map (\(p,ct) -> bind p (go ct)) pats)
interestingToLift
:: Set TmOccName
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> Maybe Term
interestingToLift inScope _ e@(Var _ nm) _ =
if nameOcc nm `Set.member` inScope
then Just e
else Nothing
interestingToLift inScope eval e@(Prim nm pty) args =
case List.lookup nm interestingPrims of
Just t | t || not (all isConstant lArgs) -> Just e
_ -> if isHOTy pty
then if not . null . Maybe.catMaybes $
map (uncurry (interestingToLift inScope eval) .
collectArgs
) lArgs
then Just e
else Nothing
else Nothing
where
interestingPrims =
[("Clash.Sized.Internal.BitVector.*#",tailNonPow2)
,("Clash.Sized.Internal.BitVector.times#",tailNonPow2)
,("Clash.Sized.Internal.BitVector.quot#",lastNotPow2)
,("Clash.Sized.Internal.BitVector.rem#",lastNotPow2)
,("Clash.Sized.Internal.Index.*#",tailNonPow2)
,("Clash.Sized.Internal.Index.quot#",lastNotPow2)
,("Clash.Sized.Internal.Index.rem#",lastNotPow2)
,("Clash.Sized.Internal.Signed.*#",tailNonPow2)
,("Clash.Sized.Internal.Signed.times#",tailNonPow2)
,("Clash.Sized.Internal.Signed.rem#",lastNotPow2)
,("Clash.Sized.Internal.Signed.quot#",lastNotPow2)
,("Clash.Sized.Internal.Signed.div#",lastNotPow2)
,("Clash.Sized.Internal.Signed.mod#",lastNotPow2)
,("Clash.Sized.Internal.Unsigned.*#",tailNonPow2)
,("Clash.Sized.Internal.Unsigned.times#",tailNonPow2)
,("Clash.Sized.Internal.Unsigned.quot#",lastNotPow2)
,("Clash.Sized.Internal.Unsigned.rem#",lastNotPow2)
,("GHC.Base.quotInt",lastNotPow2)
,("GHC.Base.remInt",lastNotPow2)
,("GHC.Base.divInt",lastNotPow2)
,("GHC.Base.modInt",lastNotPow2)
,("GHC.Classes.divInt#",lastNotPow2)
,("GHC.Classes.modInt#",lastNotPow2)
,("GHC.Integer.Type.timesInteger",allNonPow2)
,("GHC.Integer.Type.divInteger",lastNotPow2)
,("GHC.Integer.Type.modInteger",lastNotPow2)
,("GHC.Integer.Type.quotInteger",lastNotPow2)
,("GHC.Integer.Type.remInteger",lastNotPow2)
,("GHC.Prim.*#",allNonPow2)
,("GHC.Prim.quotInt#",lastNotPow2)
,("GHC.Prim.remInt#",lastNotPow2)
]
lArgs = Either.lefts args
allNonPow2 = all (not . termIsPow2) lArgs
tailNonPow2 = case lArgs of
[] -> True
_ -> all (not . termIsPow2) (tail lArgs)
lastNotPow2 = case lArgs of
[] -> True
_ -> not (termIsPow2 (last lArgs))
termIsPow2 e' = case eval e' of
Literal (IntegerLiteral n) -> isPow2 n
a -> case collectArgs a of
(Prim nm' _,[Right _,Left _,Left (Literal (IntegerLiteral n))])
| isFromInteger nm' -> isPow2 n
_ -> False
isPow2 x = x /= 0 && (x .&. (complement x + 1)) == x
isFromInteger x = x `elem` ["Clash.Sized.Internal.BitVector.fromInteger##"
,"Clash.Sized.Internal.BitVector.fromInteger#"
,"Clash.Sized.Integer.Index.fromInteger"
,"Clash.Sized.Internal.Signed.fromInteger#"
,"Clash.Sized.Internal.Unsigned.fromInteger#"
]
isHOTy t = case splitFunForallTy t of
(args',_) -> any isPolyFunTy (Either.rights args')
interestingToLift _ _ _ _ = Nothing