module CodeGen.Render.Function
  ( renderSig
  , haskellSig
  , mkHsname
  , SigType(..)
  ) where

import CodeGen.Prelude
import CodeGen.Types
import CodeGen.Parse.Cases (type2hsreal)
import Control.Arrow ((&&&))
import qualified CodeGen.Render.C as C (render)
import qualified CodeGen.Render.Haskell as Hs (render)
import qualified Data.Char as Ch (toUpper)
import qualified Data.Text as T

data SigType
  = IsFun
  | IsFunPtr
  deriving (Eq, Ord, Show)

ffiPrefix :: SigType -> Text
ffiPrefix = \case
  IsFun    -> "c_"
  IsFunPtr -> "p_"

isPtr :: SigType -> Bool
isPtr f = f == IsFunPtr

comment :: LibType -> SigType -> Text -> [Arg] -> Parsable -> Text
comment lt t hsname args retType = T.intercalate " "
  $  [ "-- |" , hsname , ":", if isPtr t then "Pointer to function :" else "" ]
  <> map argName args
  <> (if null args then [] else ["->"])
  <> [C.render retType]

foreignCall :: Text -> FilePath -> Text
foreignCall cname headerFile = T.intercalate "\""
  [ "foreign import ccall "
  , T.pack headerFile <> cname
  , ""
  ]

haskellSig :: LibType -> Text -> SigType -> TemplateType -> [Arg] -> Parsable -> Text
haskellSig lt hsname st tt args retType = T.intercalate ""
  [ hsname
  , " :: "
  , if isPtr st then "FunPtr (" else ""
  , T.intercalate " -> " typeSignature, retArrow
  , if isPtr st then ")" else ""
  ]
 where
  typeSignature :: [Text]
  typeSignature = case args of
    [Arg (CType CVoid) _] -> []
    args' -> mapMaybe (Hs.render FunctionParam tt . argType) args'

  retArrow :: Text
  retArrow = case Hs.render ReturnValue tt retType of
    Nothing  -> ""
    Just ret -> if null typeSignature then ret else (" -> " <> ret)


mkCname
  :: SigType -> LibType -> ModuleSuffix -> TemplateType -> CodeGenType -> Maybe (LibType, Text) -> Text -> Text
mkCname st lt ms tt cgt mpref funname
  = (if isPtr st then " &" else " ")
  <> identifier
  <> funname
 where
  identifier :: Text
  identifier = case cgt of
    ConcreteFiles -> ""
    GenericFiles ->
      case lt of
        TH     -> "TH"        <> type2hsreal tt <> textSuffix ms <> "_"
        THNN   -> "THNN_"     <> type2hsreal tt
        THCUNN -> "THNN_Cuda" <> type2hsreal tt
        THC -> case mpref of
          -- THC is the only library that has this
          Nothing       -> prefix lt  (isTHCTensor lt ) <> type2hsreal tt <> textSuffix ms <> "_"
          Just (lt', t) -> prefix lt' (isTHCTensor lt') <> type2hsreal tt <> t <> "_"
         where
          isTHCTensor :: LibType -> Bool
          isTHCTensor lt =
            (  textSuffix ms == "Tensor"
            || textSuffix ms == "Storage"
            || textSuffix ms == "TensorMath")

-- | render a haskell function name.
mkHsname :: LibType -> SigType -> Maybe (LibType, Text) -> Text -> Text
mkHsname lt st mpref funname =
  case lt of
    THCUNN -> ffiPrefix st <> funname
    _ -> case mpref of
      Nothing       -> ffiPrefix st <> funname
      Just (lt', _) -> ffiPrefix st <> (if lt' == lt then funname else newName lt')
 where
   newName :: LibType -> Text
   newName lt' = (T.toLower (tshow lt') <>) $ uncurry T.cons $ ((Ch.toUpper . T.head) &&& T.tail) funname


-- | Render a single function signature.
renderSig
  :: SigType
  -> LibType
  -> CodeGenType
  -> FilePath
  -> TemplateType
  -> ModuleSuffix
  -> FileSuffix
  -> (Maybe (LibType, Text), Text, Parsable, [Arg])
  -> Text
renderSig st lt cgt headerFile tt ms fs (mpref, name, retType, args) =
  T.intercalate "\n"
    [ comment lt st hsname args retType
    , foreignCall cname headerFile
    , implementation
    ]
 where
  cname, hsname :: Text
  cname = mkCname st lt ms tt cgt mpref name
  hsname = mkHsname lt st mpref name

  implementation :: Text
  implementation =
    case (lt, cgt, fs, st) of
      -- NOTE: TH and THC functions differ in the THC State. TH does have a concept of THState, which
      -- is unused. Here we render some function aliases which will allow us to maintain unified
      -- backpack signatures.
      --
      -- NOTE2: In the event that we render generic functions from the TH
      -- library which _does not include THTensorRandom_, we want to post-fix these names with a @_@
      -- and use the alias to match the backpack signatures.
      (TH, GenericFiles, _, IsFun) ->
        T.intercalate "\n"
          [ "  " <> (haskellSig lt (mkAliasRefName hsname) st tt args retType)
          , ""
          , "-- | alias of " <> mkAliasRefName hsname <> " with unused argument (for CTHState) to unify backpack signatures."
          , haskellSig lt hsname st tt thArgs retType
          , hsname <> " = const " <> mkAliasRefName hsname
          ]
        where
          -- | TH only (and even then only generic TH files).
          --
          -- 'reference' implying "original haskell function" and alias implying
          -- "backpack-compatible function" as well as "c-native function"
          mkAliasRefName :: Text -> Text
          mkAliasRefName = (<> "_")

          thArgs :: [Arg]
          thArgs =
            if length args == 1 && argType (head args) == (CType CVoid)
            then [statePtr]
            else statePtr:args

          statePtr :: Arg
          statePtr = Arg (Ptr (TenType (Pair (State, TH)))) "state"

      _ -> "  " <> (haskellSig lt hsname st tt args retType)