{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.Construct
  ( letSubExp
  , letSubExps
  , letExp
  , letExps
  , letTupExp
  , letTupExp'
  , letInPlace

  , eSubExp
  , eIf
  , eIf'
  , eBinOp
  , eCmpOp
  , eConvOp
  , eNegate
  , eNot
  , eAbs
  , eSignum
  , eCopy
  , eAssert
  , eBody
  , eLambda
  , eDivRoundingUp
  , eRoundToMultipleOf
  , eSliceArray
  , eSplitArray

  , eWriteArray

  , asIntZ, asIntS

  , resultBody
  , resultBodyM
  , insertStmsM
  , mapResult

  , foldBinOp
  , binOpLambda
  , cmpOpLambda
  , fullSlice
  , fullSliceNum
  , isFullSlice
  , ifCommon

  , module Futhark.Binder

  -- * Result types
  , instantiateShapes
  , instantiateShapes'
  , instantiateShapesFromIdentList
  , instantiateExtTypes
  , instantiateIdents
  , removeExistentials

  -- * Convenience
  , simpleMkLetNames

  , ToExp(..)
  )
where

import qualified Data.Map.Strict as M
import Data.Loc (SrcLoc)
import Data.List
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer

import Futhark.Representation.AST
import Futhark.MonadFreshNames
import Futhark.Binder
import Futhark.Util

letSubExp :: MonadBinder m =>
             String -> Exp (Lore m) -> m SubExp
letSubExp _ (BasicOp (SubExp se)) = return se
letSubExp desc e = Var <$> letExp desc e

letExp :: MonadBinder m =>
          String -> Exp (Lore m) -> m VName
letExp _ (BasicOp (SubExp (Var v))) =
  return v
letExp desc e = do
  n <- length <$> expExtType e
  vs <- replicateM n $ newVName desc
  idents <- letBindNames vs e
  case idents of
    [ident] -> return $ identName ident
    _       -> fail $ "letExp: tuple-typed expression given:\n" ++ pretty e

letInPlace :: MonadBinder m =>
              String -> VName -> Slice SubExp -> Exp (Lore m)
           -> m VName
letInPlace desc src slice e = do
  tmp <- letSubExp (desc ++ "_tmp") e
  letExp desc $ BasicOp $ Update src slice tmp

letSubExps :: MonadBinder m =>
              String -> [Exp (Lore m)] -> m [SubExp]
letSubExps desc = mapM $ letSubExp desc

letExps :: MonadBinder m =>
           String -> [Exp (Lore m)] -> m [VName]
letExps desc = mapM $ letExp desc

letTupExp :: (MonadBinder m) =>
             String -> Exp (Lore m)
          -> m [VName]
letTupExp _ (BasicOp (SubExp (Var v))) =
  return [v]
letTupExp name e = do
  numValues <- length <$> expExtType e
  names <- replicateM numValues $ newVName name
  map identName <$> letBindNames names e

letTupExp' :: (MonadBinder m) =>
              String -> Exp (Lore m)
           -> m [SubExp]
letTupExp' _ (BasicOp (SubExp se)) = return [se]
letTupExp' name ses = map Var <$> letTupExp name ses

eSubExp :: MonadBinder m =>
           SubExp -> m (Exp (Lore m))
eSubExp = pure . BasicOp . SubExp

eIf :: (MonadBinder m, BranchType (Lore m) ~ ExtType) =>
       m (Exp (Lore m)) -> m (Body (Lore m)) -> m (Body (Lore m))
    -> m (Exp (Lore m))
eIf ce te fe = eIf' ce te fe IfNormal

-- | As 'eIf', but an 'IfSort' can be given.
eIf' :: (MonadBinder m, BranchType (Lore m) ~ ExtType) =>
        m (Exp (Lore m)) -> m (Body (Lore m)) -> m (Body (Lore m))
     -> IfSort
     -> m (Exp (Lore m))
eIf' ce te fe if_sort = do
  ce' <- letSubExp "cond" =<< ce
  te' <- insertStmsM te
  fe' <- insertStmsM fe
  -- We need to construct the context.
  ts <- generaliseExtTypes <$> bodyExtType te' <*> bodyExtType fe'
  te'' <- addContextForBranch ts te'
  fe'' <- addContextForBranch ts fe'
  return $ If ce' te'' fe'' $ IfAttr ts if_sort
  where addContextForBranch ts (Body _ stms val_res) = do
          body_ts <- extendedScope (traverse subExpType val_res) stmsscope
          let ctx_res = map snd $ sortOn fst $
                        M.toList $ shapeExtMapping ts body_ts
          mkBodyM stms $ ctx_res++val_res
            where stmsscope = scopeOf stms

eBinOp :: MonadBinder m =>
          BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m))
       -> m (Exp (Lore m))
