module Csound.Dynamic.Tfm.Liveness (
liveness
) where
import Prelude hiding (mapM, mapM_)
import Control.Monad.Trans.State.Strict
import Data.Traversable
import Data.Foldable
import qualified Data.Map as M
import Control.Monad.Trans.Class
import Control.Monad hiding (mapM, mapM_)
import Control.Monad.ST
import qualified Data.Array.Unboxed as A
import qualified Data.Array.MArray as A
import qualified Data.Array.ST as A
import qualified Csound.Dynamic.Tfm.DeduceTypes as D
import Csound.Dynamic.Tfm.DeduceTypes(varType, varId)
import Csound.Dynamic.Types.Exp(Rate(..))
liveness :: Traversable f => Int -> Dag f -> Dag f
liveness lastFreshId as = runST $ do
st <- initSt lastFreshId $ analyse lastFreshId as
evalStateT (mapM substExp $ countLines $ as) st
type LineNumber = Int
countLines :: [a] -> [(LineNumber, a)]
countLines = zip [0 ..]
type Var = D.Var Rate
type Lhs = [Var]
type Rhs f = f Var
type Exp f = (Lhs, Rhs f)
type Dag f = [Exp f]
data IdList = IdList
[Int]
Int
allocId :: IdList -> (Int, IdList)
allocId (IdList is lastId) = (head is, IdList (tail is) (max (head is) lastId))
freeId :: Int -> IdList -> IdList
freeId n (IdList is lastId) = IdList (insertSorted n is) lastId1
where lastId1 = if (n == lastId) then (lastId - 1) else lastId
insertSorted :: Int -> [Int] -> [Int]
insertSorted n (a:as)
| n < a = n : a : as
| n == a = a : as
| otherwise = a : insertSorted n as
insertSorted n [] = [n]
type StArr s = A.STUArray s Int Int
type LivenessTable = A.UArray Int Int
type SubstTable s = StArr s
data Registers s = Registers
{ registers :: M.Map Rate IdList
, livenessTable :: LivenessTable
, substTable :: SubstTable s
}
type Memory s a = StateT (Registers s) (ST s) a
onRegs :: (M.Map Rate IdList -> M.Map Rate IdList) -> (Registers s -> Registers s)
onRegs f rs = rs { registers = f $ registers rs }
initRegs :: M.Map Rate IdList
initRegs = M.fromList $ fmap (\x -> (x, initIdList)) [(minBound :: Rate) .. maxBound]
where initIdList = IdList [0..] 0
isAlive :: LineNumber -> Var -> Memory s Bool
isAlive lineNum v = do
tab <- fmap livenessTable get
return $ lineNum < tab A.! (varId v)
lookUpSubst :: Int -> Memory s Int
lookUpSubst i = do
tab <- fmap substTable get
lift $ A.readArray tab i
saveSubst :: Int -> Int -> Memory s ()
saveSubst from to = do
tab <- fmap substTable get
lift $ A.writeArray tab from to
substLhs :: Var -> Memory s Var
substLhs v = do
v1 <- allocAndSkipInits v
saveSubst (varId v) (varId v1)
return v1
substRhs :: LineNumber -> Var -> Memory s Var
substRhs lineNum v = do
i1 <- lookUpSubst (varId v)
let v1 = D.Var i1 (varType v)
b <- isAlive lineNum v
unless b $ free v1
return v1
allocAndSkipInits :: Var -> Memory s Var
allocAndSkipInits v
| isInit r = return v
| otherwise = alloc r
where
r = varType v
isInit x = x == Ir || x == Sr
alloc :: Rate -> Memory s Var
alloc rate = state $ \mem ->
let (i, mem1) = allocRegister rate mem
in (D.Var i rate, mem1)
where
allocRegister :: Rate -> Registers s -> (Int, Registers s)
allocRegister r mem = (i, onRegs (M.update (const $ Just is) r) mem)
where (i, is) = allocId $ registers mem M.! r
free :: Var -> Memory s ()
free v = state $ \mem ->
let mem1 = freeRegister (varType v) (varId v) mem
in ((), mem1)
where
freeRegister :: Rate -> Int -> Registers s -> Registers s
freeRegister rate i = onRegs $ M.update (Just . freeId i) rate
analyse :: Traversable f => Int -> Dag f -> LivenessTable
analyse lastFreshId as = A.runSTUArray $ do
arr <- A.newArray (0, lastFreshId) 0
mapM_ (go arr) $ countLines as
return arr
where
go :: Traversable f => StArr s -> (LineNumber, Exp f) -> ST s ()
go arr (lineNum, (_, rhs)) = mapM (countVar arr lineNum) rhs >> return ()
countVar :: StArr s -> LineNumber -> Var -> ST s ()
countVar arr lineNum v = do
val <- A.readArray arr i
A.writeArray arr i (val `max` lineNum)
where i = varId v
substExp :: Traversable f => (LineNumber, Exp f) -> Memory s (Exp f)
substExp (lineNum, (lhs, rhs)) = do
freshLhs <- traverse substLhs lhs
freshRhs <- traverse (substRhs lineNum) rhs
return (freshLhs, freshRhs)
initSt :: Int -> LivenessTable -> ST s (Registers s)
initSt lastFreshId livenessTab = fmap (Registers initRegs livenessTab) (initSubstTable lastFreshId)
initSubstTable :: Int -> ST s (SubstTable s)
initSubstTable n = A.newListArray (0, n+1) [0 .. n + 1]