{-# LANGUAGE TypeFamilies #-}
{- |
Apply operations on symbolic arrays to physical ones.
-}
module Data.Array.Knead.Symbolic.Render (
   run,
   MarshalExp(..),
   MapFilter(..),
   FilterOuter(..),
   Scatter(..),
   ScatterMaybe(..),
   MapAccumLSimple(..),
   MapAccumLSequence(..),
   MapAccumL(..),
   FoldOuterL(..),
   AddDimension(..),
   ) where

import qualified Data.Array.Knead.Symbolic.Render.Basic as Render
import qualified Data.Array.Knead.Symbolic.Render.Argument as Arg
import qualified Data.Array.Knead.Symbolic.PhysicalParametric as PhysP
import qualified Data.Array.Knead.Symbolic.Physical as Phys
import qualified Data.Array.Knead.Symbolic.Private as Core
import qualified Data.Array.Knead.Shape as Shape
import Data.Array.Knead.Symbolic.PhysicalParametric
         (MapFilter, FilterOuter,
          MapAccumLSimple, MapAccumLSequence, MapAccumL, FoldOuterL,
          Scatter, ScatterMaybe, AddDimension)

import qualified LLVM.DSL.Render.Run as Run
import LLVM.DSL.Expression (Exp)

import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue

import Prelude2010
import Prelude ()



class C f where
   type Plain f
   build :: (Marshal.C p) => Run.T IO p f (Plain f)

instance
   (Marshal.C sh, Shape.C sh, Storable.C a) =>
      C (Core.Array sh a) where
   type Plain (Core.Array sh a) = IO (Phys.Array sh a)
   build = Run.Cons PhysP.render

instance
   (Shape.Sequence n, Marshal.C n,
    Storable.C b, MultiValue.C b) =>
      C (MapFilter n a b) where
   type Plain (MapFilter n a b) = IO (Phys.Array n b)
   build = Run.Cons PhysP.mapFilter

instance
   (Shape.Sequence n, Marshal.C n,
    Shape.C sh, Marshal.C sh,
    Storable.C a, MultiValue.C a) =>
      C (FilterOuter n sh a) where
   type Plain (FilterOuter n sh a) = IO (Phys.Array (n,sh) a)
   build = Run.Cons PhysP.filterOuter

instance
   (Shape.C sh0, Marshal.C sh0,
    Shape.C sh1, Marshal.C sh1,
    Storable.C a, MultiValue.C a) =>
      C (Scatter sh0 sh1 a) where
   type Plain (Scatter sh0 sh1 a) = IO (Phys.Array sh1 a)
   build = Run.Cons PhysP.scatter

instance
   (Shape.C sh0, Marshal.C sh0,
    Shape.C sh1, Marshal.C sh1,
    Storable.C a, MultiValue.C a) =>
      C (ScatterMaybe sh0 sh1 a) where
   type Plain (ScatterMaybe sh0 sh1 a) = IO (Phys.Array sh1 a)
   build = Run.Cons PhysP.scatterMaybe

instance
   (Shape.C sh, Marshal.C sh,
    Shape.C n, Marshal.C n,
    MultiValue.C acc,
    Storable.C a, MultiValue.C a,
    Storable.C b, MultiValue.C b) =>
      C (MapAccumLSimple sh n acc a b) where
   type Plain (MapAccumLSimple sh n acc a b) = IO (Phys.Array (sh,n) b)
   build = Run.Cons PhysP.mapAccumLSimple

instance
   (Shape.C n, Marshal.C n,
    MultiValue.C acc,
    Storable.C final, MultiValue.C final,
    Storable.C a, MultiValue.C a,
    Storable.C b, MultiValue.C b) =>
      C (MapAccumLSequence n acc final a b) where
   type Plain (MapAccumLSequence n acc final a b) = IO (final, Phys.Array n b)
   build = Run.Cons PhysP.mapAccumLSequence

instance
   (Shape.C sh, Marshal.C sh,
    Shape.C n, Marshal.C n,
    MultiValue.C acc,
    Storable.C final, MultiValue.C final,
    Storable.C a, MultiValue.C a,
    Storable.C b, MultiValue.C b) =>
      C (MapAccumL sh n acc final a b) where
   type Plain (MapAccumL sh n acc final a b) =
            IO (Phys.Array sh final, Phys.Array (sh,n) b)
   build = Run.Cons PhysP.mapAccumL

instance
   (Shape.C n, Marshal.C n,
    Shape.C sh, Marshal.C sh,
    Storable.C a, MultiValue.C a,
    Storable.C b, MultiValue.C b) =>
      C (FoldOuterL n sh a b) where
   type Plain (FoldOuterL n sh a b) = IO (Phys.Array sh a)
   build = Run.Cons PhysP.foldOuterL

instance
   (Shape.C sh, Marshal.C sh,
    Shape.C n, Marshal.C n,
    Storable.C b, MultiValue.C b) =>
      C (AddDimension sh n a b) where
   type Plain (AddDimension sh n a b) = IO (Phys.Array (sh,n) b)
   build = Run.Cons PhysP.addDimension


instance (Storable.C a, MultiValue.C a) => C (Exp a) where
   type Plain (Exp a) = IO a
   build = Render.storable

newtype MarshalExp a = MarshalExp {getMarshalExp :: Exp a}

instance (Marshal.C a) => C (MarshalExp a) where
   type Plain (MarshalExp a) = IO a
   build = Run.premapDSL getMarshalExp Render.marshal

instance (Argument arg, C func) => C (arg -> func) where
   type Plain (arg -> func) = PlainArg arg -> Plain func
   build = buildArg Render.*-> build


class Argument a where
   type PlainArg a
   buildArg :: Arg.T (PlainArg a) a

instance Argument () where
   type PlainArg () = ()
   buildArg = Arg.unit

instance
   (Shape.C sh, Marshal.C sh, Storable.C a) =>
      Argument (Core.Array sh a) where
   type PlainArg (Core.Array sh a) = Phys.Array sh a
   buildArg = Arg.array

instance (Marshal.C a) => Argument (Exp a) where
   type PlainArg (Exp a) = a
   buildArg = Arg.primitive

instance (Argument a, Argument b) => Argument (a,b) where
   type PlainArg (a,b) = (PlainArg a, PlainArg b)
   buildArg = Arg.pair buildArg buildArg

instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where
   type PlainArg (a,b,c) = (PlainArg a, PlainArg b, PlainArg c)
   buildArg = Arg.triple buildArg buildArg buildArg



run :: (C f) => f -> IO (Plain f)
run = Render.run build