eBinOp op x y = do
  x' <- letSubExp "x" =<< x
  y' <- letSubExp "y" =<< y
  return $ BasicOp $ BinOp op x' y'

eCmpOp :: MonadBinder m =>
          CmpOp -> m (Exp (Lore m)) -> m (Exp (Lore m))
       -> m (Exp (Lore m))
eCmpOp op x y = do
  x' <- letSubExp "x" =<< x
  y' <- letSubExp "y" =<< y
  return $ BasicOp $ CmpOp op x' y'

eConvOp :: MonadBinder m =>
           ConvOp -> m (Exp (Lore m))
        -> m (Exp (Lore m))
eConvOp op x = do
  x' <- letSubExp "x" =<< x
  return $ BasicOp $ ConvOp op x'

eNegate :: MonadBinder m =>
           m (Exp (Lore m)) -> m (Exp (Lore m))
eNegate em = do
  e <- em
  e' <- letSubExp "negate_arg" e
  t <- subExpType e'
  case t of
    Prim (IntType int_t) ->
      return $ BasicOp $
      BinOp (Sub int_t) (intConst int_t 0) e'
    Prim (FloatType float_t) ->
      return $ BasicOp $
      BinOp (FSub float_t) (floatConst float_t 0) e'
    _ ->
      fail $ "eNegate: operand " ++ pretty e ++ " has invalid type."

eNot :: MonadBinder m =>
        m (Exp (Lore m)) -> m (Exp (Lore m))
eNot e = BasicOp . UnOp Not <$> (letSubExp "not_arg" =<< e)

eAbs :: MonadBinder m =>
        m (Exp (Lore m)) -> m (Exp (Lore m))
eAbs em = do
  e <- em
  e' <- letSubExp "abs_arg" e
  t <- subExpType e'
  case t of
    Prim (IntType int_t) ->
      return $ BasicOp $ UnOp (Abs int_t) e'
    Prim (FloatType float_t) ->
      return $ BasicOp $ UnOp (FAbs float_t) e'
    _ ->
      fail $ "eAbs: operand " ++ pretty e ++ " has invalid type."

eSignum :: MonadBinder m =>
        m (Exp (Lore m)) -> m (Exp (Lore m))
eSignum em = do
  e <- em
  e' <- letSubExp "signum_arg" e
  t <- subExpType e'
  case t of
    Prim (IntType int_t) ->
      return $ BasicOp $ UnOp (SSignum int_t) e'
    _ ->
      fail $ "eSignum: operand " ++ pretty e ++ " has invalid type."

eCopy :: MonadBinder m =>
         m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy e = BasicOp . Copy <$> (letExp "copy_arg" =<< e)

eAssert :: MonadBinder m =>
         m (Exp (Lore m)) -> ErrorMsg SubExp -> SrcLoc -> m (Exp (Lore m))
eAssert e msg loc = do e' <- letSubExp "assert_arg" =<< e
                       return $ BasicOp $ Assert e' msg (loc, mempty)

eBody :: (MonadBinder m) =>
         [m (Exp (Lore m))]
      -> m (Body (Lore m))
eBody es = insertStmsM $ do
             es' <- sequence es
             xs <- mapM (letTupExp "x") es'
             mkBodyM mempty $ map Var $ concat xs

eLambda :: MonadBinder m =>
           Lambda (Lore m) -> [m (Exp (Lore m))] -> m [SubExp]
