{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Knead.Shape.Orphan where

import qualified Data.Array.Knead.Expression as Expr

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape
         (ZeroBased(ZeroBased), Range(Range), Shifted(Shifted),
          Enumeration(Enumeration))

import qualified LLVM.Extra.Multi.Value.Memory as MultiMem
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))

import Prelude2010
import Prelude ()



unzipZeroBased :: MultiValue.T (ZeroBased n) -> ZeroBased (MultiValue.T n)
unzipZeroBased (MultiValue.Cons (ZeroBased n)) = ZeroBased (MultiValue.Cons n)

zeroBasedSize :: (Expr.Value val) => val (ZeroBased n) -> val n
zeroBasedSize = Expr.lift1 $ Shape.zeroBasedSize . unzipZeroBased

zeroBased :: (Expr.Value val) => val n -> val (ZeroBased n)
zeroBased = Expr.lift1 $ \(MultiValue.Cons n) -> MultiValue.Cons (ZeroBased n)

instance (MultiValue.C n) => MultiValue.C (ZeroBased n) where
   type Repr f (ZeroBased n) = ZeroBased (MultiValue.Repr f n)
   cons (ZeroBased n) = zeroBased (MultiValue.cons n)
   undef = zeroBased MultiValue.undef
   zero = zeroBased MultiValue.zero
   phis bb = Monad.lift zeroBased . MultiValue.phis bb . zeroBasedSize
   addPhis bb a b = MultiValue.addPhis bb (zeroBasedSize a) (zeroBasedSize b)

