{-# LANGUAGE FlexibleContexts #-}
-- | Variation of "Futhark.CodeGen.ImpCode" that contains the notion
-- of a kernel invocation.
module Futhark.CodeGen.ImpCode.Kernels
  ( Program
  , Function
  , FunctionT (Function)
  , Code
  , KernelCode
  , KernelConst (..)
  , KernelConstExp
  , HostOp (..)
  , KernelOp (..)
  , AtomicOp (..)
  , CallKernel (..)
  , MapKernel (..)
  , Kernel (..)
  , LocalMemoryUse
  , KernelUse (..)
  , module Futhark.CodeGen.ImpCode
  , module Futhark.Representation.Kernels.Sizes
  -- * Utility functions
  , getKernels
  )
  where

import Control.Monad.Writer
import Data.List
import qualified Data.Set as S

import Futhark.CodeGen.ImpCode hiding (Function, Code)
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.Representation.Kernels.Sizes
import Futhark.Representation.AST.Attributes.Names
import Futhark.Representation.AST.Pretty ()
import Futhark.Util.Pretty

type Program = Functions HostOp
type Function = Imp.Function HostOp
-- | Host-level code that can call kernels.
type Code = Imp.Code CallKernel
-- | Code inside a kernel.
type KernelCode = Imp.Code KernelOp

-- | A run-time constant related to kernels.
newtype KernelConst = SizeConst VName
                    deriving (Eq, Ord, Show)

-- | An expression whose variables are kernel constants.
type KernelConstExp = PrimExp KernelConst

data HostOp = CallKernel CallKernel
            | GetSize VName VName SizeClass
            | CmpSizeLe VName VName SizeClass Imp.Exp
            | GetSizeMax VName SizeClass
            deriving (Show)

data CallKernel = Map MapKernel
                | AnyKernel Kernel
                | MapTranspose PrimType VName Exp VName Exp Exp Exp Exp Exp Exp
            deriving (Show)

-- | A generic kernel containing arbitrary kernel code.
data MapKernel = MapKernel { mapKernelThreadNum :: VName
                             -- ^ Stm position - also serves as a unique
                             -- name for the kernel.
                           , mapKernelDesc :: String
                           -- ^ Used to name the kernel for readability.
                           , mapKernelBody :: Imp.Code KernelOp
                           , mapKernelUses :: [KernelUse]
                           , mapKernelNumGroups :: DimSize
                           , mapKernelGroupSize :: DimSize
                           , mapKernelSize :: Imp.Exp
                           -- ^ Do not actually execute threads past this.
                           }
                     deriving (Show)

data Kernel = Kernel
              { kernelBody :: Imp.Code KernelOp
              , kernelLocalMemory :: [LocalMemoryUse]
              -- ^ The local memory used by this kernel.

              , kernelUses :: [KernelUse]
                -- ^ The host variables referenced by the kernel.

              , kernelNumGroups :: DimSize
              , kernelGroupSize :: DimSize
              , kernelName :: VName
                -- ^ Unique name for the kernel.
              , kernelDesc :: String
               -- ^ A short descriptive name - should be
               -- alphanumeric and without spaces.
              }
            deriving (Show)

-- ^ In-kernel name and per-workgroup size in bytes.
type LocalMemoryUse = (VName, Either MemSize KernelConstExp)

data KernelUse = ScalarUse VName PrimType
               | MemoryUse VName Imp.DimSize
               | ConstUse VName KernelConstExp
                 deriving (Eq, Show)

getKernels :: Program -> [CallKernel]
getKernels = nubBy sameKernel . execWriter . traverse getFunKernels
  where getFunKernels (CallKernel kernel) =
          tell [kernel]
        getFunKernels _ =
          return ()
        sameKernel (MapTranspose bt1 _ _ _ _ _ _ _ _ _) (MapTranspose bt2 _ _ _ _ _ _ _ _ _) =
          bt1 == bt2
        sameKernel _ _ = False

instance Pretty KernelConst where
  ppr (SizeConst key) = text "get_size" <> parens (ppr key)

instance Pretty KernelUse where
  ppr (ScalarUse name t) =
    text "scalar_copy" <> parens (commasep [ppr name, ppr t])
  ppr (MemoryUse name size) =
    text "mem_copy" <> parens (commasep [ppr name, ppr size])
  ppr (ConstUse name e) =
    text "const" <> parens (commasep [ppr name, ppr e])

instance Pretty HostOp where
  ppr (GetSize dest key size_class) =
    ppr dest <+> text "<-" <+>
    text "get_size" <> parens (commasep [ppr key, ppr size_class])
  ppr (GetSizeMax dest size_class) =
    ppr dest <+> text "<-" <+> text "get_size_max" <> parens (ppr size_class)
  ppr (CmpSizeLe dest name size_class x) =
    ppr dest <+> text "<-" <+>
    text "get_size" <> parens (commasep [ppr name, ppr size_class]) <+>
    text "<" <+> ppr x
  ppr (CallKernel c) =
    ppr c

instance FreeIn HostOp where
  freeIn (CallKernel c) = freeIn c
  freeIn (CmpSizeLe dest name _ x) =
    freeIn dest <> freeIn name <> freeIn x
  freeIn (GetSizeMax dest _) =
    freeIn dest
  freeIn (GetSize dest _ _) =
    freeIn dest

instance Pretty CallKernel where
  ppr (Map k) = ppr k
  ppr (AnyKernel k) = ppr k
  ppr (MapTranspose bt dest destoffset src srcoffset num_arrays size_x size_y in_size out_size) =
    text "mapTranspose" <>
    parens (ppr bt <> comma </>
            ppMemLoc dest destoffset <> comma </>
            ppMemLoc src srcoffset <> comma </>
            ppr num_arrays <> comma <+>
            ppr size_x <> comma <+>
            ppr size_y <> comma <+>
            ppr in_size <> comma <+>
            ppr out_size)
    where ppMemLoc base offset =
            ppr base <+> text "+" <+> ppr offset

instance FreeIn CallKernel where
  freeIn (Map k) = freeIn k
  freeIn (AnyKernel k) = freeIn k
  freeIn (MapTranspose _ dest destoffset src srcoffset num_arrays size_x size_y in_size out_size) =
    freeIn [dest, src] <> freeIn [destoffset, srcoffset] <> freeIn num_arrays <>
    freeIn [size_x, size_y] <> freeIn [in_size, out_size]

instance FreeIn Kernel where
  freeIn kernel = freeIn (kernelBody kernel) <>
                  freeIn [kernelNumGroups kernel, kernelGroupSize kernel]

instance Pretty MapKernel where
  ppr kernel =
    text "mapKernel" <+> brace
    (text "uses" <+> brace (commasep $ map ppr $ mapKernelUses kernel) </>
     text "body" <+> brace (ppr (mapKernelThreadNum kernel) <+>
                            text "<- get_thread_number()" </>
                            ppr (mapKernelBody kernel)))

instance Pretty Kernel where
  ppr kernel =
    text "kernel" <+> brace
    (text "groups" <+> brace (ppr $ kernelNumGroups kernel) </>
     text "group_size" <+> brace (ppr $ kernelGroupSize kernel) </>
     text "local_memory" <+> brace (commasep $
                                    map ppLocalMemory $
                                    kernelLocalMemory kernel) </>
     text "uses" <+> brace (commasep $ map ppr $ kernelUses kernel) </>
     text "body" <+> brace (ppr $ kernelBody kernel))
    where ppLocalMemory (name, Left size) =
            ppr name <+> parens (ppr size <+> text "bytes")
          ppLocalMemory (name, Right size) =
            ppr name <+> parens (ppr size <+> text "bytes (const)")

instance FreeIn MapKernel where
  freeIn kernel =
    mapKernelThreadNum kernel `S.delete` freeIn (mapKernelBody kernel)

data KernelOp = GetGroupId VName Int
              | GetLocalId VName Int
              | GetLocalSize VName Int
              | GetGlobalSize VName Int
              | GetGlobalId VName Int
              | GetLockstepWidth VName
              | Atomic AtomicOp
              | Barrier
              | MemFence
              deriving (Show)

-- Atomic operations return the value stored before the update.
-- This value is stored in the first VName.
data AtomicOp = AtomicAdd VName VName (Count Bytes) Exp
              | AtomicSMax VName VName (Count Bytes) Exp
              | AtomicSMin VName VName (Count Bytes) Exp
              | AtomicUMax VName VName (Count Bytes) Exp
              | AtomicUMin VName VName (Count Bytes) Exp
              | AtomicAnd VName VName (Count Bytes) Exp
              | AtomicOr VName VName (Count Bytes) Exp
              | AtomicXor VName VName (Count Bytes) Exp
              | AtomicCmpXchg VName VName (Count Bytes) Exp Exp
              | AtomicXchg VName VName (Count Bytes) Exp
              deriving (Show)

instance FreeIn AtomicOp where
  freeIn (AtomicAdd _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicSMax _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicSMin _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicUMax _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicUMin _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicAnd _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicOr _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicXor _ arr i x) = freeIn arr <> freeIn i <> freeIn x
  freeIn (AtomicCmpXchg _ arr i x y) = freeIn arr <> freeIn i <> freeIn x <> freeIn y
  freeIn (AtomicXchg _ arr i x) = freeIn arr <> freeIn i <> freeIn x

instance Pretty KernelOp where
  ppr (GetGroupId dest i) =
    ppr dest <+> text "<-" <+>
    text "get_group_id" <> parens (ppr i)
  ppr (GetLocalId dest i) =
    ppr dest <+> text "<-" <+>
    text "get_local_id" <> parens (ppr i)
  ppr (GetLocalSize dest i) =
    ppr dest <+> text "<-" <+>
    text "get_local_size" <> parens (ppr i)
  ppr (GetGlobalSize dest i) =
    ppr dest <+> text "<-" <+>
    text "get_global_size" <> parens (ppr i)
  ppr (GetGlobalId dest i) =
    ppr dest <+> text "<-" <+>
    text "get_global_id" <> parens (ppr i)
  ppr (GetLockstepWidth dest) =
    ppr dest <+> text "<-" <+>
    text "get_lockstep_width()"
  ppr Barrier =
    text "barrier()"
  ppr MemFence =
    text "mem_fence()"
  ppr (Atomic (AtomicAdd old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_add" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicSMax old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_smax" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicSMin old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_smin" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicUMax old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_umax" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicUMin old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_umin" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicAnd old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_and" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicOr old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_or" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicXor old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_xor" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])
  ppr (Atomic (AtomicCmpXchg old arr ind x y)) =
    ppr old <+> text "<-" <+> text "atomic_cmp_xchg" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x, ppr y])
  ppr (Atomic (AtomicXchg old arr ind x)) =
    ppr old <+> text "<-" <+> text "atomic_xchg" <>
    parens (commasep [ppr arr <> brackets (ppr ind), ppr x])

instance FreeIn KernelOp where
  freeIn (Atomic op) = freeIn op
  freeIn _ = mempty

brace :: Doc -> Doc
brace body = text " {" </> indent 2 body </> text "}"