{-# LANGUAGE BangPatterns,
             CPP,
             OverloadedStrings,
             DataKinds,
             FlexibleContexts,
             GADTs,
             KindSignatures,
             RankNTypes,
             ScopedTypeVariables #-}

----------------------------------------------------------------
--                                                    2016.06.23
-- |
-- Module      :  Language.Hakaru.CodeGen.Wrapper
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  zsulliva@indiana.edu
-- Stability   :  experimental
-- Portability :  GHC-only
--
--   The purpose of the wrapper is to intelligently wrap CStatements
-- into CFunctions and CProgroms to be printed by 'hkc'
--
----------------------------------------------------------------


module Language.Hakaru.CodeGen.Wrapper
  ( wrapProgram
  , PrintConfig(..)
  ) where

import           Language.Hakaru.Syntax.ABT
import           Language.Hakaru.Syntax.AST
import           Language.Hakaru.Syntax.IClasses
import           Language.Hakaru.Syntax.TypeCheck
import           Language.Hakaru.Syntax.TypeOf (typeOf)
import           Language.Hakaru.Types.Sing
import           Language.Hakaru.CodeGen.CodeGenMonad
import           Language.Hakaru.CodeGen.Flatten
import           Language.Hakaru.CodeGen.Types
import           Language.Hakaru.CodeGen.AST
import           Language.Hakaru.Types.DataKind (Hakaru(..))

import           Control.Monad.State.Strict
import           Prelude            as P hiding (unlines,exp)


#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative
#endif


-- | wrapProgram is the top level C codegen. Depending on the type a program
--   will have a different construction. It will produce an effect in the
--   CodeGenMonad that will produce a standalone C file containing the CPP
--   includes, struct declarations, functions, and sometimes a main.
wrapProgram
  :: TypedAST (TrivialABT Term) -- ^ Some Hakaru ABT
  -> Maybe String               -- ^ Maybe an output name
  -> PrintConfig                -- ^ show weights?
  -> CodeGen ()
wrapProgram tast@(TypedAST typ _) mn pc =
  do sequence_ . fmap (extDeclare . CPPExt) . header $ typ
     baseCG
     return ()
  where baseCG = case (tast,mn) of
               ( TypedAST (SFun _ _) abt, Just name ) ->
                 do reserveName name
                    flattenTopLambda abt $ Ident name

               ( TypedAST (SFun _ _) abt, Nothing   ) ->
                 genIdent' "fn" >>= \name ->
                   flattenTopLambda abt name


               ( TypedAST _ _,                  Just _ ) -> undefined
                 -- do reserveName name
                 --    defineFunction typ'
                 --                   (Ident name)
                 --                   []
                 --                   (putStat . CReturn . Just =<< flattenABT abt)

               ( TypedAST typ'       abt, Nothing   ) ->
                 mainFunction pc typ' abt



header :: Sing (a :: Hakaru) -> [Preprocessor]
header (SMeasure _) = fmap PPInclude ["time.h", "stdlib.h", "stdio.h", "math.h"]
header _            = fmap PPInclude ["stdlib.h", "stdio.h", "math.h"]



--------------------------------------------------------------------------------
--                             A Main Function                                --
--------------------------------------------------------------------------------
{-

Create standalone C program for a Hakaru ABT. This program will also print the
computed value to stdin.

-}

mainFunction
  :: ABT Term abt
  => PrintConfig
  -> Sing (a :: Hakaru)    -- ^ type of program
  -> abt '[] (a :: Hakaru) -- ^ Hakaru ABT
  -> CodeGen ()

-- when measure, compile to a sampler
mainFunction pc typ@(SMeasure t) abt =
  let ident   = Ident "measure"
      funId   = Ident "main"
      mdataId = Ident "mdata"
  in  do reserveName "measure"
         reserveName "mdata"
         reserveName "main"

         extDeclare . mdataStruct $ t

         -- defined a measure function that returns mdata
         funCG CVoid ident [mdataPtrDeclaration t mdataId] $
           do flattenABT abt (CVar mdataId)
              putStat (CReturn Nothing)

  --        -- need to set seed?
  --        -- srand(time(NULL));

         printf pc typ (CVar ident)
         putStat . CReturn . Just $ intE 0

         !cg <- get
         extDeclare . CFunDefExt $ functionDef SInt
                                               funId
                                               []
                                               (P.reverse $ declarations cg)
                                               (P.reverse $ statements cg)
  -- where isSArray (SArray _) = True
  --       isSArray _          = False
  --       mkArrayStruct :: Sing (a :: Hakaru) -> CExtDecl
  --       mkArrayStruct (SArray t) = arrayStruct t
  --       mkArrayStruct _          = error "Not Array"
  --       getArrayType :: Sing (b :: Hakaru) -> [CTypeSpec]
  --       getArrayType (SArray t) = case buildType t of
  --                                   [] -> error "wrapper: this shouldn't happen"
  --                                   t  -> t
  --       getArrayType _          = error "Not Array"
  --       getPlateArity :: ABT Term abt => Term abt a -> abt '[] 'HNat
  --       getPlateArity (Plate :$ arity :* _ :* End) = arity
  --       getPlateArity _ = error "mainFunction not a plate"

-- just a computation
mainFunction pc typ abt =
  let resId = Ident "result"
      resE  = CVar resId
      funId = Ident "main"
  in  do reserveName "result"
         reserveName "main"

         declare typ resId
         flattenABT abt resE

         printf pc typ resE
         putStat . CReturn . Just $ intE 0

         cg <- get
         extDeclare . CFunDefExt $ functionDef SInt
                                              funId
                                              []
                                              (P.reverse $ declarations cg)
                                              (P.reverse $ statements cg)


--------------------------------------------------------------------------------
--                               Printing Values                              --
--------------------------------------------------------------------------------
{-

In HKC the printconfig is parsed from the command line. The default being that
we don't show weights and probabilities are printed as normal real values.

-}

data PrintConfig
  = PrintConfig { showWeights   :: Bool
                , showProbInLog :: Bool
                } deriving Show


printf
  :: PrintConfig
  -> Sing (a :: Hakaru) -- ^ Hakaru type to be printed
  -> CExpr              -- ^ CExpr representing value
  -> CodeGen ()

printf pc mt@(SMeasure t) sampleFunc =
  case t of
    _ -> do mId <- genIdent' "m"
            declare mt mId
            let mE = CVar mId
                getSampleS   = CExpr . Just $ CCall sampleFunc [address mE]
                printSampleE = CExpr . Just
                             $ CCall (CVar . Ident $ "printf")
                                     $ [ stringE $ printfText pc mt "\n"]
                                     ++ (if showWeights pc
                                         then [ if showProbInLog pc
                                                then mdataWeight mE
                                                else exp $ mdataWeight mE ]
                                         else [])
                                     ++ [ case t of
                                            SProb -> if showProbInLog pc
                                                     then mdataSample mE
                                                     else exp $ mdataSample mE
                                            _ -> mdataSample mE ]
                wrapSampleFunc = CCompound $ [CBlockStat getSampleS
                                             ,CBlockStat $ CIf ((exp $ mdataWeight mE) .>. (floatE 0)) printSampleE Nothing]
            putStat $ CWhile (intE 1) wrapSampleFunc False


printf pc (SArray t) arg =
  do iterId <- genIdent' "it"
     declare SInt iterId
     let iter   = CVar iterId
         result = arg
         dataPtr = CMember result (Ident "data") True
         sizeVar = CMember result (Ident "size") True
         cond     = iter .<. sizeVar
         inc      = CUnary CPostIncOp iter
         currInd  = indirect (dataPtr .+. iter)
         loopBody = putExprStat $ CCall (CVar . Ident $ "printf")
                                        [ stringE $ printfText pc t " "
                                        , currInd ]


     putString "[ "
     mkSequential -- cant print arrays in parallel
     forCG (iter .=. (intE 0)) cond inc loopBody
     putString "]\n"
  where putString s = putExprStat $ CCall (CVar . Ident $ "printf")
                                          [stringE s]

printf pc SProb arg =
  putExprStat $ CCall (CVar . Ident $ "printf")
                      [ stringE $ printfText pc SProb "\n"
                      , if showProbInLog pc
                        then arg
                        else exp arg ]

printf pc typ arg =
  putExprStat $ CCall (CVar . Ident $ "printf")
                      [ stringE $ printfText pc typ "\n"
                      , arg ]



printfText :: PrintConfig -> Sing (a :: Hakaru) -> (String -> String)
printfText _ SInt         = \s -> "%d" ++ s
printfText _ SNat         = \s -> "%d" ++ s
printfText c SProb        = \s -> if showProbInLog c
                                  then "exp(%.15f)" ++ s
                                  else "%.15f" ++ s
printfText _ SReal        = \s -> "%.17f" ++ s
printfText c (SMeasure t) = if showWeights c
                            then \s -> if showProbInLog c
                                       then "exp(%.15f) " ++ printfText c t s
                                       else "%.15f " ++ printfText c t s
                            else printfText c t
printfText c (SArray t)   = printfText c t
printfText _ (SFun _ _)   = id
printfText _ (SData _ _)  = \s -> "TODO: printft datum" ++ s


--------------------------------------------------------------------------------
--                           Wrapping   Lambdas                               --
--------------------------------------------------------------------------------
{-

Lambdas become function in C. The Hakaru ABT only allows one arguement for each
lambda. So at the top level of a Hakaru program that is a function there may be
several nested lambdas. In C however, we can and should coalesce these into one
function with several arguements. This is what flattenTopLambda is for.

-}


flattenTopLambda
  :: ABT Term abt
  => abt '[] a
  -> Ident
  -> CodeGen ()
flattenTopLambda abt name =
    coalesceLambda abt $ \vars abt' ->
    let varMs = foldMap11 (\v -> [mkVarDecl v =<< createIdent v]) vars
        typ   = typeOf abt'
    in  do argDecls <- sequence varMs
           cg <- get
           case typ of
             -- SMeasure _ -> error "flattenTopLambda: for Measures"
             -- SMeasure _ -> do let m       = putStat . CReturn . Just =<< flattenABT abt'
             --                      (_,cg') = runState m $ cg { statements = []
             --                                                , declarations = [] }
             --                  put $ cg' { statements   = statements cg
             --                            , declarations = declarations cg }
             --                  extDeclare . CFunDefExt
             --                    $ functionDef typ name
             --                                      argDecls
             --                                      (P.reverse $ declarations cg')
             --                                      (P.reverse $ statements cg')
             _ -> do let m       = do outId <- genIdent' "out"
                                      declare (typeOf abt') outId
                                      let outE = CVar outId
                                      flattenABT abt' outE
                                      putStat . CReturn . Just $ outE
                         (_,cg') = runState m $ cg { statements = []
                                                   , declarations = [] }
                     put $ cg' { statements   = statements cg
                               , declarations = declarations cg }
                     extDeclare . CFunDefExt
                       $ functionDef typ name
                                         argDecls
                                         (P.reverse $ declarations cg')
                                         (P.reverse $ statements cg')
  -- do at top level
  where coalesceLambda
          :: ABT Term abt
          => abt '[] a
          -> ( forall (ys :: [Hakaru]) b
             . List1 Variable ys -> abt '[] b -> r)
          -> r
        coalesceLambda abt_ k =
          caseVarSyn abt_ (const (k Nil1 abt_)) $ \term ->
            case term of
              (Lam_ :$ body :* End) ->
                caseBind body $ \v body' ->
                  coalesceLambda body' $ \vars body'' -> k (Cons1 v vars) body''
              _ -> k Nil1 abt_


        mkVarDecl :: Variable (a :: Hakaru) -> Ident -> CodeGen CDecl
        mkVarDecl (Variable _ _ SInt)  = return . typeDeclaration SInt
        mkVarDecl (Variable _ _ SNat)  = return . typeDeclaration SNat
        mkVarDecl (Variable _ _ SProb) = return . typeDeclaration SProb
        mkVarDecl (Variable _ _ SReal) = return . typeDeclaration SReal
        mkVarDecl (Variable _ _ (SArray t)) = \i -> do extDeclare $ arrayStruct t
                                                       return $ arrayDeclaration t i
        mkVarDecl (Variable _ _ d@(SData _ _)) = \i -> do extDeclare $ datumStruct d
                                                          return $ datumDeclaration d i
        mkVarDecl v = error $ "flattenSCon.Lam_.mkVarDecl cannot handle vars of type " ++ show v