{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.IO.Codec.Serialise
-- Copyright   : [2012..2021] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Orphan instance for binary serialisation of 'Array'
--

module Data.Array.Accelerate.IO.Codec.Serialise ()
  where

import Codec.Serialise
import Codec.Serialise.Decoding
import Codec.Serialise.Encoding

import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Representation.Shape                   ( ShapeR(..) )
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Type
import qualified Data.Array.Accelerate.Representation.Array         as R
import qualified Data.Array.Accelerate.Representation.Shape         as R

import Data.Array.Accelerate.IO.Data.ByteString


instance (Shape sh, Elt e) => Serialise (Array sh e) where
  {-# INLINE encode #-}
  encode :: Array sh e -> Encoding
encode arr :: Array sh e
arr@(Array Array (EltR sh) (EltR e)
arrR)
    =  Word -> Encoding
encodeListLen (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ShapeR (EltR sh) -> Int
forall sh. ShapeR sh -> Int
R.rank (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)) Word -> Word -> Word
forall a. Num a => a -> a -> a
+ TypeR (EltR e) -> Word
forall t. TypeR t -> Word
fieldsArrayR (Elt e => TypeR (EltR e)
forall a. Elt a => TypeR (EltR a)
eltR @e))
    Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> ShapeR (EltR sh) -> EltR sh -> Encoding
forall t. ShapeR t -> t -> Encoding
encodeShapeR (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh) (Array (EltR sh) (EltR e) -> EltR sh
forall sh e. Array sh e -> sh
R.shape Array (EltR sh) (EltR e)
arrR)
    Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> TypeR (EltR e) -> ByteStrings (EltR e) -> Encoding
forall t. TypeR t -> ByteStrings t -> Encoding
encodeArrayR (Elt e => TypeR (EltR e)
forall a. Elt a => TypeR (EltR a)
eltR @e) (Array sh e -> ByteStrings (EltR e)
forall sh e.
(Shape sh, Elt e) =>
Array sh e -> ByteStrings (EltR e)
toByteStrings Array sh e
arr)
    where
      encodeShapeR :: ShapeR t -> t -> Encoding
      encodeShapeR :: ShapeR t -> t -> Encoding
encodeShapeR ShapeR t
ShapeRz          ()       = Encoding
forall a. Monoid a => a
mempty
      encodeShapeR (ShapeRsnoc ShapeR sh1
shR) (sh, sz) = ShapeR sh1 -> sh1 -> Encoding
forall t. ShapeR t -> t -> Encoding
encodeShapeR ShapeR sh1
shR sh1
sh Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> Int -> Encoding
encodeInt Int
sz

      encodeArrayR :: TypeR t -> ByteStrings t -> Encoding
      encodeArrayR :: TypeR t -> ByteStrings t -> Encoding
encodeArrayR TypeR t
TupRunit           ()       = Encoding
forall a. Monoid a => a
mempty
      encodeArrayR (TupRpair TupR ScalarType a1
aR1 TupR ScalarType b
aR2) (a1, a2) = TupR ScalarType a1 -> ByteStrings a1 -> Encoding
forall t. TypeR t -> ByteStrings t -> Encoding
encodeArrayR TupR ScalarType a1
aR1 ByteStrings a1
a1 Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> TupR ScalarType b -> ByteStrings b -> Encoding
forall t. TypeR t -> ByteStrings t -> Encoding
encodeArrayR TupR ScalarType b
aR2 ByteStrings b
a2
      encodeArrayR (TupRsingle ScalarType t
aR)    ByteStrings t
a        = ScalarType t -> ByteStrings t -> Encoding
forall t. ScalarType t -> ByteStrings t -> Encoding
scalar ScalarType t
aR ByteStrings t
a
        where
          scalar :: ScalarType t -> ByteStrings t -> Encoding
          scalar :: ScalarType t -> ByteStrings t -> Encoding
scalar (SingleScalarType SingleType t
t) = SingleType t -> ByteStrings t -> Encoding
forall t. SingleType t -> ByteStrings t -> Encoding
single SingleType t
t
          scalar (VectorScalarType VectorType (Vec n a1)
t) = VectorType (Vec n a1) -> ByteStrings (Vec n a1) -> Encoding
forall t. VectorType t -> ByteStrings t -> Encoding
vector VectorType (Vec n a1)
t

          vector :: VectorType t -> ByteStrings t -> Encoding
          vector :: VectorType t -> ByteStrings t -> Encoding
vector (VectorType Int
_ SingleType a1
t)
            | SingleArrayDict a1
SingleArrayDict <- SingleType a1 -> SingleArrayDict a1
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType a1
t
            = SingleType a1 -> ByteStrings a1 -> Encoding
forall t. SingleType t -> ByteStrings t -> Encoding
single SingleType a1
t

          single :: SingleType t -> ByteStrings t -> Encoding
          single :: SingleType t -> ByteStrings t -> Encoding
single (NumSingleType NumType t
t) = NumType t -> ByteStrings t -> Encoding
forall t. NumType t -> ByteStrings t -> Encoding
num NumType t
t

          num :: NumType t -> ByteStrings t -> Encoding
          num :: NumType t -> ByteStrings t -> Encoding
num (IntegralNumType IntegralType t
t) = IntegralType t -> ByteStrings t -> Encoding
forall t. IntegralType t -> ByteStrings t -> Encoding
integral IntegralType t
t
          num (FloatingNumType FloatingType t
t) = FloatingType t -> ByteStrings t -> Encoding
forall t. FloatingType t -> ByteStrings t -> Encoding
floating FloatingType t
t

          integral :: IntegralType t -> ByteStrings t -> Encoding
          integral :: IntegralType t -> ByteStrings t -> Encoding
integral IntegralType t
TypeInt    = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeInt8   = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeInt16  = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeInt32  = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeInt64  = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeWord   = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeWord8  = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeWord16 = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeWord32 = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          integral IntegralType t
TypeWord64 = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes

          floating :: FloatingType t -> ByteStrings t -> Encoding
          floating :: FloatingType t -> ByteStrings t -> Encoding
floating FloatingType t
TypeHalf   = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          floating FloatingType t
TypeFloat  = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes
          floating FloatingType t
TypeDouble = ByteString -> Encoding
ByteStrings t -> Encoding
encodeBytes

      fieldsArrayR :: TypeR t -> Word
      fieldsArrayR :: TypeR t -> Word
fieldsArrayR TypeR t
TupRunit           = Word
0
      fieldsArrayR TupRsingle{}       = Word
1
      fieldsArrayR (TupRpair TupR ScalarType a1
eR1 TupR ScalarType b
eR2) = TupR ScalarType a1 -> Word
forall t. TypeR t -> Word
fieldsArrayR TupR ScalarType a1
eR1 Word -> Word -> Word
forall a. Num a => a -> a -> a
+ TupR ScalarType b -> Word
forall t. TypeR t -> Word
fieldsArrayR TupR ScalarType b
eR2

  {-# INLINE decode #-}
  decode :: Decoder s (Array sh e)
decode = do
    Int
_  <- Decoder s Int
forall s. Decoder s Int
decodeListLen
    EltR sh
sh <- ShapeR (EltR sh) -> Decoder s (EltR sh)
forall t s. ShapeR t -> Decoder s t
decodeShapeR (Shape sh => ShapeR (EltR sh)
forall sh. Shape sh => ShapeR (EltR sh)
shapeR @sh)
    ByteStrings (EltR e)
bs <- TypeR (EltR e) -> Decoder s (ByteStrings (EltR e))
forall t s. TypeR t -> Decoder s (ByteStrings t)
decodeArrayR (Elt e => TypeR (EltR e)
forall a. Elt a => TypeR (EltR a)
eltR @e)
    Array sh e -> Decoder s (Array sh e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Array sh e -> Decoder s (Array sh e))
-> Array sh e -> Decoder s (Array sh e)
forall a b. (a -> b) -> a -> b
$! sh -> ByteStrings (EltR e) -> Array sh e
forall sh e.
(Shape sh, Elt e) =>
sh -> ByteStrings (EltR e) -> Array sh e
fromByteStrings (EltR sh -> sh
forall a. Elt a => EltR a -> a
toElt EltR sh
sh) ByteStrings (EltR e)
bs
    where
      decodeShapeR :: ShapeR t -> Decoder s t
      decodeShapeR :: ShapeR t -> Decoder s t
decodeShapeR ShapeR t
ShapeRz          = () -> Decoder s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      decodeShapeR (ShapeRsnoc ShapeR sh1
shR) = do
        sh1
sh <- ShapeR sh1 -> Decoder s sh1
forall t s. ShapeR t -> Decoder s t
decodeShapeR ShapeR sh1
shR
        Int
sz <- Decoder s Int
forall s. Decoder s Int
decodeInt
        (sh1, Int) -> Decoder s (sh1, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (sh1
sh, Int
sz)

      decodeArrayR :: TypeR t -> Decoder s (ByteStrings t)
      decodeArrayR :: TypeR t -> Decoder s (ByteStrings t)
decodeArrayR TypeR t
TupRunit           = () -> Decoder s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      decodeArrayR (TupRpair TupR ScalarType a1
aR1 TupR ScalarType b
aR2) = do
        ByteStrings a1
a1 <- TupR ScalarType a1 -> Decoder s (ByteStrings a1)
forall t s. TypeR t -> Decoder s (ByteStrings t)
decodeArrayR TupR ScalarType a1
aR1
        ByteStrings b
a2 <- TupR ScalarType b -> Decoder s (ByteStrings b)
forall t s. TypeR t -> Decoder s (ByteStrings t)
decodeArrayR TupR ScalarType b
aR2
        (ByteStrings a1, ByteStrings b)
-> Decoder s (ByteStrings a1, ByteStrings b)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteStrings a1
a1, ByteStrings b
a2)
      decodeArrayR (TupRsingle ScalarType t
aR) = ScalarType t -> Decoder s (ByteStrings t)
forall t s. ScalarType t -> Decoder s (ByteStrings t)
scalar ScalarType t
aR
        where
          scalar :: ScalarType t -> Decoder s (ByteStrings t)
          scalar :: ScalarType t -> Decoder s (ByteStrings t)
scalar (SingleScalarType SingleType t
t) = SingleType t -> Decoder s (ByteStrings t)
forall t s. SingleType t -> Decoder s (ByteStrings t)
single SingleType t
t
          scalar (VectorScalarType VectorType (Vec n a1)
t) = VectorType (Vec n a1) -> Decoder s (ByteStrings (Vec n a1))
forall t s. VectorType t -> Decoder s (ByteStrings t)
vector VectorType (Vec n a1)
t

          vector :: VectorType t -> Decoder s (ByteStrings t)
          vector :: VectorType t -> Decoder s (ByteStrings t)
vector (VectorType Int
_ SingleType a1
t)
            | SingleArrayDict a1
SingleArrayDict <- SingleType a1 -> SingleArrayDict a1
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType a1
t
            = SingleType a1 -> Decoder s (ByteStrings a1)
forall t s. SingleType t -> Decoder s (ByteStrings t)
single SingleType a1
t

          single :: SingleType t -> Decoder s (ByteStrings t)
          single :: SingleType t -> Decoder s (ByteStrings t)
single (NumSingleType NumType t
t) = NumType t -> Decoder s (ByteStrings t)
forall t s. NumType t -> Decoder s (ByteStrings t)
num NumType t
t

          num :: NumType t -> Decoder s (ByteStrings t)
          num :: NumType t -> Decoder s (ByteStrings t)
num (IntegralNumType IntegralType t
t) = IntegralType t -> Decoder s (ByteStrings t)
forall t s. IntegralType t -> Decoder s (ByteStrings t)
integral IntegralType t
t
          num (FloatingNumType FloatingType t
t) = FloatingType t -> Decoder s (ByteStrings t)
forall t s. FloatingType t -> Decoder s (ByteStrings t)
floating FloatingType t
t

          integral :: IntegralType t -> Decoder s (ByteStrings t)
          integral :: IntegralType t -> Decoder s (ByteStrings t)
integral IntegralType t
TypeInt    = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeInt8   = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeInt16  = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeInt32  = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeInt64  = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeWord   = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeWord8  = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeWord16 = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeWord32 = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          integral IntegralType t
TypeWord64 = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes

          floating :: FloatingType t -> Decoder s (ByteStrings t)
          floating :: FloatingType t -> Decoder s (ByteStrings t)
floating FloatingType t
TypeHalf   = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          floating FloatingType t
TypeFloat  = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes
          floating FloatingType t
TypeDouble = Decoder s (ByteStrings t)
forall s. Decoder s ByteString
decodeBytes