eLambda lam args = do zipWithM_ bindParam (lambdaParams lam) args
                      bodyBind $ lambdaBody lam
  where bindParam param arg = letBindNames_ [paramName param] =<< arg

-- | Note: unsigned division.
eDivRoundingUp :: MonadBinder m =>
                  IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eDivRoundingUp t x y =
  eBinOp (SQuot t) (eBinOp (Add t) x (eBinOp (Sub t) y (eSubExp one))) y
  where one = intConst t 1

eRoundToMultipleOf :: MonadBinder m =>
                      IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf t x d =
  ePlus x (eMod (eMinus d (eMod x d)) d)
  where eMod = eBinOp (SMod t)
        eMinus = eBinOp (Sub t)
        ePlus = eBinOp (Add t)

-- | Construct an 'Index' expressions that slices an array with unit stride.
eSliceArray :: MonadBinder m =>
               Int -> VName -> m (Exp (Lore m)) -> m (Exp (Lore m))
            -> m (Exp (Lore m))
eSliceArray d arr i n = do
  arr_t <- lookupType arr
  let skips = map (slice (constant (0::Int32))) $ take d $ arrayDims arr_t
  i' <- letSubExp "slice_i" =<< i
  n' <- letSubExp "slice_n" =<< n
  return $ BasicOp $ Index arr $ fullSlice arr_t $ skips ++ [slice i' n']
  where slice j m = DimSlice j m (constant (1::Int32))

-- | Construct an 'Index' expressions that splits an array in different parts along the outer dimension.
eSplitArray :: MonadBinder m =>
               VName -> [m (Exp (Lore m))] -> m [Exp (Lore m)]
