--------------------------------------------------------------------------------
-- Copyright © 2011 National Institute of Aerospace / Galois, Inc.
--------------------------------------------------------------------------------

{-# LANGUAGE Safe #-}

-- | Implementation of an array that uses type literals to store length. No
-- explicit indexing is used for the input data. Supports arbitrary nesting of
-- arrays.

{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}

module Copilot.Core.Type.Array
  ( Array
  , array
  , flatten
  , size
  , Flatten
  , InnerType
  , arrayelems
  ) where

import GHC.TypeLits     (Nat, KnownNat, natVal)
import Data.Proxy       (Proxy (..))

data Array (n :: Nat) t where
  Array :: [t] -> Array n t

instance Show t => Show (Array n t) where
  show :: Array n t -> String
show (Array [t]
xs) = [t] -> String
forall a. Show a => a -> String
show [t]
xs

array :: forall n t. KnownNat n => [t] -> Array n t
array :: [t] -> Array n t
array [t]
xs | Int
datalen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
typelen = [t] -> Array n t
forall t (n :: Nat). [t] -> Array n t
Array [t]
xs
         | Bool
otherwise          = String -> Array n t
forall a. HasCallStack => String -> a
error String
errmsg where
  datalen :: Int
datalen = [t] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [t]
xs
  typelen :: Int
typelen = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)
  errmsg :: String
errmsg = String
"Length of data (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
datalen String -> ShowS
forall a. [a] -> [a] -> [a]
++
           String
") does not match length of type (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
typelen String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")."


type family InnerType x where
  InnerType (Array _ x) = InnerType x
  InnerType x           = x


class Flatten a b where
  flatten :: Array n a -> [b]

instance Flatten a a where
  flatten :: Array n a -> [a]
flatten (Array [a]
xs) = [a]
xs

instance Flatten a b => Flatten (Array n a) b where
  flatten :: Array n (Array n a) -> [b]
flatten (Array [Array n a]
xss) = [[b]] -> [b]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[b]] -> [b]) -> [[b]] -> [b]
forall a b. (a -> b) -> a -> b
$ (Array n a -> [b]) -> [Array n a] -> [[b]]
forall a b. (a -> b) -> [a] -> [b]
map Array n a -> [b]
forall a b (n :: Nat). Flatten a b => Array n a -> [b]
flatten [Array n a]
xss

instance Foldable (Array n) where
  foldr :: (a -> b -> b) -> b -> Array n a -> b
foldr a -> b -> b
f b
base (Array [a]
xs) = (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
base [a]
xs


size :: forall a n b. (Flatten a b, b ~ InnerType a) => Array n a -> Int
size :: Array n a -> Int
size Array n a
xs = [b] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([b] -> Int) -> [b] -> Int
forall a b. (a -> b) -> a -> b
$ (Array n a -> [b]
forall a b (n :: Nat). Flatten a b => Array n a -> [b]
flatten Array n a
xs :: [b])

arrayelems :: Array n a -> [a]
arrayelems :: Array n a -> [a]
arrayelems (Array [a]
xs) = [a]
xs