{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{- |
Exponential curve with controllable delay.
-}
module Synthesizer.LLVM.Generator.Exponential2 (
   Parameter,
   parameter,
   parameterPlain,
   causalP,

   ParameterPacked,
   parameterPacked,
   parameterPackedPlain,
   causalPackedP,
   ) where

import qualified Synthesizer.LLVM.CausalParameterized.Process as CausalP
import qualified Synthesizer.LLVM.Simple.Value as Value
import qualified Synthesizer.LLVM.Parameter as Param
import qualified Synthesizer.LLVM.Frame.SerialVector as Serial
import qualified Synthesizer.LLVM.CausalParameterized.Functional as F

import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Class as Class
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Class (MakeValueTuple, ValueTuple, Undefined, undefTuple, )

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, IsArithmetic, IsPrimitive, IsFloating, IsSized, SizeOf,
    CodeGenFunction, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal.Number ((:*:), )

import Foreign.Storable (Storable, )
import qualified Foreign.Storable
-- import qualified Foreign.Storable.Record as Store
import qualified Foreign.Storable.Traversable as Store

import qualified Control.Applicative as App
import qualified Data.Foldable as Fold
import qualified Data.Traversable as Trav
import Control.Applicative (liftA2, (<*>), )
import Control.Arrow (arr, (^<<), (&&&), )
import Control.Monad (liftM2, )

import qualified Algebra.Transcendental as Trans
-- import qualified Algebra.Field as Field
-- import qualified Algebra.Ring as Ring

import NumericPrelude.Numeric
import NumericPrelude.Base


newtype Parameter a = Parameter a
   deriving (Show, Storable)


instance Functor Parameter where
   {-# INLINE fmap #-}
   fmap f (Parameter k) = Parameter (f k)

instance App.Applicative Parameter where
   {-# INLINE pure #-}
   pure x = Parameter x
   {-# INLINE (<*>) #-}
   Parameter f <*> Parameter k =
      Parameter (f k)

instance Fold.Foldable Parameter where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable Parameter where
   {-# INLINE sequenceA #-}
   sequenceA (Parameter k) =
      fmap Parameter k


instance (Phi a) => Phi (Parameter a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (Parameter a) where
   undefTuple = Class.undefTuplePointed

instance Class.Zero a => Class.Zero (Parameter a) where
   zeroTuple = Class.zeroTuplePointed

instance (Memory.C a) => Memory.C (Parameter a) where
   type Struct (Parameter a) = Memory.Struct a
   load = Memory.loadNewtype Parameter
   store = Memory.storeNewtype (\(Parameter k) -> k)
   decompose = Memory.decomposeNewtype Parameter
   compose = Memory.composeNewtype (\(Parameter k) -> k)


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (Parameter a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (Parameter a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (MakeValueTuple a) => MakeValueTuple (Parameter a) where
   type ValueTuple (Parameter a) = Parameter (Class.ValueTuple a)
   valueTupleOf = Class.valueTupleOfFunctor


instance (Value.Flatten a) => Value.Flatten (Parameter a) where
   type Registers (Parameter a) = Parameter (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable


instance (Vector.Simple v) => Vector.Simple (Parameter v) where
   type Element (Parameter v) = Parameter (Vector.Element v)
   type Size (Parameter v) = Vector.Size v
   shuffleMatch = Vector.shuffleMatchTraversable
   extract = Vector.extractTraversable

instance (Vector.C v) => Vector.C (Parameter v) where
   insert  = Vector.insertTraversable


parameter ::
   (Trans.C a, SoV.TranscendentalConstant a, IsFloating a) =>
   Value a ->
   CodeGenFunction r (Parameter (Value a))
parameter = Value.unlift1 parameterPlain

parameterPlain ::
   (Trans.C a) =>
   a -> Parameter a
parameterPlain halfLife =
   Parameter $ 0.5 ** recip halfLife


causalP ::
   (Storable a, MakeValueTuple a, ValueTuple a ~ al,
    Memory.C al, A.PseudoRing al) =>
   Param.T p a ->
   CausalP.T p (Parameter al) al
causalP initial =
   CausalP.loop initial
      (arr snd &&& CausalP.zipWithSimple (\(Parameter a) -> A.mul a))


data ParameterPacked a =
   ParameterPacked {ppFeedback, ppCurrent :: a}


instance Functor ParameterPacked where
   {-# INLINE fmap #-}
   fmap f p = ParameterPacked
      (f $ ppFeedback p) (f $ ppCurrent p)

instance App.Applicative ParameterPacked where
   {-# INLINE pure #-}
   pure x = ParameterPacked x x
   {-# INLINE (<*>) #-}
   f <*> p = ParameterPacked
      (ppFeedback f $ ppFeedback p)
      (ppCurrent f $ ppCurrent p)

instance Fold.Foldable ParameterPacked where
   {-# INLINE foldMap #-}
   foldMap = Trav.foldMapDefault

instance Trav.Traversable ParameterPacked where
   {-# INLINE sequenceA #-}
   sequenceA p =
      liftA2 ParameterPacked
         (ppFeedback p) (ppCurrent p)


instance (Phi a) => Phi (ParameterPacked a) where
   phis = Class.phisTraversable
   addPhis = Class.addPhisFoldable

instance Undefined a => Undefined (ParameterPacked a) where
   undefTuple = Class.undefTuplePointed

instance Class.Zero a => Class.Zero (ParameterPacked a) where
   zeroTuple = Class.zeroTuplePointed


{-
storeParameter ::
   Storable a => Store.Dictionary (ParameterPacked a)
storeParameter =
   Store.run $
   liftA2 ParameterPacked
      (Store.element ppFeedback)
      (Store.element ppCurrent)

instance Storable a => Storable (ParameterPacked a) where
   sizeOf    = Store.sizeOf storeParameter
   alignment = Store.alignment storeParameter
   peek      = Store.peek storeParameter
   poke      = Store.poke storeParameter
-}

instance Storable a => Storable (ParameterPacked a) where
   sizeOf    = Store.sizeOf
   alignment = Store.alignment
   peek      = Store.peekApplicative
   poke      = Store.poke


type ParameterPackedStruct a = LLVM.Struct (a, (a, ()))

memory ::
   (Memory.C a) =>
   Memory.Record r (ParameterPackedStruct (Memory.Struct a)) (ParameterPacked a)
memory =
   liftA2 ParameterPacked
      (Memory.element ppFeedback TypeNum.d0)
      (Memory.element ppCurrent  TypeNum.d1)

instance (Memory.C a) => Memory.C (ParameterPacked a) where
   type Struct (ParameterPacked a) = ParameterPackedStruct (Memory.Struct a)
   load = Memory.loadRecord memory
   store = Memory.storeRecord memory
   decompose = Memory.decomposeRecord memory
   compose = Memory.composeRecord memory


{-
instance LLVM.ValueTuple a => LLVM.ValueTuple (ParameterPacked a) where
   buildTuple f = Class.buildTupleTraversable (LLVM.buildTuple f)

instance LLVM.IsTuple a => LLVM.IsTuple (ParameterPacked a) where
   tupleDesc = Class.tupleDescFoldable
-}

instance (MakeValueTuple a) => MakeValueTuple (ParameterPacked a) where
   type ValueTuple (ParameterPacked a) = ParameterPacked (Class.ValueTuple a)
   valueTupleOf = Class.valueTupleOfFunctor


instance (Value.Flatten a) => Value.Flatten (ParameterPacked a) where
   type Registers (ParameterPacked a) = ParameterPacked (Value.Registers a)
   flattenCode = Value.flattenCodeTraversable
   unfoldCode = Value.unfoldCodeTraversable

type instance F.Arguments f (ParameterPacked a) = f (ParameterPacked a)
instance F.MakeArguments (ParameterPacked a) where
   makeArgs = id



withSize ::
   (TypeNum.Natural n) =>
   (Serial.C v, Serial.Size v ~ n, TypeNum.Positive n) =>
   (TypeNum.Singleton n -> m (param v)) ->
   m (param v)
withSize f = f TypeNum.singleton

parameterPacked ::
   (Serial.C v, Serial.Element v ~ a,
    A.PseudoRing v, A.RationalConstant v,
    A.Transcendental a, A.RationalConstant a) =>
   a -> CodeGenFunction r (ParameterPacked v)
parameterPacked halfLife = withSize $ \n -> do
   feedback <-
      Serial.upsample =<<
      A.pow (A.fromRational' 0.5) =<<
      A.fdiv (A.fromInteger' $ TypeNum.integralFromSingleton n) halfLife
   k <-
      A.pow (A.fromRational' 0.5) =<<
      A.fdiv (A.fromInteger' 1) halfLife
   current <-
      Serial.iterate (A.mul k) (A.fromInteger' 1)
   return $ ParameterPacked feedback current
{-
   Value.unlift1 parameterPackedPlain
-}

withSizePlain ::
   (TypeNum.Natural n) =>
   (TypeNum.Singleton n -> param (Serial.Plain n a)) ->
   param (Serial.Plain n a)
withSizePlain f = f TypeNum.singleton

parameterPackedPlain ::
   (Trans.C a,
    TypeNum.Positive n) =>
   a -> ParameterPacked (Serial.Plain n a)
parameterPackedPlain halfLife =
   withSizePlain $ \n ->
   ParameterPacked
      (Serial.replicate (0.5 ** (fromInteger (TypeNum.integralFromSingleton n) / halfLife)))
      (Serial.iteratePlain (0.5 ** recip halfLife *) one)


causalPackedP ::
   (IsArithmetic a, SoV.IntegerConstant a,
    Storable a, MakeValueTuple a, ValueTuple a ~ (Value a),
    Memory.FirstClass a, Memory.Stored a ~ am, IsSized am,
    IsPrimitive a, IsPrimitive am,
    TypeNum.Positive (n :*: SizeOf a),
    TypeNum.Positive (n :*: SizeOf am),
    TypeNum.Positive n) =>
   Param.T p a ->
   CausalP.T p (ParameterPacked (Serial.Value n a)) (Serial.Value n a)
causalPackedP initial =
   CausalP.loop
      (Serial.replicate ^<< initial)
      (CausalP.mapSimple $
       \(p, s0) -> liftM2 (,)
          (A.mul (ppCurrent p) s0)
          (A.mul (ppFeedback p) s0))