-- | Analyses variable for reusal.
-- So that we spend less memory and allocate lesser variables and audio vectors.
module Csound.Dynamic.Tfm.Liveness (
    liveness
) where

import Prelude hiding (mapM, mapM_)

import Control.Monad.Trans.State.Strict
import Data.Traversable
import Data.Foldable

import Control.Monad.Trans.Class
import Control.Monad hiding (mapM, mapM_)
import Control.Monad.ST
import Data.Vector.Unboxed.Mutable qualified as UVector

import Csound.Dynamic.Tfm.InferTypes (Var (..))
import Csound.Dynamic.Types.Exp(Rate(..))

-- | Reuses variables. It analyses weather the vraibel is used further
-- in the code and if it's not used it tries to reuse it for the next assignments
liveness :: Traversable f => Int -> Dag f -> Dag f
liveness :: forall (f :: * -> *). Traversable f => Int -> Dag f -> Dag f
liveness Int
lastFreshId Dag f
as = (forall s. ST s (Dag f)) -> Dag f
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Dag f)) -> Dag f)
-> (forall s. ST s (Dag f)) -> Dag f
forall a b. (a -> b) -> a -> b
$ do
  Registers s
st <- Int -> LivenessTable s -> ST s (Registers s)
forall s. Int -> LivenessTable s -> ST s (Registers s)
initSt Int
lastFreshId (LivenessTable s -> ST s (Registers s))
-> ST s (LivenessTable s) -> ST s (Registers s)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> Dag f -> ST s (LivenessTable s)
forall (f :: * -> *) s.
Traversable f =>
Int -> Dag f -> ST s (LivenessTable s)
analyse Int
lastFreshId Dag f
as
  StateT (Registers s) (ST s) (Dag f) -> Registers s -> ST s (Dag f)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (((Int, Exp f) -> StateT (Registers s) (ST s) (Exp f))
-> [(Int, Exp f)] -> StateT (Registers s) (ST s) (Dag f)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int, Exp f) -> StateT (Registers s) (ST s) (Exp f)
forall (f :: * -> *) s.
Traversable f =>
(Int, Exp f) -> Memory s (Exp f)
substExp ([(Int, Exp f)] -> StateT (Registers s) (ST s) (Dag f))
-> [(Int, Exp f)] -> StateT (Registers s) (ST s) (Dag f)
forall a b. (a -> b) -> a -> b
$ Dag f -> [(Int, Exp f)]
forall a. [a] -> [(Int, a)]
countLines (Dag f -> [(Int, Exp f)]) -> Dag f -> [(Int, Exp f)]
forall a b. (a -> b) -> a -> b
$ Dag f
as) Registers s
st

type LineNumber = Int

countLines :: [a] -> [(LineNumber, a)]
countLines :: forall a. [a] -> [(Int, a)]
countLines = [Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]

type Lhs   = [Var]
type Rhs f = f Var
type Exp f = (Lhs, Rhs f)

type Dag f = [Exp f]

-----------------------------------------------

data IdList = IdList
    [Int] -- fresh ids (always infinite list)
    !Int   -- the biggest used id

allocId :: IdList -> (Int, IdList)
allocId :: IdList -> (Int, IdList)
allocId (IdList [Int]
is Int
lastId) =
  case  [Int]
is of
    Int
hd : [Int]
tl -> (Int
hd, [Int] -> Int -> IdList
IdList [Int]
tl (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
hd Int
lastId))
    [] -> [Char] -> (Int, IdList)
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible: list of IDs is always infinite"

freeId :: Int -> IdList -> IdList
freeId :: Int -> IdList -> IdList
freeId  Int
n (IdList [Int]
is Int
lastId) = [Int] -> Int -> IdList
IdList (Int -> [Int] -> [Int]
insertSorted Int
n [Int]
is) Int
lastId1
  where lastId1 :: Int
lastId1 = if (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lastId) then (Int
lastId Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) else Int
lastId

