{-# LANGUAGE FlexibleContexts #-}
module Futhark.Analysis.ScalExp
  ( RelOp0(..)
  , ScalExp(..)
  , scalExpType
  , scalExpSize
  , subExpToScalExp
  , toScalExp
  , expandScalExp
  , LookupVar
  , module Futhark.Representation.Primitive
  )
where

import Data.List
import qualified Data.Set as S
import Data.Maybe
import Data.Monoid ((<>))

import Futhark.Representation.Primitive hiding (SQuot, SRem, SDiv, SMod, SSignum)
import Futhark.Representation.AST hiding (SQuot, SRem, SDiv, SMod, SSignum)
import qualified Futhark.Representation.AST as AST
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Util.Pretty hiding (pretty)

-----------------------------------------------------------------
-- BINARY OPERATORS for Numbers                                --
-- Note that MOD, BAND, XOR, BOR, SHIFTR, SHIFTL not supported --
--   `a SHIFTL/SHIFTR p' can be translated if desired as as    --
--   `a * 2^p' or `a / 2^p                                     --
-----------------------------------------------------------------

-- | Relational operators.
data RelOp0 = LTH0
            | LEQ0
             deriving (Eq, Ord, Enum, Bounded, Show)

-- | Representation of a scalar expression, which is:
--
--    (i) an algebraic expression, e.g., min(a+b, a*b),
--
--   (ii) a relational expression: a+b < 5,
--
--  (iii) a logical expression: e1 and (not (a+b>5)
data ScalExp= Val     PrimValue
            | Id      VName PrimType
            | SNeg    ScalExp
            | SNot    ScalExp
            | SAbs    ScalExp
            | SSignum ScalExp
            | SPlus   ScalExp ScalExp
            | SMinus  ScalExp ScalExp
            | STimes  ScalExp ScalExp
            | SPow    ScalExp ScalExp
            | SDiv ScalExp ScalExp
            | SMod    ScalExp ScalExp
            | SQuot   ScalExp ScalExp
            | SRem    ScalExp ScalExp
            | MaxMin  Bool   [ScalExp]
            | RelExp  RelOp0  ScalExp
            | SLogAnd ScalExp ScalExp
            | SLogOr  ScalExp ScalExp
              deriving (Eq, Ord, Show)

instance Num ScalExp where
  0 + y = y
  x + 0 = x
  x + y = SPlus x y

  x - 0 = x
  x - y = SMinus x y

  0 * _ = 0
  _ * 0 = 0
  1 * y = y
  y * 1 = y
  x * y = STimes x y

  abs = SAbs
  signum = SSignum
  fromInteger = Val . IntValue . Int32Value . fromInteger -- probably not OK
  negate = SNeg

instance Pretty ScalExp where
  pprPrec _ (Val val) = ppr val
  pprPrec _ (Id v _) = ppr v
  pprPrec _ (SNeg e) = text "-" <> pprPrec 9 e
  pprPrec _ (SNot e) = text "not" <+> pprPrec 9 e
  pprPrec _ (SAbs e) = text "abs" <+> pprPrec 9 e
  pprPrec _ (SSignum e) = text "signum" <+> pprPrec 9 e
  pprPrec prec (SPlus x y) = ppBinOp prec "+" 4 4 x y
  pprPrec prec (SMinus x y) = ppBinOp prec "-" 4 10 x y
  pprPrec prec (SPow x y) = ppBinOp prec "^" 6 6 x y
  pprPrec prec (STimes x y) = ppBinOp prec "*" 5 5 x y
  pprPrec prec (SDiv x y) = ppBinOp prec "/" 5 10 x y
  pprPrec prec (SMod x y) = ppBinOp prec "%" 5 10 x y
  pprPrec prec (SQuot x y) = ppBinOp prec "//" 5 10 x y
  pprPrec prec (SRem x y) = ppBinOp prec "%%" 5 10 x y
  pprPrec prec (SLogOr x y) = ppBinOp prec "||" 0 0 x y
  pprPrec prec (SLogAnd x y) = ppBinOp prec "&&" 1 1 x y
  pprPrec prec (RelExp LTH0 e) = ppBinOp prec "<" 2 2 e (0::Int)
  pprPrec prec (RelExp LEQ0 e) = ppBinOp prec "<=" 2 2 e (0::Int)
  pprPrec _ (MaxMin True es) = text "min" <> parens (commasep $ map ppr es)
  pprPrec _ (MaxMin False es) = text "max" <> parens (commasep $ map ppr es)

ppBinOp :: (Pretty a, Pretty b) => Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp p bop precedence rprecedence x y =
  parensIf (p > precedence) $
           pprPrec precedence x <+/>
           text bop <+>
           pprPrec rprecedence y

instance Substitute ScalExp where
  substituteNames subst e =
    case e of Id v t -> Id (substituteNames subst v) t
              Val v -> Val v
              SNeg x -> SNeg $ substituteNames subst x
              SNot x -> SNot $ substituteNames subst x
              SAbs x -> SAbs $ substituteNames subst x
              SSignum x -> SSignum $ substituteNames subst x
              SPlus x y -> substituteNames subst x `SPlus` substituteNames subst y
              SMinus x y -> substituteNames subst x `SMinus` substituteNames subst y
              SPow x y -> substituteNames subst x `SPow` substituteNames subst y
              STimes x y -> substituteNames subst x `STimes` substituteNames subst y
              SDiv x y -> substituteNames subst x `SDiv` substituteNames subst y
              SMod x y -> substituteNames subst x `SMod` substituteNames subst y
              SQuot x y -> substituteNames subst x `SDiv` substituteNames subst y
              SRem x y -> substituteNames subst x `SRem` substituteNames subst y
              MaxMin m es -> MaxMin m $ map (substituteNames subst) es
              RelExp r x -> RelExp r $ substituteNames subst x
              SLogAnd x y -> substituteNames subst x `SLogAnd` substituteNames subst y
              SLogOr x y -> substituteNames subst x `SLogOr` substituteNames subst y

instance Rename ScalExp where
  rename = substituteRename

scalExpType :: ScalExp -> PrimType
scalExpType (Val v) = primValueType v
scalExpType (Id _ t) = t
scalExpType (SNeg    e) = scalExpType e
scalExpType (SNot    _) = Bool
scalExpType (SAbs    e) = scalExpType e
scalExpType (SSignum e) = scalExpType e
scalExpType (SPlus   e _) = scalExpType e
scalExpType (SMinus  e _) = scalExpType e
scalExpType (STimes  e _) = scalExpType e
scalExpType (SDiv e _) = scalExpType e
scalExpType (SMod e _)    = scalExpType e
scalExpType (SPow e _) = scalExpType e
scalExpType (SQuot e _) = scalExpType e
scalExpType (SRem e _) = scalExpType e
scalExpType (SLogAnd _ _) = Bool
scalExpType (SLogOr  _ _) = Bool
scalExpType (RelExp  _ _) = Bool
scalExpType (MaxMin _ []) = IntType Int32 -- arbitrary and probably wrong.
scalExpType (MaxMin _ (e:_)) = scalExpType e

-- | Number of nodes in the scalar expression.
scalExpSize :: ScalExp -> Int
scalExpSize Val{} = 1
scalExpSize Id{} = 1
scalExpSize (SNeg    e) = scalExpSize e
scalExpSize (SNot    e) = scalExpSize e
scalExpSize (SAbs    e) = scalExpSize e
scalExpSize (SSignum e) = scalExpSize e
scalExpSize (SPlus   x y) = scalExpSize x + scalExpSize y
scalExpSize (SMinus  x y) = scalExpSize x + scalExpSize y
scalExpSize (STimes  x y) = scalExpSize x + scalExpSize y
scalExpSize (SDiv x y) = scalExpSize x + scalExpSize y
scalExpSize (SMod x y)    = scalExpSize x + scalExpSize y
scalExpSize (SPow x y) = scalExpSize x + scalExpSize y
scalExpSize (SQuot x y) = scalExpSize x + scalExpSize y
scalExpSize (SRem x y) = scalExpSize x + scalExpSize y
scalExpSize (SLogAnd x y) = scalExpSize x + scalExpSize y
scalExpSize (SLogOr  x y) = scalExpSize x + scalExpSize y
scalExpSize (RelExp  _ x) = scalExpSize x
scalExpSize (MaxMin _ []) = 0
scalExpSize (MaxMin _ es) = sum $ map scalExpSize es

-- | A function that checks whether a variable name corresponds to a
-- scalar expression.
type LookupVar = VName -> Maybe ScalExp

-- | Non-recursively convert a subexpression to a 'ScalExp'.  The
-- (scalar) type of the subexpression must be given in advance.
subExpToScalExp :: SubExp -> PrimType -> ScalExp
subExpToScalExp (Var v) t        = Id v t
subExpToScalExp (Constant val) _ = Val val

toScalExp :: (HasScope t f, Monad f) =>
             LookupVar -> Exp lore -> f (Maybe ScalExp)
toScalExp look (BasicOp (SubExp (Var v)))
  | Just se <- look v =
    return $ Just se
  | otherwise = do
    t <- lookupType v
    case t of
      Prim bt | typeIsOK bt ->
        return $ Just $ Id v bt
      _ ->
        return Nothing
toScalExp _ (BasicOp (SubExp (Constant val)))
  | typeIsOK $ primValueType val =
    return $ Just $ Val val
toScalExp look (BasicOp (CmpOp (CmpSlt _) x y)) =
  Just . RelExp LTH0 <$> (sminus <$> subExpToScalExp' look x <*> subExpToScalExp' look y)
toScalExp look (BasicOp (CmpOp (CmpSle _) x y)) =
  Just . RelExp LEQ0 <$> (sminus <$> subExpToScalExp' look x <*> subExpToScalExp' look y)
toScalExp look (BasicOp (CmpOp (CmpEq t) x y))
  | typeIsOK t = do
  x' <- subExpToScalExp' look x
  y' <- subExpToScalExp' look y
  return $ Just $ case t of
    Bool ->
      SLogAnd x' y' `SLogOr` SLogAnd (SNot x') (SNot y')
    _ ->
      RelExp LEQ0 (x' `sminus` y') `SLogAnd` RelExp LEQ0 (y' `sminus` x')
toScalExp look (BasicOp (BinOp (Sub t) (Constant x) y))
  | typeIsOK $ IntType t, zeroIsh x =
  Just . SNeg <$> subExpToScalExp' look y
toScalExp look (BasicOp (UnOp AST.Not e)) =
  Just . SNot <$> subExpToScalExp' look e
toScalExp look (BasicOp (BinOp bop x y))
  | Just f <- binOpScalExp bop =
  Just <$> (f <$> subExpToScalExp' look x <*> subExpToScalExp' look y)

toScalExp _ _ = return Nothing

typeIsOK :: PrimType -> Bool
typeIsOK = (`elem` Bool : map IntType allIntTypes)

subExpToScalExp' :: HasScope t f =>
                    LookupVar -> SubExp -> f ScalExp
subExpToScalExp' look (Var v)
  | Just se <- look v =
    pure se
  | otherwise =
    withType <$> lookupType v
    where withType (Prim t) =
            subExpToScalExp (Var v) t
          withType t =
            error $ "Cannot create ScalExp from variable " ++ pretty v ++
            " of type " ++ pretty t
subExpToScalExp' _ (Constant val) =
  pure $ Val val

-- | If you have a scalar expression that has been created with
-- incomplete symbol table information, you can use this function to
-- grow its 'Id' leaves.
expandScalExp :: LookupVar -> ScalExp -> ScalExp
expandScalExp _ (Val v) = Val v
expandScalExp look (Id v t) = fromMaybe (Id v t) $ look v
expandScalExp look (SNeg se) = SNeg $ expandScalExp look se
expandScalExp look (SNot se) = SNot $ expandScalExp look se
expandScalExp look (SAbs se) = SAbs $ expandScalExp look se
expandScalExp look (SSignum se) = SSignum $ expandScalExp look se
expandScalExp look (MaxMin b ses) = MaxMin b $ map (expandScalExp look) ses
expandScalExp look (SPlus x y) = SPlus (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SMinus x y) = SMinus (expandScalExp look x) (expandScalExp look y)
expandScalExp look (STimes x y) = STimes (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SDiv x y) = SDiv (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SMod x y) = SMod (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SQuot x y) = SQuot (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SRem x y) = SRem (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SPow x y) = SPow (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SLogAnd x y) = SLogAnd (expandScalExp look x) (expandScalExp look y)
expandScalExp look (SLogOr x y) = SLogOr (expandScalExp look x) (expandScalExp look y)
expandScalExp look (RelExp relop x) = RelExp relop $ expandScalExp look x

-- | "Smart constructor" that checks whether we are subtracting zero,
-- and if so just returns the first argument.
sminus :: ScalExp -> ScalExp -> ScalExp
sminus x (Val v) | zeroIsh v = x
sminus x y = x `SMinus` y

 -- XXX: Only integers and booleans, OK?
binOpScalExp :: BinOp -> Maybe (ScalExp -> ScalExp -> ScalExp)
binOpScalExp bop = fmap snd . find ((==bop) . fst) $
                   concatMap intOps allIntTypes ++
                   [ (LogAnd, SLogAnd), (LogOr, SLogOr) ]
  where intOps t = [ (Add t, SPlus)
                   , (Sub t, SMinus)
                   , (Mul t, STimes)
                   , (AST.SDiv t, SDiv)
                   , (AST.Pow t, SPow)
                   ]

instance FreeIn ScalExp where
  freeIn (Val   _) = mempty
  freeIn (Id i _)  = S.singleton i
  freeIn (SNeg  e) = freeIn e
  freeIn (SNot  e) = freeIn e
  freeIn (SAbs  e) = freeIn e
  freeIn (SSignum e) = freeIn e
  freeIn (SPlus x y)   = freeIn x <> freeIn y
  freeIn (SMinus x y)  = freeIn x <> freeIn y
  freeIn (SPow x y)    = freeIn x <> freeIn y
  freeIn (STimes x y)  = freeIn x <> freeIn y
  freeIn (SDiv x y) = freeIn x <> freeIn y
  freeIn (SMod x y) = freeIn x <> freeIn y
  freeIn (SQuot x y) = freeIn x <> freeIn y
  freeIn (SRem x y) = freeIn x <> freeIn y
  freeIn (SLogOr x y)  = freeIn x <> freeIn y
  freeIn (SLogAnd x y) = freeIn x <> freeIn y
  freeIn (RelExp LTH0 e) = freeIn e
  freeIn (RelExp LEQ0 e) = freeIn e
  freeIn (MaxMin _  es) = mconcat $ map freeIn es