{-# LANGUAGE FlexibleContexts #-}
module Futhark.Analysis.Range
       ( rangeAnalysis
       , runRangeM
       , RangeM
       , analyseExp
       , analyseLambda
       , analyseBody
       , analyseStms
       )
       where

import qualified Data.Map.Strict as M
import Control.Monad.Reader
import Data.List

import qualified Futhark.Analysis.ScalExp as SE
import Futhark.Representation.Ranges
import Futhark.Analysis.AlgSimplify as AS

-- Entry point

-- | Perform variable range analysis on the given program, returning a
-- program with embedded range annotations.
rangeAnalysis :: (Attributes lore, CanBeRanged (Op lore)) =>
                 Prog lore -> Prog (Ranges lore)
rangeAnalysis = Prog . map analyseFun . progFunctions

-- Implementation

analyseFun :: (Attributes lore, CanBeRanged (Op lore)) =>
              FunDef lore -> FunDef (Ranges lore)
analyseFun (FunDef entry fname restype params body) =
  runRangeM $ bindFunParams params $
  FunDef entry fname restype params <$> analyseBody body

analyseBody :: (Attributes lore, CanBeRanged (Op lore)) =>
               Body lore
            -> RangeM (Body (Ranges lore))
analyseBody (Body lore origbnds result) =
  analyseStms origbnds $ \bnds' ->
    return $ mkRangedBody lore bnds' result

analyseStms :: (Attributes lore, CanBeRanged (Op lore)) =>
               Stms lore
            -> (Stms (Ranges lore) -> RangeM a)
            -> RangeM a
analyseStms = analyseStms' mempty . stmsToList
  where analyseStms' acc [] m =
          m acc
        analyseStms' acc (bnd:bnds) m = do
          bnd' <- analyseStm bnd
          bindPattern (stmPattern bnd') $
            analyseStms' (acc <> oneStm bnd') bnds m

analyseStm :: (Attributes lore, CanBeRanged (Op lore)) =>
              Stm lore -> RangeM (Stm (Ranges lore))
analyseStm (Let pat lore e) = do
  e' <- analyseExp e
  pat' <- simplifyPatRanges $ addRangesToPattern pat e'
  return $ Let pat' lore e'

analyseExp :: (Attributes lore, CanBeRanged (Op lore)) =>
              Exp lore
           -> RangeM (Exp (Ranges lore))
analyseExp = mapExpM analyse
  where analyse =
          Mapper { mapOnSubExp = return
                    , mapOnCertificates = return
                    , mapOnVName = return
                    , mapOnBody = const analyseBody
                    , mapOnRetType = return
                    , mapOnBranchType = return
                    , mapOnFParam = return
                    , mapOnLParam = return
                    , mapOnOp = return . addOpRanges
                    }

analyseLambda :: (Attributes lore, CanBeRanged (Op lore)) =>
                 Lambda lore
              -> RangeM (Lambda (Ranges lore))
analyseLambda lam = do
  body <- analyseBody $ lambdaBody lam
  return $ lam { lambdaBody = body
               , lambdaParams = lambdaParams lam
               }

-- Monad and utility definitions

type RangeEnv = M.Map VName Range

emptyRangeEnv :: RangeEnv
emptyRangeEnv = M.empty

type RangeM = Reader RangeEnv

runRangeM :: RangeM a -> a
runRangeM = flip runReader emptyRangeEnv

bindFunParams :: Typed attr =>
                 [ParamT attr] -> RangeM a -> RangeM a
bindFunParams []             m =
  m
bindFunParams (param:params) m = do
  ranges <- rangesRep
  local bindFunParam $
    local (refineDimensionRanges ranges dims) $
    bindFunParams params m
  where bindFunParam = M.insert (paramName param) unknownRange
        dims = arrayDims $ paramType param

bindPattern :: Typed attr =>
               PatternT (Range, attr) -> RangeM a -> RangeM a
bindPattern pat m = do
  ranges <- rangesRep
  local bindPatElems $
    local (refineDimensionRanges ranges dims)
    m
  where bindPatElems env =
          foldl bindPatElem env $ patternElements pat
        bindPatElem env patElem =
          M.insert (patElemName patElem) (fst $ patElemAttr patElem) env
        dims = nub $ concatMap arrayDims $ patternTypes pat

refineDimensionRanges :: AS.RangesRep -> [SubExp]
                      -> RangeEnv -> RangeEnv
refineDimensionRanges ranges = flip $ foldl refineShape
  where refineShape env (Var dim) =
          refineRange ranges dim dimBound env
        refineShape env _ =
          env
        -- A dimension is never negative.
        dimBound :: Range
        dimBound = (Just $ ScalarBound 0,
                    Nothing)

refineRange :: AS.RangesRep -> VName -> Range -> RangeEnv
            -> RangeEnv
refineRange =
  M.insertWith . refinedRange

-- New range, old range, result range.
refinedRange :: AS.RangesRep -> Range -> Range -> Range
refinedRange ranges (new_lower, new_upper) (old_lower, old_upper) =
  (simplifyBound ranges $ refineLowerBound new_lower old_lower,
   simplifyBound ranges $ refineUpperBound new_upper old_upper)

-- New bound, old bound, result bound.
refineLowerBound :: Bound -> Bound -> Bound
refineLowerBound = flip maximumBound

-- New bound, old bound, result bound.
refineUpperBound :: Bound -> Bound -> Bound
refineUpperBound = flip minimumBound

lookupRange :: VName -> RangeM Range
lookupRange = asks . M.findWithDefault unknownRange

simplifyPatRanges :: PatternT (Range, attr)
                  -> RangeM (PatternT (Range, attr))
simplifyPatRanges (Pattern context values) =
  Pattern <$> mapM simplifyPatElemRange context <*> mapM simplifyPatElemRange values
  where simplifyPatElemRange patElem = do
          let (range, innerattr) = patElemAttr patElem
          range' <- simplifyRange range
          return $ setPatElemLore patElem (range', innerattr)

simplifyRange :: Range -> RangeM Range
simplifyRange (lower, upper) = do
  ranges <- rangesRep
  lower' <- simplifyBound ranges <$> betterLowerBound lower
  upper' <- simplifyBound ranges <$> betterUpperBound upper
  return (lower', upper')

simplifyBound :: AS.RangesRep -> Bound -> Bound
simplifyBound ranges = fmap $ simplifyKnownBound ranges

simplifyKnownBound :: AS.RangesRep -> KnownBound -> KnownBound
simplifyKnownBound ranges bound
  | Just se <- boundToScalExp bound =
    ScalarBound $ AS.simplify se ranges
simplifyKnownBound ranges (MinimumBound b1 b2) =
  MinimumBound (simplifyKnownBound ranges b1) (simplifyKnownBound ranges b2)
simplifyKnownBound ranges (MaximumBound b1 b2) =
  MaximumBound (simplifyKnownBound ranges b1) (simplifyKnownBound ranges b2)
simplifyKnownBound _ bound =
  bound

betterLowerBound :: Bound -> RangeM Bound
betterLowerBound (Just (ScalarBound (SE.Id v t))) = do
  range <- lookupRange v
  return $ Just $ case range of
    (Just lower, _) -> lower
    _               -> ScalarBound $ SE.Id v t
betterLowerBound bound =
  return bound

betterUpperBound :: Bound -> RangeM Bound
betterUpperBound (Just (ScalarBound (SE.Id v t))) = do
  range <- lookupRange v
  return $ Just $ case range of
    (_, Just upper) -> upper
    _               -> ScalarBound $ SE.Id v t
betterUpperBound bound =
  return bound

-- The algebraic simplifier requires a loop nesting level for each
-- range.  We just put a zero because I don't think it's used for
-- anything in this case.
rangesRep :: RangeM AS.RangesRep
rangesRep = asks $ M.map addLeadingZero
  where addLeadingZero (x,y) =
          (0, boundToScalExp =<< x, boundToScalExp =<< y)