{-# LANGUAGE CPP
           , DataKinds
           , PolyKinds
           , GADTs
           , TypeOperators
           , EmptyCase
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.Value where

import           Language.Hakaru.Syntax.IClasses
import           Language.Hakaru.Syntax.Datum
import           Language.Hakaru.Types.HClasses
import           Language.Hakaru.Types.DataKind
import           Language.Hakaru.Types.Coercion
import           Language.Hakaru.Types.Sing

import           Data.STRef

import qualified Data.Vector                     as V
import qualified Data.Number.LogFloat            as LF
import           Data.Number.Natural

import qualified System.Random.MWC               as MWC

data Value :: Hakaru -> * where
     VNat     ::                !Natural -> Value 'HNat
     VInt     ::                !Integer -> Value 'HInt
     VProb    :: {-# UNPACK #-} !LF.LogFloat -> Value 'HProb
     VReal    :: {-# UNPACK #-} !Double -> Value 'HReal

     VDatum   :: !(Datum Value (HData' t)) -> Value (HData' t)

     -- Assuming you want to consider lambdas/closures to be values.
     -- N.B., the type below is larger than is correct; that is,
     VLam     :: (Value a -> Value b) -> Value (a ':-> b)

     -- Measures hold their importance weight and random seed
     VMeasure :: (Value 'HProb ->
                  MWC.GenIO    ->
                  IO (Maybe (Value a, Value 'HProb))
                 ) -> Value ('HMeasure a)
     VArray   :: {-# UNPACK #-} !(V.Vector (Value a)) -> Value ('HArray a)

instance Eq1 Value where
    eq1 :: Value i -> Value i -> Bool
eq1 (VNat  Natural
a) (VNat  Natural
b)   = Natural
a Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
b
    eq1 (VInt  Integer
a) (VInt  Integer
b)   = Integer
a Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
b
    eq1 (VProb LogFloat
a) (VProb LogFloat
b)   = LogFloat
a LogFloat -> LogFloat -> Bool
forall a. Eq a => a -> a -> Bool
== LogFloat
b
    eq1 (VReal Double
a) (VReal Double
b)   = Double
a Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
b
    eq1 (VDatum Datum Value (HData' t)
a) (VDatum Datum Value (HData' t)
b) = Datum Value (HData' t)
a Datum Value (HData' t) -> Datum Value (HData' t) -> Bool
forall a. Eq a => a -> a -> Bool
== Datum Value (HData' t)
Datum Value (HData' t)
b
    eq1 (VArray Vector (Value a)
a) (VArray Vector (Value a)
b) = Vector (Value a)
a Vector (Value a) -> Vector (Value a) -> Bool
forall a. Eq a => a -> a -> Bool
== Vector (Value a)
Vector (Value a)
b
    eq1 Value i
_        Value i
_            = Bool
False

instance Eq (Value a) where
    == :: Value a -> Value a -> Bool
(==) = Value a -> Value a -> Bool
forall k (a :: k -> *) (i :: k). Eq1 a => a i -> a i -> Bool
eq1

instance Show1 Value where
    showsPrec1 :: Int -> Value i -> ShowS
showsPrec1 Int
p (VNat   Natural
v)   = Int -> Natural -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec  Int
p Natural
v
    showsPrec1 Int
p (VInt   Integer
v)   = Int -> Integer -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec  Int
p Integer
v
    showsPrec1 Int
p (VProb  LogFloat
v)   = Int -> LogFloat -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec  Int
p LogFloat
v
    showsPrec1 Int
p (VReal  Double
v)   = Int -> Double -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec  Int
p Double
v
    showsPrec1 Int
p (VDatum Datum Value (HData' t)
d)   = Int -> Datum Value (HData' t) -> ShowS
forall k (a :: k -> *) (i :: k). Show1 a => Int -> a i -> ShowS
showsPrec1 Int
p Datum Value (HData' t)
d
    showsPrec1 Int
_ (VLam   Value a -> Value b
_)   = String -> ShowS
showString String
"<function>"
    showsPrec1 Int
_ (VMeasure Value 'HProb -> GenIO -> IO (Maybe (Value a, Value 'HProb))
_) = String -> ShowS
showString String
"<measure>"
    showsPrec1 Int
p (VArray Vector (Value a)
e)   = Int -> Vector (Value a) -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec  Int
p Vector (Value a)
e

instance Show (Value a) where
    showsPrec :: Int -> Value a -> ShowS
showsPrec = Int -> Value a -> ShowS
forall k (a :: k -> *) (i :: k). Show1 a => Int -> a i -> ShowS
showsPrec1
    show :: Value a -> String
show      = Value a -> String
forall k (a :: k -> *) (i :: k). Show1 a => a i -> String
show1

instance Coerce Value where
    coerceTo :: Coercion a b -> Value a -> Value b
coerceTo   Coercion a b
CNil         Value a
v = Value a
Value b
v
    coerceTo   (CCons PrimCoercion a b
c Coercion b b
cs) Value a
v = Coercion b b -> Value b -> Value b
forall (f :: Hakaru -> *) (a :: Hakaru) (b :: Hakaru).
Coerce f =>
Coercion a b -> f a -> f b
coerceTo Coercion b b
cs (PrimCoercion a b -> Value a -> Value b
forall (f :: Hakaru -> *) (a :: Hakaru) (b :: Hakaru).
PrimCoerce f =>
PrimCoercion a b -> f a -> f b
primCoerceTo PrimCoercion a b
c Value a
v)

    coerceFrom :: Coercion a b -> Value b -> Value a
coerceFrom Coercion a b
CNil         Value b
v = Value a
Value b
v
    coerceFrom (CCons PrimCoercion a b
c Coercion b b
cs) Value b
v = PrimCoercion a b -> Value b -> Value a
forall (f :: Hakaru -> *) (a :: Hakaru) (b :: Hakaru).
PrimCoerce f =>
PrimCoercion a b -> f b -> f a
primCoerceFrom PrimCoercion a b
c (Coercion b b -> Value b -> Value b
forall (f :: Hakaru -> *) (a :: Hakaru) (b :: Hakaru).
Coerce f =>
Coercion a b -> f b -> f a
coerceFrom Coercion b b
cs Value b
v)

instance PrimCoerce Value where
    primCoerceTo :: PrimCoercion a b -> Value a -> Value b
primCoerceTo PrimCoercion a b
c Value a
l =
        case (PrimCoercion a b
c,Value a
l) of
        (Signed HRing b
HRing_Int,            VNat  Natural
a) -> Integer -> Value 'HInt
VInt  (Integer -> Value 'HInt) -> Integer -> Value 'HInt
forall a b. (a -> b) -> a -> b
$ Natural -> Integer
fromNatural Natural
a
        (Signed HRing b
HRing_Real,           VProb LogFloat
a) -> Double -> Value 'HReal
VReal (Double -> Value 'HReal) -> Double -> Value 'HReal
forall a b. (a -> b) -> a -> b
$ LogFloat -> Double
LF.fromLogFloat LogFloat
a
        (Continuous HContinuous b
HContinuous_Prob, VNat  Natural
a) ->
            LogFloat -> Value 'HProb
VProb (LogFloat -> Value 'HProb) -> LogFloat -> Value 'HProb
forall a b. (a -> b) -> a -> b
$ Double -> LogFloat
LF.logFloat (Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Natural -> Integer
fromNatural Natural
a) :: Double)
        (Continuous HContinuous b
HContinuous_Real, VInt  Integer
a) -> Double -> Value 'HReal
VReal (Double -> Value 'HReal) -> Double -> Value 'HReal
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
a

    primCoerceFrom :: PrimCoercion a b -> Value b -> Value a
primCoerceFrom PrimCoercion a b
c Value b
l =
        case (PrimCoercion a b
c,Value b
l) of
        (Signed HRing b
HRing_Int,            VInt  Integer
a) -> Natural -> Value 'HNat
VNat  (Natural -> Value 'HNat) -> Natural -> Value 'HNat
forall a b. (a -> b) -> a -> b
$ Integer -> Natural
unsafeNatural Integer
a
        (Signed HRing b
HRing_Real,           VReal Double
a) -> LogFloat -> Value 'HProb
VProb (LogFloat -> Value 'HProb) -> LogFloat -> Value 'HProb
forall a b. (a -> b) -> a -> b
$ Double -> LogFloat
LF.logFloat Double
a
        (Continuous HContinuous b
HContinuous_Prob, VProb LogFloat
a) ->
            Natural -> Value 'HNat
VNat (Natural -> Value 'HNat) -> Natural -> Value 'HNat
forall a b. (a -> b) -> a -> b
$ Integer -> Natural
unsafeNatural (Integer -> Natural) -> Integer -> Natural
forall a b. (a -> b) -> a -> b
$ Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor (LogFloat -> Double
LF.fromLogFloat LogFloat
a :: Double)
        (Continuous HContinuous b
HContinuous_Real, VReal Double
a) -> Integer -> Value 'HInt
VInt  (Integer -> Value 'HInt) -> Integer -> Value 'HInt
forall a b. (a -> b) -> a -> b
$ Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
a


lam2 :: Value (a ':-> b ':-> c) -> (Value a -> Value b -> Value c)
lam2 :: Value (a ':-> (b ':-> c)) -> Value a -> Value b -> Value c
lam2 (VLam Value a -> Value b
f1) Value a
v1 =
    case Value a -> Value b
f1 Value a
Value a
v1 of
    VLam Value a -> Value b
f2 -> Value b -> Value c
Value a -> Value b
f2

enumFromUntilValue
    :: (HDiscrete a)
    -> Value a
    -> Value a
    -> [Value a]
enumFromUntilValue :: HDiscrete a -> Value a -> Value a -> [Value a]
enumFromUntilValue HDiscrete a
_ (VNat Natural
lo) (VNat Natural
hi) = (Natural -> Value 'HNat) -> [Natural] -> [Value 'HNat]
forall a b. (a -> b) -> [a] -> [b]
map Natural -> Value 'HNat
VNat ([Natural] -> [Natural]
forall a. [a] -> [a]
init (Natural -> Natural -> [Natural]
forall a. Enum a => a -> a -> [a]
enumFromTo Natural
lo Natural
hi))
enumFromUntilValue HDiscrete a
_ (VInt Integer
lo) (VInt Integer
hi) = (Integer -> Value 'HInt) -> [Integer] -> [Value 'HInt]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> Value 'HInt
VInt ([Integer] -> [Integer]
forall a. [a] -> [a]
init (Integer -> Integer -> [Integer]
forall a. Enum a => a -> a -> [a]
enumFromTo Integer
lo Integer
hi))
enumFromUntilValue HDiscrete a
_ Value a
_         Value a
_         = String -> [Value a]
forall a. HasCallStack => String -> a
error String
"Tried to iterate over a non-iterable value"

data VReducer :: * -> Hakaru -> * where
     VRed_Num    :: STRef s (Value a)
                 -> VReducer s a
     VRed_Unit   :: VReducer s HUnit
     VRed_Pair   :: Sing a
                 -> Sing b
                 -> VReducer s a
                 -> VReducer s b
                 -> VReducer s (HPair a b)
     VRed_Array  :: V.Vector (VReducer s a)
                 -> VReducer s ('HArray a)