eSplitArray arr sizes = do
  sizes' <- mapM (letSubExp "split_size") =<< sequence sizes
  -- Compute the starting offset for each slice.
  (_, offsets) <- mapAccumLM increase (intConst Int32 0) sizes'
  zipWithM (eSliceArray 0 arr) (map eSubExp offsets) (map eSubExp sizes')
  where increase offset size = do
          offset' <- letSubExp "offset" $ BasicOp $ BinOp (Add Int32) offset size
          return (offset', offset)

-- | Write to an index of the array, if within bounds.  Otherwise,
-- nothing.  Produces the updated array.
eWriteArray :: (MonadBinder m, BranchType (Lore m) ~ ExtType) =>
               VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
            -> m (Exp (Lore m))
eWriteArray arr is v = do
  arr_t <- lookupType arr
  let ws = arrayDims arr_t
  is' <- mapM (letSubExp "write_i") =<< sequence is
  v' <- letSubExp "write_v" =<< v
  let checkDim w i = do
        less_than_zero <- letSubExp "less_than_zero" $
          BasicOp $ CmpOp (CmpSlt Int32) i (constant (0::Int32))
        greater_than_size <- letSubExp "greater_than_size" $
          BasicOp $ CmpOp (CmpSle Int32) w i
        letSubExp "outside_bounds_dim" $
          BasicOp $ BinOp LogOr less_than_zero greater_than_size

  outside_bounds <-
    letSubExp "outside_bounds" =<<
    foldBinOp LogOr (constant False) =<<
    zipWithM checkDim ws is'

  outside_bounds_branch <- insertStmsM $ resultBodyM [Var arr]

  in_bounds_branch <- insertStmsM $ do
    res <- letInPlace "write_out_inside_bounds" arr
           (fullSlice arr_t (map DimFix is')) $ BasicOp $ SubExp v'
    resultBodyM [Var res]

  return $
    If outside_bounds outside_bounds_branch in_bounds_branch $
    ifCommon [arr_t]

-- | Sign-extend to the given integer type.
asIntS :: MonadBinder m => IntType -> SubExp -> m SubExp
asIntS = asInt SExt

-- | Zero-extend to the given integer type.
asIntZ :: MonadBinder m => IntType -> SubExp -> m SubExp
asIntZ = asInt ZExt

asInt :: MonadBinder m =>
         (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt ext to_it e = do
  e_t <- subExpType e
  case e_t of
    Prim (IntType from_it)
      | to_it == from_it -> return e
      | otherwise -> letSubExp s $ BasicOp $ ConvOp (ext from_it to_it) e
    _ -> fail "asInt: wrong type"
  where s = case e of Var v -> baseString v
                      _     -> "to_" ++ pretty to_it


-- | Apply a binary operator to several subexpressions.  A left-fold.
foldBinOp :: MonadBinder m =>
             BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp _ ne [] =
  return $ BasicOp $ SubExp ne
foldBinOp bop ne (e:es) =
  eBinOp bop (pure $ BasicOp $ SubExp e) (foldBinOp bop ne es)

-- | Create a two-parameter lambda whose body applies the given binary
-- operation to its arguments.  It is assumed that both argument and
-- result types are the same.  (This assumption should be fixed at
-- some point.)
binOpLambda :: (MonadBinder m, Bindable (Lore m)) =>
               BinOp -> PrimType -> m (Lambda (Lore m))
binOpLambda bop t = binLambda (BinOp bop) t t

-- | As 'binOpLambda', but for 'CmpOp's.
cmpOpLambda :: (MonadBinder m, Bindable (Lore m)) =>
               CmpOp -> PrimType -> m (Lambda (Lore m))
cmpOpLambda cop t = binLambda (CmpOp cop) t Bool

binLambda :: (MonadBinder m, Bindable (Lore m)) =>
             (SubExp -> SubExp -> BasicOp (Lore m)) -> PrimType -> PrimType
          -> m (Lambda (Lore m))
binLambda bop arg_t ret_t = do
  x   <- newVName "x"
  y   <- newVName "y"
  body <- insertStmsM $ do
    res <- letSubExp "res" $ BasicOp $ bop (Var x) (Var y)
    return $ resultBody [res]
  return Lambda {
             lambdaParams     = [Param x (Prim arg_t),
                                 Param y (Prim arg_t)]
           , lambdaReturnType = [Prim ret_t]
           , lambdaBody       = body
           }

-- | @fullSlice t slice@ returns @slice@, but with 'DimSlice's of
-- entire dimensions appended to the full dimensionality of @t@.  This
-- function is used to turn incomplete indexing complete, as required
-- by 'Index'.
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice t slice =
  slice ++
  map (\d -> DimSlice (constant (0::Int32)) d (constant (1::Int32)))
  (drop (length slice) $ arrayDims t)

-- | Like 'fullSlice', but the dimensions are simply numeric.
fullSliceNum :: Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum dims slice =
  slice ++ map (\d -> DimSlice 0 d 1) (drop (length slice) dims)

-- | Does the slice describe the full size of the array?  The most
-- obvious such slice is one that 'DimSlice's the full span of every
-- dimension, but also one that fixes all unit dimensions.
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice shape slice = and $ zipWith allOfIt (shapeDims shape) slice
  where allOfIt (Constant v) DimFix{} = oneIsh v
        allOfIt d (DimSlice _ n _) = d == n
        allOfIt _ _ = False

ifCommon :: [Type] -> IfAttr ExtType
ifCommon ts = IfAttr (staticShapes ts) IfNormal

-- | Conveniently construct a body that contains no bindings.
resultBody :: Bindable lore => [SubExp] -> Body lore
resultBody = mkBody mempty

-- | Conveniently construct a body that contains no bindings - but
-- this time, monadically!
resultBodyM :: MonadBinder m =>
               [SubExp]
            -> m (Body (Lore m))
resultBodyM = mkBodyM mempty

-- | Evaluate the action, producing a body, then wrap it in all the
-- bindings it created using 'addStm'.
insertStmsM :: (MonadBinder m) =>
               m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM m = do
  (Body _ bnds res, otherbnds) <- collectStms m
  mkBodyM (otherbnds <> bnds) res

-- | Change that result where evaluation of the body would stop.  Also
-- change type annotations at branches.
mapResult :: Bindable lore =>
             (Result -> Body lore) -> Body lore -> Body lore
mapResult f (Body _ bnds res) =
  let Body _ bnds2 newres = f res
  in mkBody (bnds<>bnds2) newres

-- | Instantiate all existential parts dimensions of the given
-- type, using a monadic action to create the necessary 'SubExp's.
-- You should call this function within some monad that allows you to
-- collect the actions performed (say, 'Writer').
instantiateShapes :: Monad m =>
                     (Int -> m SubExp)
                  -> [TypeBase ExtShape u]
                  -> m [TypeBase Shape u]
instantiateShapes f ts = evalStateT (mapM instantiate ts) M.empty
  where instantiate t = do
          shape <- mapM instantiate' $ shapeDims $ arrayShape t
          return $ t `setArrayShape` Shape shape
        instantiate' (Ext x) = do
          m <- get
          case M.lookup x m of
            Just se -> return se
            Nothing -> do se <- lift $ f x
                          put $ M.insert x se m
                          return se
        instantiate' (Free se) = return se

instantiateShapes' :: MonadFreshNames m =>
                      [TypeBase ExtShape u]
                   -> m ([TypeBase Shape u], [Ident])
instantiateShapes' ts =
  runWriterT $ instantiateShapes instantiate ts
  where instantiate _ = do v <- lift $ newIdent "size" $ Prim int32
                           tell [v]
                           return $ Var $ identName v

instantiateShapesFromIdentList :: [Ident] -> [ExtType] -> [Type]
instantiateShapesFromIdentList idents ts =
  evalState (instantiateShapes instantiate ts) idents
  where instantiate _ = do
          idents' <- get
          case idents' of
            [] -> fail "instantiateShapesFromIdentList: insufficiently sized context"
            ident:idents'' -> do put idents''
                                 return $ Var $ identName ident

instantiateExtTypes :: [VName] -> [ExtType] -> [Ident]
instantiateExtTypes names rt =
  let (shapenames,valnames) = splitAt (shapeContextSize rt) names
      shapes = [ Ident name (Prim int32) | name <- shapenames ]
      valts  = instantiateShapesFromIdentList shapes rt
      vals   = [ Ident name t | (name,t) <- zip valnames valts ]
  in shapes ++ vals

instantiateIdents :: [VName] -> [ExtType]
                  -> Maybe ([Ident], [Ident])
instantiateIdents names ts
  | let n = shapeContextSize ts,
    n + length ts == length names = do
    let (context, vals) = splitAt n names
        nextShape _ = do
          (context', remaining) <- get
          case remaining of []   -> lift Nothing
                            x:xs -> do let ident = Ident x (Prim int32)
                                       put (context'++[ident], xs)
                                       return $ Var x
    (ts', (context', _)) <-
      runStateT (instantiateShapes nextShape ts) ([],context)
    return (context', zipWith Ident vals ts')
  | otherwise = Nothing

removeExistentials :: ExtType -> Type -> Type
removeExistentials t1 t2 =
  t1 `setArrayDims`
  zipWith nonExistential
  (shapeDims $ arrayShape t1)
  (arrayDims t2)
  where nonExistential (Ext _)    dim = dim
        nonExistential (Free dim) _   = dim

-- | Can be used as the definition of 'mkLetNames' for a 'Bindable'
-- instance for simple representations.
simpleMkLetNames :: (ExpAttr lore ~ (), LetAttr lore ~ Type,
                     MonadFreshNames m, TypedOp (Op lore), HasScope lore m) =>
                    [VName] -> Exp lore -> m (Stm lore)
simpleMkLetNames names e = do
  et <- expExtType e
  (ts, shapes) <- instantiateShapes' et
  let shapeElems = [ PatElem shape shapet | Ident shape shapet <- shapes ]
  let valElems = zipWith PatElem names ts
  return $ Let (Pattern shapeElems valElems) (StmAux mempty ()) e

-- | Instances of this class can be converted to Futhark expressions
-- within a 'MonadBinder'.
class ToExp a where
  toExp :: MonadBinder m => a -> m (Exp (Lore m))

instance ToExp SubExp where
  toExp = return . BasicOp . SubExp

instance ToExp VName where
  toExp = return . BasicOp . SubExp . Var