type instance
   MultiValue.Decomposed f (ZeroBased pn) =
      ZeroBased (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (ZeroBased pn) =
      ZeroBased (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (ZeroBased n) where
   type Composed (ZeroBased n) = ZeroBased (MultiValue.Composed n)
   compose (ZeroBased n) = zeroBased (MultiValue.compose n)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (ZeroBased pn) where
   decompose (ZeroBased p) sh =
      MultiValue.decompose p <$> unzipZeroBased sh

instance (Expr.Compose n) => Expr.Compose (ZeroBased n) where
   type Composed (ZeroBased n) = ZeroBased (Expr.Composed n)
   compose (ZeroBased n) = Expr.lift1 zeroBased (Expr.compose n)

instance (Expr.Decompose pn) => Expr.Decompose (ZeroBased pn) where
   decompose (ZeroBased p) = ZeroBased . Expr.decompose p . zeroBasedSize

instance (MultiMem.C n) => MultiMem.C (ZeroBased n) where
   type Struct (ZeroBased n) = MultiMem.Struct n
   decompose = fmap zeroBased . MultiMem.decompose
   compose = MultiMem.compose . zeroBasedSize



singletonRange :: n -> Range n
singletonRange n = Range n n

unzipRange :: MultiValue.T (Range n) -> Range (MultiValue.T n)
unzipRange (MultiValue.Cons (Range from to)) =
   Range (MultiValue.Cons from) (MultiValue.Cons to)

zipRange :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Range n)
zipRange (MultiValue.Cons from) (MultiValue.Cons to) =
   MultiValue.Cons (Range from to)

instance (MultiValue.C n) => MultiValue.C (Range n) where
   type Repr f (Range n) = Range (MultiValue.Repr f n)
   cons (Range from to) = zipRange (MultiValue.cons from) (MultiValue.cons to)
   undef = MultiValue.compose $ singletonRange MultiValue.undef
   zero = MultiValue.compose $ singletonRange MultiValue.zero
   phis bb a =
      case unzipRange a of
         Range a0 a1 ->
            Monad.lift2 zipRange (MultiValue.phis bb a0) (MultiValue.phis bb a1)
   addPhis bb a b =
      case (unzipRange a, unzipRange b) of
         (Range a0 a1, Range b0 b1) ->
            MultiValue.addPhis bb a0 b0 >>
            MultiValue.addPhis bb a1 b1

type instance
   MultiValue.Decomposed f (Range pn) = Range (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (Range pn) = Range (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where
   type Composed (Range n) = Range (MultiValue.Composed n)
   compose (Range from to) =
      zipRange (MultiValue.compose from) (MultiValue.compose to)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where
   decompose (Range pfrom pto) rng =
      case unzipRange rng of
         Range from to ->
            Range
               (MultiValue.decompose pfrom from)
               (MultiValue.decompose pto to)

instance (MultiMem.C n) => MultiMem.C (Range n) where
   type Struct (Range n) = PairStruct n
   decompose = fmap (uncurry zipRange) . decomposeGen
   compose x = case unzipRange x of Range n m -> composeGen n m



singletonShifted :: n -> Shifted n
singletonShifted n = Shifted n n

unzipShifted :: MultiValue.T (Shifted n) -> Shifted (MultiValue.T n)
unzipShifted (MultiValue.Cons (Shifted from to)) =
   Shifted (MultiValue.Cons from) (MultiValue.Cons to)

zipShifted :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Shifted n)
zipShifted (MultiValue.Cons from) (MultiValue.Cons to) =
   MultiValue.Cons (Shifted from to)

instance (MultiValue.C n) => MultiValue.C (Shifted n) where
   type Repr f (Shifted n) = Shifted (MultiValue.Repr f n)
   cons (Shifted start len) =
      zipShifted (MultiValue.cons start) (MultiValue.cons len)
   undef = MultiValue.compose $ singletonShifted MultiValue.undef
   zero = MultiValue.compose $ singletonShifted MultiValue.zero
   phis bb a =
      case unzipShifted a of
         Shifted a0 a1 ->
            Monad.lift2 zipShifted
               (MultiValue.phis bb a0) (MultiValue.phis bb a1)
   addPhis bb a b =
      case (unzipShifted a, unzipShifted b) of
         (Shifted a0 a1, Shifted b0 b1) ->
            MultiValue.addPhis bb a0 b0 >>
            MultiValue.addPhis bb a1 b1

type instance
   MultiValue.Decomposed f (Shifted pn) =
      Shifted (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (Shifted pn) =
      Shifted (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where
   type Composed (Shifted n) = Shifted (MultiValue.Composed n)
   compose (Shifted start len) =
      zipShifted (MultiValue.compose start) (MultiValue.compose len)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where
   decompose (Shifted pstart plen) rng =
      case unzipShifted rng of
         Shifted start len ->
            Shifted
               (MultiValue.decompose pstart start)
               (MultiValue.decompose plen len)

instance (MultiMem.C n) => MultiMem.C (Shifted n) where
   type Struct (Shifted n) = PairStruct n
   decompose = fmap (uncurry zipShifted) . decomposeGen
   compose x = case unzipShifted x of Shifted n m -> composeGen n m



type PairStruct n = LLVM.Struct (MultiMem.Struct n, (MultiMem.Struct n, ()))

decomposeGen ::
   (MultiMem.C n) =>
   LLVM.Value (PairStruct n) ->
   LLVM.CodeGenFunction r (MultiValue.T n, MultiValue.T n)
decomposeGen nm =
   Monad.lift2 (,)
      (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d0)
      (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d1)

composeGen ::
   (MultiMem.C n) =>
   MultiValue.T n -> MultiValue.T n ->
   LLVM.CodeGenFunction r (LLVM.Value (PairStruct n))
composeGen n m = do
   sn <- MultiMem.compose n
   sm <- MultiMem.compose m
   rn <- LLVM.insertvalue (LLVM.value LLVM.undef) sn TypeNum.d0
   LLVM.insertvalue rn sm TypeNum.d1



instance (Enum enum, Bounded enum) => MultiValue.C (Enumeration enum) where
   type Repr f (Enumeration enum) = ()
   cons = MultiValue.consUnit
   undef = MultiValue.undefUnit
   zero = MultiValue.zeroUnit
   phis = MultiValue.phisUnit
   addPhis = MultiValue.addPhisUnit

type instance MultiValue.Decomposed f (Enumeration enum) = Enumeration enum
type instance MultiValue.PatternTuple (Enumeration enum) = Enumeration enum

instance
      (Enum enum, Bounded enum) => MultiValue.Compose (Enumeration enum) where
   type Composed (Enumeration enum) = Enumeration enum
   compose = MultiValue.cons

instance MultiValue.Decompose (Enumeration enum) where
   decompose Enumeration _ = Enumeration


instance (Enum enum, Bounded enum) => Expr.Compose (Enumeration enum) where
   type Composed (Enumeration enum) = Enumeration enum
   compose = Expr.cons

instance Expr.Decompose (Enumeration enum) where
   decompose Enumeration _ = Enumeration

instance (Enum enum, Bounded enum) => MultiMem.C (Enumeration enum) where
   type Struct (Enumeration enum) = LLVM.Struct ()
   load = MultiMem.loadUnit
   store = MultiMem.storeUnit
   decompose = MultiMem.decomposeUnit
   compose = MultiMem.composeUnit