{-# LANGUAGE CPP,
             BangPatterns,
             DataKinds,
             FlexibleContexts,
             GADTs,
             KindSignatures,
             ScopedTypeVariables,
             RankNTypes,
             TypeOperators #-}

----------------------------------------------------------------
--                                                    2016.06.23
-- |
-- Module      :  Language.Hakaru.CodeGen.Flatten
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  zsulliva@indiana.edu
-- Stability   :  experimental
-- Portability :  GHC-only
--
--   Flatten takes Hakaru ABTs and C vars and returns a CStatement
-- assigning the var to the flattened ABT.
--
----------------------------------------------------------------


module Language.Hakaru.CodeGen.Flatten
  ( flattenABT
  , flattenVar
  , flattenTerm )
  where

import Language.Hakaru.CodeGen.CodeGenMonad
import Language.Hakaru.CodeGen.AST
import Language.Hakaru.CodeGen.Types

import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.TypeOf (typeOf)
import Language.Hakaru.Syntax.Datum hiding (Ident)
import qualified Language.Hakaru.Syntax.Prelude as HKP
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.Sing

import           Control.Monad.State.Strict
import           Control.Monad (replicateM)
import           Data.Number.Natural
import           Data.Ratio
import qualified Data.List.NonEmpty as NE
import qualified Data.Sequence      as S
import qualified Data.Foldable      as F
import qualified Data.Traversable   as T


#if __GLASGOW_HASKELL__ < 710
import           Data.Functor
#endif

import Prelude hiding (log,exp,sqrt)


opComment :: String -> CStat
opComment opStr = CComment $ concat [space," ",opStr," ",space]
  where size  = (50 - (length opStr)) `div` 2 - 8
        space = replicate size '-'

--------------------------------------------------------------------------------
--                                 Top Level                                  --
--------------------------------------------------------------------------------
{-

flattening an ABT will produce a continuation that takes a CExpr representing
a location where the value of the ABT should be stored. Return type of the
the continuation is CodeGen Bool, where the computed bool is whether or not
there is a Reject inside the ABT. Therefore it is only needed when computing
mochastic values

-}

flattenABT
  :: ABT Term abt
  => abt '[] a
  -> (CExpr -> CodeGen ())
flattenABT abt = caseVarSyn abt flattenVar flattenTerm

-- note that variables will find their values in the state of the CodeGen monad
flattenVar
  :: Variable (a :: Hakaru)
  -> (CExpr -> CodeGen ())
flattenVar v = \loc ->
  do v' <- CVar <$> lookupIdent v
     putStat . CExpr . Just $ loc .=. v'

flattenTerm
  :: ABT Term abt
  => Term abt a
  -> (CExpr -> CodeGen ())
flattenTerm (NaryOp_ t s)    = flattenNAryOp t s
flattenTerm (Literal_ x)     = flattenLit x
flattenTerm (Empty_ _)       = error "TODO: flattenTerm Empty"

flattenTerm (Datum_ d)       = flattenDatum d
flattenTerm (Case_ c bs)     = flattenCase c bs

flattenTerm (Array_ s e)     = flattenArray s e

-- SCon can contain mochastic terms
flattenTerm (x :$ ys)        = flattenSCon x ys

---------------------
-- Mochastic Terms --
---------------------
flattenTerm (Reject_ _)      = \loc -> putExprStat (mdataPtrWeight loc .=. (intE 0)) -- fail to draw a sample
flattenTerm (Superpose_ wes) = flattenSuperpose wes





--------------------------------------------------------------------------------
--                                  SCon                                      --
--------------------------------------------------------------------------------

flattenSCon
  :: ( ABT Term abt )
  => SCon args a
  -> SArgs abt args
  -> (CExpr -> CodeGen ())

flattenSCon Let_ =
  \(expr :* body :* End) ->
    \loc -> do
      caseBind body $ \v@(Variable _ _ typ) body'->
        do ident <- createIdent v
           declare typ ident
           flattenABT expr (CVar ident)
           flattenABT body' loc

-- Lambdas produce functions and then return a function pointer
flattenSCon Lam_ = undefined
  -- \(body :* End) ->
  --   \loc ->
  --     do coalesceLambda body $ \vars body' ->
  --        let varMs = foldMap11 (\v -> [mkVarDecl v =<< createIdent v]) vars
  --        in  do funcId <- genIdent' "fn"
  --               argDecls <- sequence varMs

  --               cg <- get
  --               let m       = putStat . CReturn . Just =<< flattenABT body'
  --                   (_,cg') = runState m $ cg { statements = []
  --                                             , declarations = [] }
  --               put $ cg' { statements   = statements cg
  --                         , declarations = declarations cg }

  --               extDeclare . CFunDefExt $ functionDef (typeOf body')
  --                                                     funcId
  --                                                     argDecls
  --                                                     (reverse $ declarations cg')
  --                                                     (reverse $ statements cg')
  -- -- do at top level
  -- where coalesceLambda
  --         :: ( ABT Term abt )
  --         => abt '[x] a
  --         -> (forall (ys :: [Hakaru]) b. List1 Variable ys -> abt '[] b -> r)
  --         -> r
  --       coalesceLambda abt k =
  --         caseBind abt $ \v abt' ->
  --           caseVarSyn abt' (const (k (Cons1 v Nil1) abt')) $ \term ->
  --             case term of
  --               (Lam_ :$ body :* End) ->
  --                 coalesceLambda body $ \vars abt'' -> k (Cons1 v vars) abt''
  --               _ -> k (Cons1 v Nil1) abt'

  --       mkVarDecl :: Variable (a :: Hakaru) -> Ident -> CodeGen CDecl
  --       mkVarDecl (Variable _ _ SInt)  = return . typeDeclaration SInt
  --       mkVarDecl (Variable _ _ SNat)  = return . typeDeclaration SNat
  --       mkVarDecl (Variable _ _ SProb) = return . typeDeclaration SProb
  --       mkVarDecl (Variable _ _ SReal) = return . typeDeclaration SReal
  --       mkVarDecl (Variable _ _ (SArray t)) = \i -> do extDeclare $ arrayStruct t
  --                                                      return $ arrayDeclaration t i
  --       mkVarDecl (Variable _ _ d@(SData _ _)) = \i -> do extDeclare $ datumStruct d
  --                                                         return $ datumDeclaration d i
  --       mkVarDecl v = error $ "flattenSCon.Lam_.mkVarDecl cannot handle vars of type " ++ show v


flattenSCon (PrimOp_ op) = flattenPrimOp op

flattenSCon (ArrayOp_ op) = flattenArrayOp op

flattenSCon (Summate _ sr) =
  \(lo :* hi :* body :* End) ->
    \loc ->
      caseBind body $ \v body' ->
        do loId <- genIdent
           hiId <- genIdent
           declare (typeOf lo) loId
           declare (typeOf hi) hiId
           let loE = CVar loId
               hiE = CVar hiId
           flattenABT lo loE
           flattenABT hi hiE

           iterI <- createIdent v
           declare SNat iterI

           accI <- genIdent' "acc"
           let semiT = sing_HSemiring sr
           declare semiT accI
           assign accI (case semiT of
                          SProb -> negInfinityE
                          SReal -> floatE 0
                          _     -> intE 0)

           let accVar  = CVar accI
               iterVar = CVar iterI


           putStat $ opComment "Summate"
           -- logSumExp for probabilities
           reductionCG CAddOp
                       accI
                       (iterVar .=. loE)
                       (iterVar .<. hiE)
                       (CUnary CPostIncOp iterVar) $
             do tmpId <- genIdent
                declare (typeOf body') tmpId
                let tmpE = CVar tmpId
                flattenABT body' tmpE
                case semiT of
                  SProb -> logSumExpCG (S.fromList [accVar,tmpE]) accVar
                  _     -> putStat . CExpr . Just $ (accVar .+=. tmpE)

           putExprStat (loc .=. accVar)


flattenSCon (Product _ sr) =
  \(lo :* hi :* body :* End) ->
    \loc ->
      caseBind body $ \v body' ->
        do loId <- genIdent
           hiId <- genIdent
           declare (typeOf lo) loId
           declare (typeOf hi) hiId
           let loE = CVar loId
               hiE = CVar hiId
           flattenABT lo loE
           flattenABT hi hiE

           iterI <- createIdent v
           declare SNat iterI

           accI <- genIdent' "acc"
           let semiT = sing_HSemiring sr
           declare semiT accI
           assign accI (case semiT of
                          SProb -> floatE 0
                          SReal -> floatE 1
                          _     -> intE 1)

           let accVar  = CVar accI
               iterVar = CVar iterI

           putStat $ opComment "Product"
           reductionCG (case semiT of
                          SProb -> CAddOp
                          _     -> CMulOp)
                       accI
                       (iterVar .=. loE)
                       (iterVar .<. hiE)
                       (CUnary CPostIncOp iterVar) $
             do tmpId <- genIdent
                declare (typeOf body') tmpId
                let tmpE = CVar tmpId
                flattenABT body' tmpE
                putExprStat $ case semiT of
                                SProb -> CAssign CAddAssOp accVar
                                _ -> CAssign CMulAssOp accVar
                            $ tmpE

           putExprStat (loc .=. accVar)


--------------------
-- SCon Coersions --
--------------------

-- at this point, only nonrecusive coersions are implemented
flattenSCon (CoerceTo_ ctyp) =
  \(e :* End) ->
    \loc ->
       do eId <- genIdent
          let eT = typeOf e
              eE = CVar eId
          declare eT eId
          flattenABT e eE
          putExprStat . (CAssign CAssignOp loc) =<< coerceToType ctyp eT eE
  where coerceToType
          :: Coercion a b
          -> Sing (c :: Hakaru)
          -> CExpr
          -> CodeGen CExpr
        coerceToType (CCons p rest) typ =
          \e ->  primitiveCoerce p typ e >>= coerceToType rest typ
        coerceToType CNil            _  = return . id

        primitiveCoerce
          :: PrimCoercion a b
          -> Sing (c :: Hakaru)
          -> CExpr
          -> CodeGen CExpr
        primitiveCoerce (Signed HRing_Int)            SNat  = nat2int
        primitiveCoerce (Signed HRing_Real)           SProb = prob2real
        primitiveCoerce (Continuous HContinuous_Prob) SNat  = nat2prob
        primitiveCoerce (Continuous HContinuous_Real) SInt  = int2real
        primitiveCoerce (Continuous HContinuous_Real) SNat  = int2real
        primitiveCoerce a b = error $ "flattenSCon CoerceTo_: cannot preform coersion "
                                    ++ show a
                                    ++ " to "
                                    ++ show b


        -- implementing ONLY functions found in Hakaru.Syntax.AST
        nat2int,nat2prob,prob2real,int2real
          :: CExpr -> CodeGen CExpr
        nat2int   = return
        nat2prob  = \n -> do ident <- genIdent' "p"
                             declare SProb ident
                             assign ident . log1p $ n .-. (intE 1)
                             return (CVar ident)
        prob2real = \p -> do ident <- genIdent' "r"
                             declare SReal ident
                             assign ident $ (expm1 p) .+. (intE 1)
                             return (CVar ident)
        int2real  = return . CCast doubleDecl


-----------------------------------
-- SCons in the Stochastic Monad --
-----------------------------------

flattenSCon (MeasureOp_ op) = flattenMeasureOp op

flattenSCon Dirac           =
  \(e :* End) ->
    \loc ->
       do sId <- genIdent' "samp"
          declare (typeOf e) sId
          let sE = CVar sId
          flattenABT e sE
          putExprStat $ mdataPtrWeight loc .=. (floatE 0)
          putExprStat $ mdataPtrSample loc .=. sE

flattenSCon MBind           =
  \(ma :* b :* End) ->
    \loc ->
      caseBind b $ \v@(Variable _ _ typ) mb ->
        do -- first
           mId <- genIdent' "m"
           declare (typeOf ma) mId
           let mE = CVar mId
           flattenABT ma (address mE)

           -- assign that sample to var
           vId <- createIdent v
           declare typ vId
           assign vId (mdataSample mE)
           flattenABT mb loc
           putExprStat $ mdataPtrWeight loc .+=. (mdataWeight mE)

-- for now plats make use of a global sample
flattenSCon Plate           =
  \(size :* b :* End) ->
    \loc ->
      caseBind b $ \v@(Variable _ _ typ) body ->
        do sizeId <- genIdent' "s"
           declare SNat sizeId
           let sizeE = CVar sizeId
           flattenABT size sizeE
           putExprStat $   (arrayPtrData . mdataPtrSample $ loc)
                       .=. (CCast (mkPtrDecl . buildType $ typ)
                                  (mkUnary "malloc"
                                  (sizeE .*. (CSizeOfType . mkDecl . buildType $ typ))))



           weightId <- genIdent' "w"
           declare SProb weightId
           let weightE = CVar weightId
           assign weightId (floatE 0)

           itId <- createIdent v
           declare SNat itId
           let itE = CVar itId
               currInd  = indirect $ (CMember (mdataSample loc) (Ident "data") True) .+. itE

           sampId <- genIdent' "samp"
           declare (typeOf $ body) sampId
           let sampE = CVar sampId

           reductionCG CAddOp
                       weightId
                       (itE .=. (intE 0))
                       (itE .<. sizeE)
                       (CUnary CPostIncOp itE)
                       (do flattenABT body (address sampE)
                           putExprStat (currInd .=. (mdataSample sampE))
                           putExprStat (weightE .+=. (mdataWeight sampE)))

           putExprStat $ mdataPtrWeight loc .=. weightE


-----------------------------------
-- SCon's that arent implemented --
-----------------------------------

flattenSCon x               = \_ -> \_ -> error $ "TODO: flattenSCon: " ++ show x





--------------------------------------------------------------------------------
--                                 NaryOps                                    --
--------------------------------------------------------------------------------

flattenNAryOp :: ABT Term abt
              => NaryOp a
              -> S.Seq (abt '[] a)
              -> (CExpr -> CodeGen ())
flattenNAryOp op args =
  \loc ->
    do es <- T.forM args $ \a ->
               do aId <- genIdent
                  let aE = CVar aId
                  declare (typeOf a) aId
                  _ <- flattenABT a aE
                  return aE
       case op of
         And -> boolNaryOp op es loc
         Or  -> boolNaryOp op es loc
         Xor -> boolNaryOp op es loc
         Iff -> boolNaryOp op es loc

         (Sum HSemiring_Prob) -> logSumExpCG es loc

         _ -> let opE = F.foldr (binaryOp op) (S.index es 0) (S.drop 1 es)
              in  putExprStat (loc .=. opE)


  where boolNaryOp op' es' loc' =
          let indexOf x = CMember x (Ident "index") True
              es''      = fmap indexOf es'
              expr      = F.foldr (binaryOp op')
                                  (S.index es'' 0)
                                  (S.drop 1 es'')
          in  putExprStat ((indexOf loc') .=. expr)


--------------------------------------
-- LogSumExp for NaryOp Add [SProb] --
--------------------------------------
{-

Special for addition of probabilities we have a logSumExp. This will compute the
sum of the probabilities safely. Just adding the exp(a . prob) would make us
loose any of the safety from underflow that we got from storing prob in the log
domain

-}

-- the tree traversal is a depth first search
logSumExp :: S.Seq CExpr -> CExpr
logSumExp es = mkCompTree 0 1

  where lastIndex  = S.length es - 1

        compIndices :: Int -> Int -> CExpr -> CExpr -> CExpr
        compIndices i j = CCond ((S.index es i) .>. (S.index es j))

        mkCompTree :: Int -> Int -> CExpr
        mkCompTree i j
          | j == lastIndex = compIndices i j (logSumExp' i) (logSumExp' j)
          | otherwise      = compIndices i j
                               (mkCompTree i (succ j))
                               (mkCompTree j (succ j))

        diffExp :: Int -> Int -> CExpr
        diffExp a b = expm1 ((S.index es a) .-. (S.index es b))

        -- given the max index, produce a logSumExp expression
        logSumExp' :: Int -> CExpr
        logSumExp' 0 = S.index es 0
          .+. (log1p $ foldr (\x acc -> diffExp x 0 .+. acc)
                            (diffExp 1 0)
                            [2..S.length es - 1]
                    .+. (intE $ fromIntegral lastIndex))
        logSumExp' i = S.index es i
          .+. (log1p $ foldr (\x acc -> if i == x
                                       then acc
                                       else diffExp x i .+. acc)
                            (diffExp 0 i)
                            [1..S.length es - 1]
                    .+. (intE $ fromIntegral lastIndex))


-- | logSumExpCG creates global functions for every n-ary logSumExp function
-- this reduces code size
logSumExpCG :: S.Seq CExpr -> (CExpr -> CodeGen ())
logSumExpCG seqE =
  let size   = S.length $ seqE
      name   = "logSumExp" ++ (show size)
      funcId = Ident name
  in \loc -> do-- reset the names so that the function is the same for each arity
       cg <- get
       put (cg { freshNames = suffixes })
       argIds <- replicateM size genIdent
       let decls = fmap (typeDeclaration SProb) argIds
           vars  = fmap CVar argIds
       extDeclare . CFunDefExt $ functionDef SProb
                                             funcId
                                             decls
                                             []
                                             [CReturn . Just $ logSumExp $ S.fromList vars ]
       cg' <- get
       put (cg' { freshNames = freshNames cg })
       putExprStat $ loc .=. (CCall (CVar funcId) (F.toList seqE))


--------------------------------------------------------------------------------
--                                  Literals                                  --
--------------------------------------------------------------------------------

flattenLit
  :: Literal a
  -> (CExpr -> CodeGen ())
flattenLit lit =
  \loc ->
    case lit of
      (LNat x)  -> putExprStat $ loc .=. (intE $ fromIntegral x)
      (LInt x)  -> putExprStat $ loc .=. (intE x)
      (LReal x) -> putExprStat $ loc .=. (floatE $ fromRational x)
      (LProb x) -> let rat = fromNonNegativeRational x
                       x'  = (fromIntegral $ numerator rat)
                           / (fromIntegral $ denominator rat)
                       xE  = log1p (floatE x' .-. intE 1)
                   in putExprStat (loc .=. xE)

--------------------------------------------------------------------------------
--                                Array and ArrayOps                          --
--------------------------------------------------------------------------------


flattenArray
  :: (ABT Term abt)
  => (abt '[] 'HNat)
  -> (abt '[ 'HNat ] a)
  -> (CExpr -> CodeGen ())
flattenArray arity body =
  \loc ->
    caseBind body $ \v@(Variable _ _ typ) body' ->
      let arityE = arraySize loc
          dataE  = arrayData loc in
      do flattenABT arity arityE
         putExprStat $   dataE
                     .=. (CCast (mkPtrDecl . buildType $ typ)
                                (mkUnary "malloc"
                                  (arityE .*. (CSizeOfType . mkDecl . buildType $ typ))))

         itId  <- createIdent v
         declare SNat itId
         let itE     = CVar itId
             currInd = indirect (dataE .+. itE)

         putStat $ opComment "Create Array"
         forCG (itE .=. (intE 0))
               (itE .<. arityE)
               (CUnary CPostIncOp itE)
               (flattenABT body' currInd)


--------------
-- ArrayOps --
--------------

flattenArrayOp
  :: ( ABT Term abt
     , typs ~ UnLCs args
     , args ~ LCs typs
     )
  => ArrayOp typs a
  -> SArgs abt args
  -> (CExpr -> CodeGen ())


flattenArrayOp (Index _)  =
  \(arr :* ind :* End) ->
    \loc ->
      do arrId <- genIdent' "arr"
         indId <- genIdent
         let arrE = CVar arrId
             indE = CVar indId
         declare (typeOf arr) arrId
         declare SNat indId
         flattenABT arr arrE
         flattenABT ind indE
         let valE = indirect ((CMember arrE (Ident "data") True) .+. indE)
         putExprStat (loc .=. valE)

flattenArrayOp (Size _)   =
  \(arr :* End) ->
    \loc ->
      do arrId <- genIdent' "arr"
         declare (typeOf arr) arrId
         let arrE = CVar arrId
         flattenABT arr arrE
         putExprStat (loc .=. (CMember arrE (Ident "size") True))

flattenArrayOp (Reduce _) = error "TODO: flattenArrayOp"
  -- \(fun :* base :* arr :* End) ->
  -- do funE  <- flattenABT fun
  --    baseE <- flattenABT base
  --    arrE  <- flattenABT arr
  --    accI  <- genIdent' "acc"
  --    iterI <- genIdent' "iter"

  --    let sizeE = CMember arrE (Ident "size") True
  --        iterE = CVar iterI
  --        accE  = CVar accI
  --        cond  = iterE .<. sizeE
  --        inc   = CUnary CPostIncOp iterE

  --    declare (typeOf base) accI
  --    declare SInt iterI
  --    assign accI baseE
  --    forCG (iterE .=. (intE 0)) cond inc $
  --      assign accI $ CCall funE [accE]

  --    return accE


--------------------------------------------------------------------------------
--                                 Datum and Case                             --
--------------------------------------------------------------------------------
{-

Datum are sums of products of types. This maps to a C structure. flattenDatum
will produce a literal of some datum type. This will also produce a global
struct representing that datum which will be needed for the C compiler.

-}


flattenDatum
  :: (ABT Term abt)
  => Datum (abt '[]) (HData' a)
  -> (CExpr -> CodeGen ())
flattenDatum (Datum _ typ code) =
  \loc ->
    do extDeclare $ datumStruct typ
       assignDatum code loc

datumNames :: [String]
datumNames = filter (\n -> not $ elem (head n) ['0'..'9']) names
  where base = ['0'..'9'] ++ ['a'..'z']
        names = [[x] | x <- base] `mplus` (do n <- names
                                              [n++[x] | x <- base])

assignDatum
  :: (ABT Term abt)
  => DatumCode xss (abt '[]) c
  -> CExpr
  -> CodeGen ()
assignDatum code ident =
  let index     = getIndex code
      indexExpr = CMember ident (Ident "index") True
  in  do putExprStat (indexExpr .=. (intE index))
         sequence_ $ assignSum code ident
  where getIndex :: DatumCode xss b c -> Integer
        getIndex (Inl _)    = 0
        getIndex (Inr rest) = succ (getIndex rest)

assignSum
  :: (ABT Term abt)
  => DatumCode xs (abt '[]) c
  -> CExpr
  -> [CodeGen ()]
assignSum code ident = fst $ runState (assignSum' code ident) datumNames

assignSum'
  :: (ABT Term abt)
  => DatumCode xs (abt '[]) c
  -> CExpr
  -> State [String] [CodeGen ()]
assignSum' (Inr rest) topIdent =
  do (_:names) <- get
     put names
     assignSum' rest topIdent
assignSum' (Inl prod) topIdent =
  do (name:_) <- get
     return $ assignProd prod topIdent (CVar . Ident $ name)

assignProd
  :: (ABT Term abt)
  => DatumStruct xs (abt '[]) c
  -> CExpr
  -> CExpr
  -> [CodeGen ()]
assignProd dstruct topIdent sumIdent =
  fst $ runState (assignProd' dstruct topIdent sumIdent) datumNames

assignProd'
  :: (ABT Term abt)
  => DatumStruct xs (abt '[]) c
  -> CExpr
  -> CExpr
  -> State [String] [CodeGen ()]
assignProd' Done _ _ = return []
assignProd' (Et (Konst d) rest) topIdent (CVar sumIdent) =
  do (name:names) <- get
     put names
     let varName  = CMember (CMember (CMember topIdent
                                              (Ident "sum")
                                              True)
                                     sumIdent
                                     True)
                            (Ident name)
                            True
     rest' <- assignProd' rest topIdent (CVar sumIdent)
     return $ [flattenABT d varName] ++ rest'
assignProd' _ _ _  = error $ "TODO: assignProd Ident"


----------
-- Case --
----------

-- currently we can only match on boolean values
flattenCase
  :: forall abt a b
  .  (ABT Term abt)
  => abt '[] a
  -> [Branch a abt b]
  -> (CExpr -> CodeGen ())

flattenCase c (Branch (PDatum _ (PInl PDone)) trueB:Branch (PDatum _ (PInr (PInl PDone))) falseB:[]) =
  \loc ->
    do cId <- genIdent
       declare (typeOf c) cId
       let cE = (CVar cId)
       flattenABT c cE

       cg <- get
       let trueM    = flattenABT trueB loc
           falseM   = flattenABT falseB loc
           (_,cg')  = runState trueM $ cg { statements = [] }
           (_,cg'') = runState falseM $ cg' { statements = [] }
       put $ cg'' { statements = statements cg }

       putStat $ CIf ((CMember cE (Ident "index") True) .==. (intE 0))
                     (CCompound . fmap CBlockStat . reverse . statements $ cg')
                     Nothing
       putStat $ CIf ((CMember cE (Ident "index") True) .==. (intE 1))
                     (CCompound . fmap CBlockStat . reverse . statements $ cg'')
                     Nothing


flattenCase _ _ = error "TODO: flattenCase"


--------------------------------------------------------------------------------
--                                     PrimOp                                 --
--------------------------------------------------------------------------------

flattenPrimOp
  :: ( ABT Term abt
     , typs ~ UnLCs args
     , args ~ LCs typs)
  => PrimOp typs a
  -> SArgs abt args
  -> (CExpr -> CodeGen ())


flattenPrimOp Pi =
  \End ->
    \loc -> let piE = log1p ((CVar . Ident $ "M_PI") .-. (intE 1)) in
      putExprStat (loc .=. piE)

flattenPrimOp Not =
  \(a :* End) ->
    \_ ->
      -- this is currently incorrect, need to use memcpy to preserve value of
      -- 'a'
      do tmpId <- genIdent' "not"
         declare sBool tmpId
         let tmpE = CVar tmpId
         flattenABT a tmpE
         let datumIndex = CMember tmpE (Ident "index") True
         putExprStat $ datumIndex .=. (CCond (datumIndex .==. (intE 1))
                                             (intE 0)
                                             (intE 1))

flattenPrimOp RealPow =
  \(base :* power :* End) ->
    \loc ->
      do baseId <- genIdent
         powerId <- genIdent
         declare SProb baseId
         declare SReal powerId
         let baseE     = CVar baseId
             powerE = CVar powerId
         flattenABT base baseE -- first argument is a Prob
         flattenABT power powerE
         let realPow = CCall (CVar . Ident $ "pow")
                             [ expm1 baseE .+. (intE 1), powerE]
         putExprStat $ loc .=. (log1p (realPow .-. (intE 1)))

flattenPrimOp (NatPow baseTyp) =
  \(base :* power :* End) ->
    \loc ->
      let sBase = sing_HSemiring baseTyp in
      do baseId <- genIdent
         powerId <- genIdent
         declare sBase baseId
         declare SReal powerId
         let baseE     = CVar baseId
             powerE = CVar powerId
         flattenABT base baseE
         flattenABT power powerE
         let powerOf x y = CCall (CVar . Ident $ "pow") [x,y]
             value = case sBase of
                       SProb -> log1p $ (powerOf (expm1 baseE .+. (intE 1)) powerE)
                                  .-. (intE 1)
                       _     -> powerOf baseE powerE
         putExprStat $ loc .=. value

flattenPrimOp (NatRoot baseTyp) =
  \(base :* root :* End) ->
    \loc ->
      let sBase = sing_HRadical baseTyp in
      do baseId <- genIdent
         rootId <- genIdent
         declare sBase baseId
         declare SReal rootId
         let baseE = CVar baseId
             rootE = CVar rootId
         flattenABT base baseE
         flattenABT root rootE
         let powerOf x y = CCall (CVar . Ident $ "pow") [x,y]
             recipE = (floatE 1) ./. rootE
             value = case sBase of
                       SProb -> log1p $ (powerOf (expm1 baseE .+. (intE 1)) recipE)
                                      .-. (intE 1)
                       _     -> powerOf baseE recipE
         putExprStat $ loc .=. value

flattenPrimOp (Recip t) =
  \(a :* End) ->
    \loc ->
      do aId <- genIdent
         declare (typeOf a) aId
         let aE = CVar aId
         flattenABT a aE
         case t of
           HFractional_Real -> putExprStat $ loc .=. ((intE 1) ./. aE)
           HFractional_Prob -> putExprStat $ loc .=. (CUnary CMinOp aE)

-- | exp : real -> prob, because of this we can just turn it into a prob without taking
--   its log, which would give us an exp in the log-domain
flattenPrimOp Exp = \(a :* End) -> flattenABT a

flattenPrimOp (Equal _) =
  \(a :* b :* End) ->
    \loc ->
      do aId <- genIdent
         bId <- genIdent
         let aE = CVar aId
             bE = CVar bId
             aT = typeOf a
             bT = typeOf b
         declare aT aId
         declare bT bId
         flattenABT a aE
         flattenABT b bE

         -- special case for booleans
         let aE' = case aT of
                     (SData _ (SPlus SDone (SPlus SDone SVoid))) -> (CMember aE (Ident "index") True)
                     _ -> aE
         let bE' = case bT of
                     (SData _ (SPlus SDone (SPlus SDone SVoid))) -> (CMember bE (Ident "index") True)
                     _ -> bE

         putExprStat $   (CMember loc (Ident "index") True)
                     .=. (CCond (aE' .==. bE') (intE 0) (intE 1))


flattenPrimOp (Less _) =
  \(a :* b :* End) ->
    \loc ->
      do aId <- genIdent
         bId <- genIdent
         let aE = CVar aId
             bE = CVar bId
         declare (typeOf a) aId
         declare (typeOf b) bId
         flattenABT a aE
         flattenABT b bE
         putExprStat $ (CMember loc (Ident "index") True)
                     .=. (CCond (aE .<. bE) (intE 0) (intE 1))

flattenPrimOp (Negate HRing_Real) =
 \(a :* End) ->
   \loc ->
     do negId <- genIdent' "neg"
        declare SReal negId
        let negE = CVar negId
        flattenABT a negE
        putExprStat $ loc .=. (CUnary CMinOp $ negE)


flattenPrimOp t  = \_ -> error $ "TODO: flattenPrimOp: " ++ show t


--------------------------------------------------------------------------------
--                           MeasureOps and Superpose                         --
--------------------------------------------------------------------------------

{-

The sections contains operations in the stochastic monad. See also
(Dirac, MBind, and Plate) found in SCon. Also see Reject found at the top level.

Remember in the C runtime. Measures are housed in a measure function, which
takes an `struct mdata` location. The MeasureOp attempts to store a value at
that location and returns 0 if it fails and 1 if it succeeds in that task.

The functions uniformCG, normalCG, and gammaCG are primitives that will generate
functions and call them (similar to logSumExpCG). The reduce code size and make
samplers a little more readable.

TODO: add inline pragmas to uniformCG, normalCG, and gammaCG
-}

uniformFun :: CFunDef
uniformFun = CFunDef [CTypeSpec CVoid]
                     (CDeclr Nothing [CDDeclrIdent funcId])
                     [typeDeclaration SReal loId
                     ,typeDeclaration SReal hiId
                     ,typePtrDeclaration (SMeasure SReal) mId]
                     (seqCStat $ comment ++[assW,assS,CReturn Nothing])
  where r          = CCast doubleDecl rand
        rMax       = CCast doubleDecl (CVar . Ident $ "RAND_MAX")
        (mId,mE)   = let ident = Ident "mdata" in (ident,CVar ident)
        (loId,loE) = let ident = Ident "lo" in (ident,CVar ident)
        (hiId,hiE) = let ident = Ident "hi" in (ident,CVar ident)
        value      = (loE .+. ((r ./. rMax) .*. (hiE .-. loE)))
        comment = fmap CComment
          ["uniform :: real -> real -> *(mdata real) -> ()"
          ,"------------------------------------------------"]
        assW       = CExpr . Just $ mdataPtrWeight mE .=. (floatE 0)
        assS       = CExpr . Just $ mdataPtrSample mE .=. value
        funcId     = Ident "uniform"


uniformCG :: CExpr -> CExpr -> (CExpr -> CodeGen ())
uniformCG aE bE =
  \loc -> do
    reserveName "uniform"
    extDeclare . CFunDefExt $ uniformFun
    putExprStat $ CCall (CVar . Ident $ "uniform") [aE,bE,loc]


{-
  This is very cryptic, but I assure you it is only building an AST for the
  Marsaglia Polar Method
-}

normalFun :: CFunDef
normalFun = CFunDef [CTypeSpec CVoid]
                    (CDeclr Nothing [CDDeclrIdent (Ident "normal")])
                    [typeDeclaration SReal aId
                    ,typeDeclaration SProb bId
                    ,typePtrDeclaration (SMeasure SReal) mId]
                    (CCompound $ comment ++ decls ++ stmts)

  where r      = CCast doubleDecl rand
        rMax   = CCast doubleDecl (CVar . Ident $ "RAND_MAX")
        (aId,aE) = let ident = Ident "a" in (ident,CVar ident)
        (bId,bE) = let ident = Ident "b" in (ident,CVar ident)
        (qId,qE) = let ident = Ident "q" in (ident,CVar ident)
        (uId,uE) = let ident = Ident "u" in (ident,CVar ident)
        (vId,vE) = let ident = Ident "v" in (ident,CVar ident)
        (rId,rE) = let ident = Ident "r" in (ident,CVar ident)
        (mId,mE) = let ident = Ident "mdata" in (ident,CVar ident)
        draw xE = CExpr . Just $ xE .=. (((r ./. rMax) .*. (floatE 2)) .-. (floatE 1))
        body = seqCStat [draw uE
                        ,draw vE
                        ,CExpr . Just $ qE .=. ((uE .*. uE) .+. (vE .*. vE))]
        polar = CWhile (qE .>. (floatE 1)) body True
        setR  = CExpr . Just $ rE .=. (sqrt (((CUnary CMinOp (floatE 2)) .*. log qE) ./. qE))
        finalValue = aE .+. (uE .*. rE .*. bE)
        comment = fmap (CBlockStat . CComment)
          ["normal :: real -> real -> *(mdata real) -> ()"
          ,"Marsaglia Polar Method"
          ,"-----------------------------------------------"]
        decls = fmap (CBlockDecl . typeDeclaration SReal) [uId,vId,qId,rId]
        stmts = fmap CBlockStat [polar,setR, assW, assS,CReturn Nothing]
        assW = CExpr . Just $ mdataPtrWeight mE .=. (floatE 0)
        assS = CExpr . Just $ mdataPtrSample mE .=. finalValue


normalCG :: CExpr -> CExpr -> (CExpr -> CodeGen ())
normalCG aE bE =
  \loc -> do
    reserveName "normal"
    extDeclare . CFunDefExt $ normalFun
    putExprStat $ CCall (CVar . Ident $ "normal") [aE,bE,loc]

{-
  This method is from Marsaglia and Tsang "a simple method for generating gamma variables"
-}
gammaFun :: CFunDef
gammaFun = CFunDef [CTypeSpec CVoid]
                   (CDeclr Nothing [CDDeclrIdent (Ident "gamma")])
                   [typeDeclaration SProb aId
                   ,typeDeclaration SProb bId
                   ,typePtrDeclaration (SMeasure SProb) mId]
                   (CCompound $ comment ++ decls ++ stmts)
  where (aId,aE) = let ident = Ident "a" in (ident,CVar ident)
        (bId,bE) = let ident = Ident "b" in (ident,CVar ident)
        (cId,cE) = let ident = Ident "c" in (ident,CVar ident)
        (dId,dE) = let ident = Ident "d" in (ident,CVar ident)
        (xId,xE) = let ident = Ident "x" in (ident,CVar ident)
        (vId,vE) = let ident = Ident "v" in (ident,CVar ident)
        (uId,uE) = let ident = Ident "u" in (ident,CVar ident)
        (mId,mE) = let ident = Ident "mdata" in (ident,CVar ident)
        comment = fmap (CBlockStat . CComment)
          ["gamma :: real -> prob -> *(mdata prob) -> ()"
          ,"Marsaglia and Tsang 'a simple method for generating gamma variables'"
          ,"--------------------------------------------------------------------"]
        decls = fmap CBlockDecl $ (fmap (typeDeclaration SReal) [dId,cId,vId])
                               ++ (fmap (typeDeclaration (SMeasure SReal)) [uId,xId])
        stmts = fmap CBlockStat $ [assD,assC,outerWhile]
        xS = mdataSample xE
        uS = mdataSample uE
        assD = CExpr . Just $ dE .=. (aE .-. ((floatE 1) ./. (floatE 3)))
        assC = CExpr . Just $ cE .=. ((floatE 1) ./. (sqrt ((floatE 9) .*. dE)))
        outerWhile = CWhile (intE 1) (seqCStat [innerWhile,assV,assU,exit]) False
        innerWhile = CWhile (vE .<=. (floatE 0)) (seqCStat [assX,assVIn]) True
        assX = CExpr . Just $ CCall (CVar . Ident $ "normal") [(floatE 0),(floatE 1),address xE]
        assVIn = CExpr . Just $ vE .=. ((floatE 1) .+. (cE .*. xS))
        assV = CExpr . Just $ vE .=. (vE .*. vE .*. vE)
        assU = CExpr . Just $ CCall (CVar . Ident $ "uniform") [(floatE 0),(floatE 1),address uE]
        exitC1 = uS .<. ((floatE 1) .-. ((floatE 0.331 .*. (xS .*. xS) .*. (xS .*. xS))))
        exitC2 = (log uS) .<. (((floatE 0.5) .*. (xS .*. xS)) .+. (dE .*. ((floatE 1.0) .-. vE .+. (log vE))))
        assW = CExpr . Just $ mdataPtrWeight mE .=. (floatE 0)
        assS = CExpr . Just $ mdataPtrSample mE .=. (log (dE .*. vE)) .+. bE
        exit = CIf (exitC1 .||. exitC2) (seqCStat [assW,assS,CReturn Nothing]) Nothing


gammaCG :: CExpr -> CExpr -> (CExpr -> CodeGen ())
gammaCG aE bE =
  \loc -> do
     extDeclare $ mdataStruct SReal
     mapM_ reserveName ["uniform","normal","gamma"]
     mapM_ (extDeclare . CFunDefExt) [uniformFun,normalFun,gammaFun]
     putExprStat $ CCall (CVar . Ident $ "gamma") [aE,bE,loc]


flattenMeasureOp
  :: forall abt typs args a .
     ( ABT Term abt
     , typs ~ UnLCs args
     , args ~ LCs typs )
  => MeasureOp typs a
  -> SArgs abt args
  -> (CExpr -> CodeGen ())


flattenMeasureOp Uniform =
  \(a :* b :* End) ->
    \loc ->
      do (aId:bId:[]) <- replicateM 2 genIdent
         let aE = CVar aId
             bE = CVar bId
         declare SReal aId
         declare SReal bId
         flattenABT a aE
         flattenABT b bE
         uniformCG aE bE loc


flattenMeasureOp Normal  =
  \(a :* b :* End) ->
    \loc ->
      do (aId:bId:[]) <- replicateM 2 genIdent
         let aE = CVar aId
             bE = CVar bId
         declare SReal aId
         declare SReal bId
         flattenABT a aE
         flattenABT b bE
         normalCG aE (exp bE) loc


flattenMeasureOp Gamma =
  \(a :* b :* End) ->
    \loc ->
      do (aId:bId:[]) <- replicateM 2 genIdent
         let aE = CVar aId
             bE = CVar bId
         declare SReal aId
         declare SReal bId
         flattenABT a aE
         flattenABT b bE
         gammaCG (exp aE) bE loc


flattenMeasureOp Beta =
  \(a :* b :* End) -> flattenABT (HKP.beta'' a b)


flattenMeasureOp Categorical = \(arr :* End) ->
  \loc ->
    do arrId <- genIdent
       declare (typeOf arr) arrId
       let arrE = CVar arrId
       flattenABT arr arrE

       itId <- genIdent' "it"
       declare SInt itId
       let itE = CVar itId

       wSumId <- genIdent' "ws"
       declare SProb wSumId
       let wSumE = CVar wSumId
       assign wSumId (log (intE 0))

       let currE = indirect (arrayData arrE .+. itE)
           cond  = itE .<. (arraySize arrE)
           inc   = CUnary CPostIncOp itE

       isPar <- isParallel
       mkSequential

       -- first calculate the max weight
       forCG (itE .=. (intE 0)) cond inc $
         logSumExpCG (S.fromList [wSumE,currE]) wSumE

       -- draw number from uniform(0, weightSum)
       rId <- genIdent' "r"
       declare SReal rId
       let r    = CCast doubleDecl rand
           rMax = CCast doubleDecl (CVar . Ident $ "RAND_MAX")
           rE = CVar rId
       assign rId ((r ./. rMax) .*. (exp wSumE))

       assign wSumId (log (intE 0))
       assign itId (intE 0)
       whileCG (intE 1)
         $ do stat <- runCodeGenBlock $
                        do putExprStat $ mdataPtrWeight loc .=. (intE 0)
                           putExprStat $ mdataPtrSample loc .=. itE
                           putStat CBreak
              putStat $ CIf (rE .<. (exp wSumE)) stat Nothing
              logSumExpCG (S.fromList [wSumE,currE]) wSumE
              putExprStat $ CUnary CPostIncOp itE

       when isPar mkParallel


flattenMeasureOp x = error $ "TODO: flattenMeasureOp: " ++ show x

---------------
-- Superpose --
---------------

flattenSuperpose
    :: (ABT Term abt)
    => NE.NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
    -> (CExpr -> CodeGen ())

-- do we need to normalize?
flattenSuperpose pairs =
  let pairs' = NE.toList pairs in

  if length pairs' == 1
  then \loc -> let (w,m) = head pairs' in
         do mId <- genIdent
            wId <- genIdent
            declare (typeOf m) mId
            declare SProb wId
            let mE = address . CVar $ mId
                wE = CVar wId
            flattenABT w wE
            flattenABT m mE
            putExprStat $ mdataPtrWeight loc .=. ((mdataPtrWeight mE) .+. wE)
            putExprStat $ mdataPtrSample loc .=. (mdataPtrSample mE)

  else \loc ->
         do wEs <- forM pairs' $ \(w,_) ->
                     do wId <- genIdent' "w"
                        declare SProb wId
                        let wE = CVar wId
                        flattenABT w wE
                        return wE

            wSumId <- genIdent' "ws"
            declare SProb wSumId
            let wSumE = CVar wSumId
            logSumExpCG (S.fromList wEs) wSumE

            -- draw number from uniform(0, weightSum)
            rId <- genIdent' "r"
            declare SReal rId
            let r    = CCast doubleDecl rand
                rMax = CCast doubleDecl (CVar . Ident $ "RAND_MAX")
                rE = CVar rId
            assign rId ((r ./. rMax) .*. (exp wSumE))

            -- an iterator for picking a measure
            itId <- genIdent' "it"
            declare SProb itId
            let itE = CVar itId
            assign itId (log (intE 0))

            -- an output measure to assign to
            outId <- genIdent' "out"
            declare (typeOf . snd . head $ pairs') outId
            let outE = address $ CVar outId

            outLabel <- genIdent' "exit"

            forM_ (zip wEs pairs')
              $ \(wE,(_,m)) ->
                  do logSumExpCG (S.fromList [itE,wE]) itE
                     stat <- runCodeGenBlock (flattenABT m outE >> putStat (CGoto outLabel))
                     putStat $ CIf (rE .<. (exp itE)) stat Nothing

            putStat $ CLabel outLabel (CExpr Nothing)
            putExprStat $ mdataPtrWeight loc .=. ((mdataPtrWeight outE) .+. wSumE)
            putExprStat $ mdataPtrSample loc .=. (mdataPtrSample outE)