{-# LANGUAGE OverloadedLists #-}
module CodeGen.Parse.Cases
  ( type2real
  , type2hsreal
  , type2accreal
  , checkFunction
  ) where

import CodeGen.Prelude hiding (char)
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as S

import CodeGen.Types hiding (prefix)

uchar, long, char :: (HsRep -> CRep -> x) -> x
uchar  cons = cons "Foreign.C.Types.CUChar"  "unsigned char"
long   cons = cons "Foreign.C.Types.CLong"   "long"
char   cons = cons "Foreign.C.Types.CChar"   "char"
double cons = cons "Foreign.C.Types.CDouble" "double"
float  cons = cons "Foreign.C.Types.CFloat"  "float"
int    cons = cons "Foreign.C.Types.CInt"    "int"
short  cons = cons "Foreign.C.Types.CShort"  "short"
half   cons = cons (HsRep $ prefix TH "Half") "THHalf"

prefix :: LibType -> Text -> Text
prefix lt t = "Torch.Types." <> tshow lt <> ".C"<> tshow lt <> t

signatureAliases :: LibType -> TemplateType -> Maybe (CTensor, CReal, CAccReal, CStorage)
signatureAliases lt = \case
  GenByte   -> Just (mkTuple "Byte"   uchar  long)
  GenChar   -> Just (mkTuple "Char"   char   long)
  GenDouble -> Just (mkTuple "Double" double double)
  GenFloat  -> Just (mkTuple "Float"  float  double)
  GenHalf   -> Just (mkTuple "Half"   half   float)
  GenInt    -> Just (mkTuple "Int"    int    long)
  GenLong   -> Just (mkTuple "Long"   long   long)
  GenShort  -> Just (mkTuple "Short"  short  long)
  GenNothing -> Nothing
 where
  mkRep :: (HsRep -> CRep -> x) -> Text -> Text -> x
  mkRep cons suffix t = cons (HsRep . prefix lt $ t <> suffix) (CRep $ t <> suffix)

  mkCTensor :: Text -> CTensor
  mkCTensor = mkRep CTensor "Tensor"

  mkCStorage :: Text -> CStorage
  mkCStorage = mkRep CStorage "Storage"

  mkTuple :: Text -> ((HsRep -> CRep -> CReal) -> CReal) -> ((HsRep -> CRep -> CAccReal) -> CAccReal) -> (CTensor, CReal, CAccReal, CStorage)
  mkTuple t r ac = (mkCTensor t, r CReal, ac CAccReal, mkCStorage t)


type2real :: LibType -> TemplateType -> Text
type2real lt t = case signatureAliases lt t of
  Just (_, CReal hs _, _, _) -> stripModule hs
  Nothing -> "" -- impossible "TemplateType is concrete and should not have been called"

-- | spliced text to use for function names
type2hsreal :: TemplateType -> Text
type2hsreal = \case
  GenByte    -> "Byte"
  GenChar    -> "Char"
  GenDouble  -> "Double"
  GenFloat   -> "Float"
  GenHalf    -> "Half"
  GenInt     -> "Int"
  GenLong    -> "Long"
  GenShort   -> "Short"
  GenNothing -> ""


type2accreal :: LibType -> TemplateType -> Text
type2accreal lt t = case signatureAliases lt t of
  Just (_, _, CAccReal hs _, _) -> stripModule hs
  Nothing -> "" -- impossible "TemplateType is concrete and should not have been called"


tensorMathCases :: LibType -> HashMap FunctionName (HashSet TemplateType)
tensorMathCases _ =
  [ ("abs",     [GenShort, GenInt, GenLong, GenFloat, GenDouble])
  , ("sigmoid", [GenFloat, GenDouble])
  , ("log",     [GenFloat, GenDouble])
  , ("lgamma",  [GenFloat, GenDouble])
  , ("log1p",   [GenFloat, GenDouble])
  , ("exp",     [GenFloat, GenDouble])
  , ("erf",     [GenFloat, GenDouble])
  , ("erfinv",  [GenFloat, GenDouble])
  , ("cos",     [GenFloat, GenDouble])
  , ("acos",    [GenFloat, GenDouble])
  , ("cosh",    [GenFloat, GenDouble])
  , ("sin",     [GenFloat, GenDouble])
  , ("asin",    [GenFloat, GenDouble])
  , ("sinh",    [GenFloat, GenDouble])
  , ("tan",     [GenFloat, GenDouble])
  , ("atan",    [GenFloat, GenDouble])
  , ("atan2",   [GenFloat, GenDouble])
  , ("tanh",    [GenFloat, GenDouble])
  , ("pow",     [GenFloat, GenDouble])
  , ("tpow",    [GenFloat, GenDouble])
  , ("sqrt",    [GenFloat, GenDouble])
  , ("rsqrt",   [GenFloat, GenDouble])
  , ("ceil",    [GenFloat, GenDouble])
  , ("floor",   [GenFloat, GenDouble])
  , ("round",   [GenFloat, GenDouble])
  , ("trunc",   [GenFloat, GenDouble])
  , ("frac",    [GenFloat, GenDouble])
  , ("lerp",    [GenFloat, GenDouble])
  , ("mean",    [GenFloat, GenDouble])
  , ("std",     [GenFloat, GenDouble])
  , ("var",     [GenFloat, GenDouble])
  , ("norm",    [GenFloat, GenDouble])
  , ("renorm",  [GenFloat, GenDouble])
  , ("dist",    [GenFloat, GenDouble])
  , ("histc",   [GenFloat, GenDouble])
  , ("bhistc",  [GenFloat, GenDouble])
  , ("meanall", [GenFloat, GenDouble])
  , ("varall",  [GenFloat, GenDouble])
  , ("stdall",  [GenFloat, GenDouble])
  , ("normall", [GenFloat, GenDouble])
  , ("linspace",[GenFloat, GenDouble])
  , ("logspace",[GenFloat, GenDouble])
  , ("rand",    [GenFloat, GenDouble])
  , ("randn",   [GenFloat, GenDouble])
  , ("logicalall", [GenByte])
  , ("logicalany", [GenByte])

  -- cinv doesn't seem to be excluded by the preprocessor, yet is not
  -- implemented for Int. TODO - file issue report?
  , ("cinv", [GenFloat, GenDouble])
  , ("neg",  [GenFloat, GenDouble, GenLong, GenShort, GenInt])
  ]

tensorRandomCases :: LibType -> HashMap FunctionName (HashSet TemplateType)
tensorRandomCases _ =
  [ ("uniform",        [GenFloat, GenDouble])
  , ("normal",         [GenFloat, GenDouble])
  , ("normal_means",   [GenFloat, GenDouble])
  , ("normal_stddevs", [GenFloat, GenDouble])
  , ("normal_means_stddevs", [GenFloat, GenDouble])
  , ("exponential",    [GenFloat, GenDouble])
  , ("standard_gamma", [GenFloat, GenDouble])

  , ("digamma",        [GenFloat, GenDouble])
  , ("trigamma",       [GenFloat, GenDouble])
  , ("polygamma",      [GenFloat, GenDouble])
  , ("expm1",          [GenFloat, GenDouble])
  , ("dirichlet_grad", [GenFloat, GenDouble])
  , ("cauchy",         [GenFloat, GenDouble])
  , ("logNormal",      [GenFloat, GenDouble])
  , ("multinomial",    [GenFloat, GenDouble])
  , ("multinomialAliasSetup", [GenFloat, GenDouble])
  , ("multinomialAliasDraw",  [GenFloat, GenDouble])

  , ("getRNGState", [GenByte])
  , ("setRNGState", [GenByte])

  -- This keeps appearing but isn't in TH. TODO: find out what is happening
  , ("bernoulli_Tensor", [])
  ]


-- TODO: check lapack bindings - not obvious from source, but there are
-- problems loading shared library with these functions for Byte
tensorLapackCases :: LibType -> HashMap FunctionName (HashSet TemplateType)
tensorLapackCases _ =
  [ ("gesv",   [GenFloat, GenDouble])
  , ("trtrs",  [GenFloat, GenDouble])
  , ("gels",   [GenFloat, GenDouble])
  , ("syev",   [GenFloat, GenDouble])
  , ("geev",   [GenFloat, GenDouble])
  , ("gesvd",  [GenFloat, GenDouble])
  , ("gesvd2", [GenFloat, GenDouble])
  , ("getrf",  [GenFloat, GenDouble])
  , ("getrs",  [GenFloat, GenDouble])
  , ("getri",  [GenFloat, GenDouble])
  , ("potrf",  [GenFloat, GenDouble])
  , ("potrs",  [GenFloat, GenDouble])
  , ("potri",  [GenFloat, GenDouble])
  , ("qr",     [GenFloat, GenDouble])
  , ("geqrf",  [GenFloat, GenDouble])
  , ("orgqr",  [GenFloat, GenDouble])
  , ("ormqr",  [GenFloat, GenDouble])
  , ("pstrf",  [GenFloat, GenDouble])
  , ("btrifact",  [GenFloat, GenDouble])
  , ("btrisolve", [GenFloat, GenDouble])
  -- , ("geev", [GenFloat, GenDouble])
  -- , ("gels", [GenFloat, GenDouble])
  -- , ("gesv", [GenFloat, GenDouble])
  -- , ("gesvd", [GenFloat, GenDouble])
  ]

storageCases :: LibType -> HashMap FunctionName (HashSet TemplateType)
storageCases _ =
  [ ("elementSize", [])
  ]

storageCopyCases :: LibType -> HashMap FunctionName (HashSet TemplateType)
storageCopyCases THC =
  [ ("copyCudaHalf", [GenHalf])
  ]
storageCopyCases _ = mempty

tensorBlasCases :: LibType -> HashMap FunctionName (HashSet TemplateType)
tensorBlasCases THC =
  [ ("dot", [GenFloat, GenDouble])
  , ("addmv", [GenFloat, GenDouble])
  , ("addmm", [GenFloat, GenDouble])
  , ("addr", [GenFloat, GenDouble])
  , ("addbmm", [GenFloat, GenDouble])
  , ("baddbmm", [GenFloat, GenDouble])
  ]
tensorBlasCases _ = mempty



checkMath :: LibType -> TemplateType -> FunctionName -> Bool
checkMath = checkMap tensorMathCases

checkRandom :: LibType -> TemplateType -> FunctionName -> Bool
checkRandom = checkMap tensorRandomCases

checkLapack :: LibType -> TemplateType -> FunctionName -> Bool
checkLapack = checkMap tensorLapackCases

checkStorage :: LibType -> TemplateType -> FunctionName -> Bool
checkStorage = checkMap storageCases

checkStorageCopy :: LibType -> TemplateType -> FunctionName -> Bool
checkStorageCopy = checkMap storageCopyCases

checkTensorBlasCases :: LibType -> TemplateType -> FunctionName -> Bool
checkTensorBlasCases = checkMap tensorBlasCases

checkMap
  :: (LibType -> HashMap FunctionName (HashSet TemplateType))
  -> LibType
  -> TemplateType
  -> FunctionName
  -> Bool
checkMap map lt tt n = maybe True (tt `S.member`) (M.lookup n (map lt))


-- | Warning a function that doesn't exist will return True by default
--
-- TODO: make this safer.
-- (stites): to make this safer I think we need to invert these maps so that we
--           are given function names instead of doing membership checks.
checkFunction :: LibType -> TemplateType -> FunctionName -> Bool
checkFunction lt tt fn
  =  checkMath   lt tt fn
  && checkRandom lt tt fn
  && checkLapack lt tt fn
  && checkStorage lt tt fn
  && checkStorageCopy lt tt fn
  && checkTensorBlasCases lt tt fn

test :: IO ()
test = do
  print $ checkFunction TH GenByte  "logicalany"
  print $ checkFunction TH GenFloat "logicalany"
  print $ checkFunction TH GenByte  "multinomial"
  print $ checkFunction TH GenFloat "multinomial"