insertSorted :: Int -> [Int] -> [Int]
insertSorted :: Int -> [Int] -> [Int]
insertSorted Int
n (Int
a:[Int]
as)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
a  = Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
a Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
as
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a = Int
a Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
as
  | Bool
otherwise = Int
a Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
insertSorted Int
n [Int]
as
insertSorted Int
n [] = [Int
n]

initIdList :: IdList
initIdList :: IdList
initIdList = [Int] -> Int -> IdList
IdList [Int
0..] Int
0

-----------------------------------------------

type StArr s = UVector.STVector s Int

type LivenessTable s = UVector.STVector s Int
type SubstTable s    = StArr s

data Registers s = Registers
  { forall s. Registers s -> IdList
arRegisters   :: !IdList
  , forall s. Registers s -> IdList
krRegisters   :: !IdList
  , forall s. Registers s -> LivenessTable s
livenessTable :: !(LivenessTable s)
  , forall s. Registers s -> LivenessTable s
substTable    :: !(SubstTable s)
  }

type Memory s a = StateT (Registers s) (ST s) a

onRegs :: Rate -> (IdList -> IdList) -> Memory s ()
onRegs :: forall s. Rate -> (IdList -> IdList) -> Memory s ()
onRegs Rate
rate IdList -> IdList
f = (Registers s -> Registers s) -> StateT (Registers s) (ST s) ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((Registers s -> Registers s) -> StateT (Registers s) (ST s) ())
-> (Registers s -> Registers s) -> StateT (Registers s) (ST s) ()
forall a b. (a -> b) -> a -> b
$ \Registers s
rs ->
  case Rate
rate of
    Rate
Ar -> Registers s
rs { arRegisters = f $ arRegisters rs }
    Rate
Kr -> Registers s
rs { krRegisters = f $ krRegisters rs }
    Rate
_  -> Registers s
rs

setArRegisters :: IdList -> Memory s ()
setArRegisters :: forall s. IdList -> Memory s ()
setArRegisters IdList
ids = (Registers s -> Registers s) -> StateT (Registers s) (ST s) ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((Registers s -> Registers s) -> StateT (Registers s) (ST s) ())
-> (Registers s -> Registers s) -> StateT (Registers s) (ST s) ()
forall a b. (a -> b) -> a -> b
$ \Registers s
s -> Registers s
s { arRegisters = ids }

setKrRegisters :: IdList -> Memory s ()
setKrRegisters :: forall s. IdList -> Memory s ()
setKrRegisters IdList
ids = (Registers s -> Registers s) -> StateT (Registers s) (ST s) ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((Registers s -> Registers s) -> StateT (Registers s) (ST s) ())
-> (Registers s -> Registers s) -> StateT (Registers s) (ST s) ()
forall a b. (a -> b) -> a -> b
$ \Registers s
s -> Registers s
s { krRegisters = ids }


isAlive :: LineNumber -> Var -> Memory s Bool
isAlive :: forall s. Int -> Var -> Memory s Bool
isAlive Int
lineNum Var
v = do
  LivenessTable s
tab <- (Registers s -> LivenessTable s)
-> StateT (Registers s) (ST s) (Registers s)
-> StateT (Registers s) (ST s) (LivenessTable s)
forall a b.
(a -> b)
-> StateT (Registers s) (ST s) a -> StateT (Registers s) (ST s) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Registers s -> LivenessTable s
forall s. Registers s -> LivenessTable s
livenessTable StateT (Registers s) (ST s) (Registers s)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  Int
lastUsage <- MVector (PrimState (StateT (Registers s) (ST s))) Int
-> Int -> StateT (Registers s) (ST s) Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UVector.read LivenessTable s
MVector (PrimState (StateT (Registers s) (ST s))) Int
tab (Var -> Int
varId Var
v)
  Bool -> Memory s Bool
forall a. a -> StateT (Registers s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Memory s Bool) -> Bool -> Memory s Bool
forall a b. (a -> b) -> a -> b
$ Int
lineNum Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lastUsage

