-- Version of the Yhc.Core.Strictness module, but based on
-- Core Annotations.

module Yhc.Core.StrictAnno (
  coreStrictAnno) where

import Yhc.Core.Type
import Yhc.Core.Prim
import Yhc.Core.Annotation
import Yhc.Core.AnnotatePrims

import qualified Data.Map as Map
import Data.List(intersect, nub, partition)

{-
ALGORITHM:

SCC PARTIAL SORT:
First sort the functions so that they occur in the childmost order:
x1 < x2, if x1 doesn't transitive-call x2, and x2 does transitive-call x1
Being wrong is fine, but being better gives better results

PRIM STRICTNESS:
The strictness of the various primitive operations

BASE STRICTNESS:
If all paths case on a particular value, then these are strict in that one
If call onwards, then strict based on the caller
-}


-- | Given a function, return a list of arguments.
--   True is strict in that argument, False is not.
--   [] is unknown strictness
--

coreStrictAnno :: CoreAnnotations -> Core -> (CoreFuncName -> [Bool])

coreStrictAnno anno core = \funcname -> Map.findWithDefault [] funcname mp
    where mp = mapStrictAnno anno $ sccSort $ coreFuncs core


mapStrictAnno anno funcs = foldl f Map.empty funcs
    where
        f mp (prim@CorePrim{coreFuncName = name}) = 
          case getAnnotation prim "Strictness" anno of
            Nothing -> mp
            Just (CoreStrictness bs) -> Map.insert name bs mp

        f mp func@(CoreFunc name args body) = case getAnnotation func "Strictness" anno of
          Just (CoreStrictness bs) -> Map.insert name bs mp
          Nothing -> Map.insert name (map (`elem` strict) args) mp
            where
                strict = strictVars body

                -- which variables are strict
                strictVars :: CoreExpr -> [String]
                strictVars (CorePos _ x) = strictVars x
                strictVars (CoreVar x) = [x]

                strictVars (CoreCase (CoreVar x) alts) = 
                  nub $ x : intersectList (map (strictVars . snd) alts)

                strictVars (CoreCase x alts) = strictVars x

                strictVars (CoreApp (CoreFun x) xs)
                    | length xs == length res
                    = nub $ concatMap strictVars $ map snd $ filter fst $ zip res xs
                    where res = Map.findWithDefault [] x mp

                strictVars (CoreApp x xs) = strictVars x

                strictVars _ = []




intersectList :: Eq a => [[a]] -> [a]
intersectList [] = []
intersectList xs = foldr1 intersect xs


-- do a sort in approximate SCC order
sccSort :: [CoreFunc] -> [CoreFunc]
sccSort xs = prims ++ funcs
    where (prims,funcs) = partition isCorePrim xs