lookUpSubst :: Int -> Memory s Int
lookUpSubst :: forall s. Int -> Memory s Int
lookUpSubst Int
i = do
  SubstTable s
tab <- (Registers s -> SubstTable s)
-> StateT (Registers s) (ST s) (Registers s)
-> StateT (Registers s) (ST s) (SubstTable s)
forall a b.
(a -> b)
-> StateT (Registers s) (ST s) a -> StateT (Registers s) (ST s) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Registers s -> SubstTable s
forall s. Registers s -> LivenessTable s
substTable StateT (Registers s) (ST s) (Registers s)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  ST s Int -> Memory s Int
forall (m :: * -> *) a. Monad m => m a -> StateT (Registers s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s Int -> Memory s Int) -> ST s Int -> Memory s Int
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UVector.read SubstTable s
MVector (PrimState (ST s)) Int
tab Int
i

saveSubst :: Int -> Int -> Memory s ()
saveSubst :: forall s. Int -> Int -> Memory s ()
saveSubst Int
from Int
to = do
  SubstTable s
tab <- (Registers s -> SubstTable s)
-> StateT (Registers s) (ST s) (Registers s)
-> StateT (Registers s) (ST s) (SubstTable s)
forall a b.
(a -> b)
-> StateT (Registers s) (ST s) a -> StateT (Registers s) (ST s) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Registers s -> SubstTable s
forall s. Registers s -> LivenessTable s
substTable StateT (Registers s) (ST s) (Registers s)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  ST s () -> Memory s ()
forall (m :: * -> *) a. Monad m => m a -> StateT (Registers s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> Memory s ()) -> ST s () -> Memory s ()
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UVector.write SubstTable s
MVector (PrimState (ST s)) Int
tab Int
from Int
to

substLhs :: Var -> Memory s Var
substLhs :: forall s. Var -> Memory s Var
substLhs = (Var -> StateT (Registers s) (ST s) Var)
-> Var -> StateT (Registers s) (ST s) Var
forall (f :: * -> *). Monad f => (Var -> f Var) -> Var -> f Var
onlyForAK ((Var -> StateT (Registers s) (ST s) Var)
 -> Var -> StateT (Registers s) (ST s) Var)
-> (Var -> StateT (Registers s) (ST s) Var)
-> Var
-> StateT (Registers s) (ST s) Var
forall a b. (a -> b) -> a -> b
$ \Var
v -> do
  Var
v1 <- Var -> StateT (Registers s) (ST s) Var
forall s. Var -> Memory s Var
alloc Var
v
  Int -> Int -> Memory s ()
forall s. Int -> Int -> Memory s ()
saveSubst (Var -> Int
varId Var
v) (Var -> Int
varId Var
v1)
  Var -> StateT (Registers s) (ST s) Var
forall a. a -> StateT (Registers s) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v1

substRhs :: LineNumber -> Var -> Memory s Var
substRhs :: forall s. Int -> Var -> Memory s Var
substRhs Int
lineNum = (Var -> StateT (Registers s) (ST s) Var)
-> Var -> StateT (Registers s) (ST s) Var
forall (f :: * -> *). Monad f => (Var -> f Var) -> Var -> f Var
onlyForAK ((Var -> StateT (Registers s) (ST s) Var)
 -> Var -> StateT (Registers s) (ST s) Var)
-> (Var -> StateT (Registers s) (ST s) Var)
-> Var
-> StateT (Registers s) (ST s) Var
forall a b. (a -> b) -> a -> b
$ \Var
v -> do
  Int
i1 <- Int -> Memory s Int
forall s. Int -> Memory s Int
lookUpSubst (Var -> Int
varId Var
v)
  let v1 :: Var
v1 = Rate -> Int -> Var
Var (Var -> Rate
varType Var
v) Int
i1
  Bool
b <- Int -> Var -> Memory s Bool
forall s. Int -> Var -> Memory s Bool
isAlive Int
lineNum Var
v
  Bool
-> StateT (Registers s) (ST s) () -> StateT (Registers s) (ST s) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
b (StateT (Registers s) (ST s) () -> StateT (Registers s) (ST s) ())
-> StateT (Registers s) (ST s) () -> StateT (Registers s) (ST s) ()
forall a b. (a -> b) -> a -> b
$ Var -> StateT (Registers s) (ST s) ()
forall s. Var -> Memory s ()
free Var
v1
  Var -> StateT (Registers s) (ST s) Var
forall a. a -> StateT (Registers s) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v1

alloc :: Var -> Memory s Var
alloc :: forall s. Var -> Memory s Var
alloc Var
v =
  case Var -> Rate
varType Var
v of
    Rate
Ar -> (Registers s -> IdList)
-> (IdList -> StateT (Registers s) (ST s) ()) -> Memory s Var
forall {m :: * -> *} {s} {a}.
Monad m =>
(s -> IdList) -> (IdList -> StateT s m a) -> StateT s m Var
allocBy Registers s -> IdList
forall s. Registers s -> IdList
arRegisters IdList -> StateT (Registers s) (ST s) ()
forall s. IdList -> Memory s ()
setArRegisters
    Rate
Kr -> (Registers s -> IdList)
-> (IdList -> StateT (Registers s) (ST s) ()) -> Memory s Var
forall {m :: * -> *} {s} {a}.
Monad m =>
(s -> IdList) -> (IdList -> StateT s m a) -> StateT s m Var
allocBy Registers s -> IdList
forall s. Registers s -> IdList
krRegisters IdList -> StateT (Registers s) (ST s) ()
forall s. IdList -> Memory s ()
setKrRegisters
    Rate
_  -> Var -> Memory s Var
forall a. a -> StateT (Registers s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Var
v
  where
    allocBy :: (s -> IdList) -> (IdList -> StateT s m a) -> StateT s m Var
allocBy s -> IdList
extract IdList -> StateT s m a
update = do
      IdList
ids <- (s -> IdList) -> StateT s m IdList
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets s -> IdList
extract
      let (Int
name, IdList
newIds) = IdList -> (Int, IdList)
allocId IdList
ids
      StateT s m a -> StateT s m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (StateT s m a -> StateT s m ()) -> StateT s m a -> StateT s m ()
forall a b. (a -> b) -> a -> b
$ IdList -> StateT s m a
update IdList
newIds
      Var -> StateT s m Var
forall a. a -> StateT s m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Rate -> Int -> Var
Var (Var -> Rate
varType Var
v) Int
name)

free :: Var -> Memory s ()
free :: forall s. Var -> Memory s ()
free (Var Rate
rate Int
name) = Rate -> (IdList -> IdList) -> Memory s ()
forall s. Rate -> (IdList -> IdList) -> Memory s ()
onRegs Rate
rate (Int -> IdList -> IdList
freeId Int
name)

--------------------------------------------------------------------------

analyse :: Traversable f => Int -> Dag f -> ST s (LivenessTable s)
analyse :: forall (f :: * -> *) s.
Traversable f =>
Int -> Dag f -> ST s (LivenessTable s)
analyse Int
lastFreshId Dag f
as = do
  LivenessTable s
arr <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UVector.replicate Int
lastFreshId Int
0
  ((Int, Exp f) -> ST s ()) -> [(Int, Exp f)] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (LivenessTable s -> (Int, Exp f) -> ST s ()
forall (f :: * -> *) s.
Traversable f =>
StArr s -> (Int, Exp f) -> ST s ()
go LivenessTable s
arr) ([(Int, Exp f)] -> ST s ()) -> [(Int, Exp f)] -> ST s ()
forall a b. (a -> b) -> a -> b
$ Dag f -> [(Int, Exp f)]
forall a. [a] -> [(Int, a)]
countLines Dag f
as
  LivenessTable s -> ST s (LivenessTable s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return LivenessTable s
arr
  where
    go :: Traversable f => StArr s -> (LineNumber, Exp f) -> ST s ()
    go :: forall (f :: * -> *) s.
Traversable f =>
StArr s -> (Int, Exp f) -> ST s ()
go StArr s
arr (Int
lineNum, (Lhs
_, Rhs f
rhs)) = (Var -> ST s ()) -> Rhs f -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (StArr s -> Int -> Var -> ST s ()
forall s. StArr s -> Int -> Var -> ST s ()
countVar StArr s
arr Int
lineNum) Rhs f
rhs

    countVar :: StArr s  -> LineNumber -> Var -> ST s ()
    countVar :: forall s. StArr s -> Int -> Var -> ST s ()
countVar StArr s
arr Int
lineNum Var
v
      | Var -> Bool
isAOrK Var
v  = MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UVector.write StArr s
MVector (PrimState (ST s)) Int
arr (Var -> Int
varId Var
v) Int
lineNum
      | Bool
otherwise = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

onlyForAK :: Monad f => (Var -> f Var) -> Var -> f Var
onlyForAK :: forall (f :: * -> *). Monad f => (Var -> f Var) -> Var -> f Var
onlyForAK Var -> f Var
go Var
v
  | Var -> Bool
isAOrK Var
v  = Var -> f Var
go Var
v
  | Bool
otherwise = Var -> f Var
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Var
v

-- we optimise for livenes only for Ar and Kr variables
isAOrK :: Var -> Bool
isAOrK :: Var -> Bool
isAOrK Var
v =
  case Var -> Rate
varType Var
v of
    Rate
Ar -> Bool
True
    Rate
Kr -> Bool
True
    Rate
_  -> Bool
False

substExp :: Traversable f => (LineNumber, Exp f) -> Memory s (Exp f)
substExp :: forall (f :: * -> *) s.
Traversable f =>
(Int, Exp f) -> Memory s (Exp f)
substExp (Int
lineNum, (Lhs
lhs, Rhs f
rhs)) = do
  Lhs
freshLhs <- (Var -> StateT (Registers s) (ST s) Var)
-> Lhs -> StateT (Registers s) (ST s) Lhs
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Var -> StateT (Registers s) (ST s) Var
forall s. Var -> Memory s Var
substLhs Lhs
lhs
  Rhs f
freshRhs <- (Var -> StateT (Registers s) (ST s) Var)
-> Rhs f -> StateT (Registers s) (ST s) (Rhs f)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (Int -> Var -> StateT (Registers s) (ST s) Var
forall s. Int -> Var -> Memory s Var
substRhs Int
lineNum) Rhs f
rhs
  Exp f -> Memory s (Exp f)
forall a. a -> StateT (Registers s) (ST s) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Lhs
freshLhs, Rhs f
freshRhs)

initSt :: Int -> LivenessTable s -> ST s (Registers s)
initSt :: forall s. Int -> LivenessTable s -> ST s (Registers s)
initSt Int
lastFreshId LivenessTable s
livenessTab = (LivenessTable s -> Registers s)
-> ST s (LivenessTable s) -> ST s (Registers s)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (IdList
-> IdList -> LivenessTable s -> LivenessTable s -> Registers s
forall s.
IdList
-> IdList -> LivenessTable s -> LivenessTable s -> Registers s
Registers IdList
initIdList IdList
initIdList LivenessTable s
livenessTab) (Int -> ST s (LivenessTable s)
forall s. Int -> ST s (SubstTable s)
initSubstTable Int
lastFreshId)

initSubstTable :: Int ->  ST s (SubstTable s)
initSubstTable :: forall s. Int -> ST s (SubstTable s)
initSubstTable Int
n = Int -> (Int -> Int) -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> (Int -> a) -> m (MVector (PrimState m) a)
UVector.generate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int
forall a. a -> a
id