{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.V.Linear.Internal.V
  ( V(..)
  , FunN
  , theLength
  , elim
  , make
  , iterate
  -- * Type-level utilities
  , caseNat
  ) where

import Data.Kind (Type)
import Data.Proxy
import Data.Type.Equality
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import GHC.Exts (Constraint, proxy#)
import GHC.TypeLits
import Prelude
  ( Eq
  , Ord
  , Int
  , Bool(..)
  , Either(..)
  , Maybe(..)
  , fromIntegral
  , error
  , (-))
import qualified Prelude as Prelude
import Prelude.Linear.Internal
import qualified Unsafe.Linear as Unsafe

{- Developers Note

See the "Developers Note" in Data.V.Linear for an explanation of this module
structure.

-}

-- # Type Definitions
-------------------------------------------------------------------------------

newtype V (n :: Nat) (a :: Type) = V (Vector a)
  deriving (V n a -> V n a -> Bool
(V n a -> V n a -> Bool) -> (V n a -> V n a -> Bool) -> Eq (V n a)
forall (n :: Nat) a. Eq a => V n a -> V n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: V n a -> V n a -> Bool
$c/= :: forall (n :: Nat) a. Eq a => V n a -> V n a -> Bool
== :: V n a -> V n a -> Bool
$c== :: forall (n :: Nat) a. Eq a => V n a -> V n a -> Bool
Eq, Eq (V n a)
Eq (V n a)
-> (V n a -> V n a -> Ordering)
-> (V n a -> V n a -> Bool)
-> (V n a -> V n a -> Bool)
-> (V n a -> V n a -> Bool)
-> (V n a -> V n a -> Bool)
-> (V n a -> V n a -> V n a)
-> (V n a -> V n a -> V n a)
-> Ord (V n a)
V n a -> V n a -> Bool
V n a -> V n a -> Ordering
V n a -> V n a -> V n a
forall {n :: Nat} {a}. Ord a => Eq (V n a)
forall (n :: Nat) a. Ord a => V n a -> V n a -> Bool
forall (n :: Nat) a. Ord a => V n a -> V n a -> Ordering
forall (n :: Nat) a. Ord a => V n a -> V n a -> V n a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: V n a -> V n a -> V n a
$cmin :: forall (n :: Nat) a. Ord a => V n a -> V n a -> V n a
max :: V n a -> V n a -> V n a
$cmax :: forall (n :: Nat) a. Ord a => V n a -> V n a -> V n a
>= :: V n a -> V n a -> Bool
$c>= :: forall (n :: Nat) a. Ord a => V n a -> V n a -> Bool
> :: V n a -> V n a -> Bool
$c> :: forall (n :: Nat) a. Ord a => V n a -> V n a -> Bool
<= :: V n a -> V n a -> Bool
$c<= :: forall (n :: Nat) a. Ord a => V n a -> V n a -> Bool
< :: V n a -> V n a -> Bool
$c< :: forall (n :: Nat) a. Ord a => V n a -> V n a -> Bool
compare :: V n a -> V n a -> Ordering
$ccompare :: forall (n :: Nat) a. Ord a => V n a -> V n a -> Ordering
Ord, (forall a b. (a -> b) -> V n a -> V n b)
-> (forall a b. a -> V n b -> V n a) -> Functor (V n)
forall (n :: Nat) a b. a -> V n b -> V n a
forall (n :: Nat) a b. (a -> b) -> V n a -> V n b
forall a b. a -> V n b -> V n a
forall a b. (a -> b) -> V n a -> V n b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> V n b -> V n a
$c<$ :: forall (n :: Nat) a b. a -> V n b -> V n a
fmap :: forall a b. (a -> b) -> V n a -> V n b
$cfmap :: forall (n :: Nat) a b. (a -> b) -> V n a -> V n b
Prelude.Functor)
  -- Using vector rather than, say, 'Array' (or directly 'Array#') because it
  -- offers many convenience function. Since all these unsafeCoerces probably
  -- kill the fusion rules, it may be worth it going lower level since I
  -- probably have to write my own fusion anyway. Therefore, starting from
  -- Vectors at the moment.

type family FunN (n :: Nat) (a :: Type) (b :: Type) :: Type where
  FunN 0 a b = b
  FunN n a b = a %1-> FunN (n-1) a b

-- # API
-------------------------------------------------------------------------------

theLength :: forall n. KnownNat n => Int
theLength :: forall (n :: Nat). KnownNat n => Int
theLength = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat). KnownNat n => Proxy# n -> Integer
natVal' @n (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @_))

split :: 1 <= n => V n a %1-> (# a, V (n-1) a #)
split :: forall (n :: Nat) a. (1 <= n) => V n a %1 -> (# a, V (n - 1) a #)
split = (V n a -> (# a, V (n - 1) a #))
%1 -> V n a %1 -> (# a, V (n - 1) a #)
forall a b (p :: Multiplicity). (a %p -> b) %1 -> a %1 -> b
Unsafe.toLinear V n a -> (# a, V (n - 1) a #)
forall (n :: Nat) a. (1 <= n) => V n a -> (# a, V (n - 1) a #)
split'
  where
    split' :: 1 <= n => V n a -> (# a, V (n-1) a #)
    split' :: forall (n :: Nat) a. (1 <= n) => V n a -> (# a, V (n - 1) a #)
split' (V Vector a
xs) = (# Vector a -> a
forall a. Vector a -> a
Vector.head Vector a
xs, Vector a -> V (n - 1) a
forall (n :: Nat) a. Vector a -> V n a
V (Vector a -> Vector a
forall a. Vector a -> Vector a
Vector.tail Vector a
xs) #)

consumeV :: V 0 a %1-> b %1-> b
consumeV :: forall a b. V 0 a %1 -> b %1 -> b
consumeV = (V 0 a -> b %1 -> b) %1 -> V 0 a %1 -> b %1 -> b
forall a b (p :: Multiplicity). (a %p -> b) %1 -> a %1 -> b
Unsafe.toLinear (\V 0 a
_ -> b %1 -> b
forall a. a %1 -> a
id)

unsafeZero :: n :~: 0
unsafeZero :: forall (n :: Nat). n :~: 0
unsafeZero = (Any :~: Any) %1 -> n :~: 0
forall a b. a %1 -> b
Unsafe.coerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

unsafeNonZero :: (1 <=? n) :~: 'True
unsafeNonZero :: forall (n :: Nat). (1 <=? n) :~: 'True
unsafeNonZero = (Any :~: Any) %1 -> (1 <=? n) :~: 'True
forall a b. a %1 -> b
Unsafe.coerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

-- Same as in the constraints library, but it's just as easy to avoid a
-- dependency here.
data Dict (c :: Constraint) where
  Dict :: c => Dict c

predNat :: forall n. (1 <= n, KnownNat n) => Dict (KnownNat (n-1))
predNat :: forall (n :: Nat). (1 <= n, KnownNat n) => Dict (KnownNat (n - 1))
predNat = case Integer -> Maybe SomeNat
someNatVal (forall (n :: Nat). KnownNat n => Proxy# n -> Integer
natVal' @n (forall (a :: Nat). Proxy# a
forall {k} (a :: k). Proxy# a
proxy# @_) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1) of
  Just (SomeNat (Proxy n
_ :: Proxy p)) -> Dict (KnownNat n) %1 -> Dict (KnownNat (n - 1))
forall a b. a %1 -> b
Unsafe.coerce (forall (c :: Constraint). c => Dict c
Dict @(KnownNat p))
  Maybe SomeNat
Nothing -> [Char] -> Dict (KnownNat (n - 1))
forall a. HasCallStack => [Char] -> a
error [Char]
"Vector.pred: n-1 is necessarily a Nat, if 1<=n"

caseNat :: forall n. KnownNat n => Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat :: forall (n :: Nat).
KnownNat n =>
Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat =
  case forall (n :: Nat). KnownNat n => Int
theLength @n of
    Int
0 -> (n :~: 0) %1 -> Either (n :~: 0) ((1 <=? n) :~: 'True)
forall a b. a -> Either a b
Left ((n :~: 0) %1 -> Either (n :~: 0) ((1 <=? n) :~: 'True))
%1 -> (n :~: 0) %1 -> Either (n :~: 0) ((1 <=? n) :~: 'True)
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ forall (n :: Nat). n :~: 0
unsafeZero @n
    Int
_ -> ((1 <=? n) :~: 'True) %1 -> Either (n :~: 0) ((1 <=? n) :~: 'True)
forall a b. b -> Either a b
Right (((1 <=? n) :~: 'True)
 %1 -> Either (n :~: 0) ((1 <=? n) :~: 'True))
%1 -> ((1 <=? n) :~: 'True)
%1 -> Either (n :~: 0) ((1 <=? n) :~: 'True)
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ forall (n :: Nat). (1 <=? n) :~: 'True
unsafeNonZero @n
{-# INLINE caseNat #-}

-- By definition.
expandFunN :: forall n a b. (1 <= n) => FunN n a b %1-> a %1-> FunN (n-1) a b
expandFunN :: forall (n :: Nat) a b.
(1 <= n) =>
FunN n a b %1 -> a %1 -> FunN (n - 1) a b
expandFunN FunN n a b
k = FunN n a b %1 -> a %1 -> FunN (n - 1) a b
forall a b. a %1 -> b
Unsafe.coerce FunN n a b
k

-- By definition.
contractFunN :: (1 <= n) => (a %1-> FunN (n-1) a b) %1-> FunN n a b
contractFunN :: forall (n :: Nat) a b.
(1 <= n) =>
(a %1 -> FunN (n - 1) a b) %1 -> FunN n a b
contractFunN a %1 -> FunN (n - 1) a b
k = (a %1 -> FunN (n - 1) a b) %1 -> FunN n a b
forall a b. a %1 -> b
Unsafe.coerce a %1 -> FunN (n - 1) a b
k

-- TODO: consider using template haskell to make this expression more efficient.
-- | This is like pattern-matching on a n-tuple. It will eventually be
-- polymorphic the same way as a case expression.
elim :: forall n a b. KnownNat n => V n a %1-> FunN n a b %1-> b
elim :: forall (n :: Nat) a b. KnownNat n => V n a %1 -> FunN n a b %1 -> b
elim V n a
xs FunN n a b
f =
  case forall (n :: Nat).
KnownNat n =>
Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat @n of
    Left n :~: 0
Refl -> V 0 a %1 -> b %1 -> b
forall a b. V 0 a %1 -> b %1 -> b
consumeV V n a
V 0 a
xs b
FunN n a b
f
    Right (1 <=? n) :~: 'True
Refl -> (1 <= n) => (# a, V (n - 1) a #) %1 -> FunN n a b %1 -> b
(# a, V (n - 1) a #) %1 -> FunN n a b %1 -> b
elimS (V n a %1 -> (# a, V (n - 1) a #)
forall (n :: Nat) a. (1 <= n) => V n a %1 -> (# a, V (n - 1) a #)
split V n a
xs) FunN n a b
f
  where
    elimS :: 1 <= n => (# a, V (n-1) a #) %1-> FunN n a b %1-> b
    elimS :: (1 <= n) => (# a, V (n - 1) a #) %1 -> FunN n a b %1 -> b
elimS (# a
x, V (n - 1) a
xs' #) FunN n a b
g = case forall (n :: Nat). (1 <= n, KnownNat n) => Dict (KnownNat (n - 1))
predNat @n of
      Dict (KnownNat (n - 1))
Dict -> V (n - 1) a %1 -> FunN (n - 1) a b %1 -> b
forall (n :: Nat) a b. KnownNat n => V n a %1 -> FunN n a b %1 -> b
elim V (n - 1) a
xs' (forall (n :: Nat) a b.
(1 <= n) =>
FunN n a b %1 -> a %1 -> FunN (n - 1) a b
expandFunN @n @a @b FunN n a b
g a
x)

-- XXX: This can probably be improved a lot.
make :: forall n a. KnownNat n => FunN n a (V n a)
make :: forall (n :: Nat) a. KnownNat n => FunN n a (V n a)
make = case forall (n :: Nat).
KnownNat n =>
Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat @n of
          Left n :~: 0
Refl -> Vector a -> V 0 a
forall (n :: Nat) a. Vector a -> V n a
V Vector a
forall a. Vector a
Vector.empty
          Right (1 <=? n) :~: 'True
Refl -> forall (n :: Nat) a b.
(1 <= n) =>
(a %1 -> FunN (n - 1) a b) %1 -> FunN n a b
contractFunN @n @a @(V n a) a %1 -> FunN (n - 1) a (V n a)
prepend
            where prepend :: a %1-> FunN (n-1) a (V n a)
                  prepend :: a %1 -> FunN (n - 1) a (V n a)
prepend a
t = case forall (n :: Nat). (1 <= n, KnownNat n) => Dict (KnownNat (n - 1))
predNat @n of
                                Dict (KnownNat (n - 1))
Dict -> forall (n :: Nat) a b c.
KnownNat n =>
(b %1 -> c) %1 -> FunN n a b %1 -> FunN n a c
continue @(n-1) @a @(V (n-1) a) (a %1 -> V (n - 1) a %1 -> V n a
forall (n :: Nat) a. a %1 -> V (n - 1) a %1 -> V n a
cons a
t) (forall (n :: Nat) a. KnownNat n => FunN n a (V n a)
make @(n-1) @a)

cons :: forall n a. a %1-> V (n-1) a %1-> V n a
cons :: forall (n :: Nat) a. a %1 -> V (n - 1) a %1 -> V n a
cons = (a -> V (n - 1) a -> V n a) %1 -> a %1 -> V (n - 1) a %1 -> V n a
forall a b c (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b %q -> c) %1 -> a %1 -> b %1 -> c
Unsafe.toLinear2 ((a -> V (n - 1) a -> V n a) %1 -> a %1 -> V (n - 1) a %1 -> V n a)
%1 -> (a -> V (n - 1) a -> V n a)
%1 -> a
%1 -> V (n - 1) a
%1 -> V n a
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ \a
x (V Vector a
v) -> Vector a -> V n a
forall (n :: Nat) a. Vector a -> V n a
V (a -> Vector a -> Vector a
forall a. a -> Vector a -> Vector a
Vector.cons a
x Vector a
v)

continue :: forall n a b c. KnownNat n => (b %1-> c) %1-> FunN n a b %1-> FunN n a c
continue :: forall (n :: Nat) a b c.
KnownNat n =>
(b %1 -> c) %1 -> FunN n a b %1 -> FunN n a c
continue = case forall (n :: Nat).
KnownNat n =>
Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat @n of
             Left n :~: 0
Refl -> (b %1 -> c) %1 -> FunN n a b %1 -> FunN n a c
forall a. a %1 -> a
id
             Right (1 <=? n) :~: 'True
Refl -> \b %1 -> c
f FunN n a b
t -> forall (n :: Nat) a b.
(1 <= n) =>
(a %1 -> FunN (n - 1) a b) %1 -> FunN n a b
contractFunN @n @a @c ((KnownNat n, 1 <= n) =>
(b %1 -> c)
%1 -> (a %1 -> FunN (n - 1) a b) %1 -> a %1 -> FunN (n - 1) a c
(b %1 -> c)
%1 -> (a %1 -> FunN (n - 1) a b) %1 -> a %1 -> FunN (n - 1) a c
continueS b %1 -> c
f (forall (n :: Nat) a b.
(1 <= n) =>
FunN n a b %1 -> a %1 -> FunN (n - 1) a b
expandFunN @n @a @b FunN n a b
t))
               where continueS :: (KnownNat n, 1 <= n) => (b %1-> c) %1-> (a %1-> FunN (n-1) a b) %1-> (a %1-> FunN (n-1) a c)
                     continueS :: (KnownNat n, 1 <= n) =>
(b %1 -> c)
%1 -> (a %1 -> FunN (n - 1) a b) %1 -> a %1 -> FunN (n - 1) a c
continueS b %1 -> c
f' a %1 -> FunN (n - 1) a b
x a
a = case forall (n :: Nat). (1 <= n, KnownNat n) => Dict (KnownNat (n - 1))
predNat @n of Dict (KnownNat (n - 1))
Dict -> forall (n :: Nat) a b c.
KnownNat n =>
(b %1 -> c) %1 -> FunN n a b %1 -> FunN n a c
continue @(n-1) @a @b b %1 -> c
f' (a %1 -> FunN (n - 1) a b
x a
a)

iterate :: forall n a. (KnownNat n, 1 <= n) => (a %1-> (a, a)) -> a %1-> V n a
iterate :: forall (n :: Nat) a.
(KnownNat n, 1 <= n) =>
(a %1 -> (a, a)) -> a %1 -> V n a
iterate a %1 -> (a, a)
dup a
init =
  forall (m :: Nat). (KnownNat m, 1 <= m) => a %1 -> V m a
go @n a
init
 where
  go :: forall m. (KnownNat m, 1 <= m) => a %1-> V m a
  go :: forall (m :: Nat). (KnownNat m, 1 <= m) => a %1 -> V m a
go a
a =
    case forall (n :: Nat). (1 <= n, KnownNat n) => Dict (KnownNat (n - 1))
predNat @m of
      Dict (KnownNat (m - 1))
Dict -> case forall (n :: Nat).
KnownNat n =>
Either (n :~: 0) ((1 <=? n) :~: 'True)
caseNat @(m-1) of
        Prelude.Left (m - 1) :~: 0
Refl ->
          case forall (k :: Nat). (0 :~: (k - 1)) -> k :~: 1
pr1 @m 0 :~: (m - 1)
forall {k} (a :: k). a :~: a
Refl of
            m :~: 1
Refl ->
              (forall (n :: Nat) a. KnownNat n => FunN n a (V n a)
make @m @a :: a %1-> V m a) a
a
        Prelude.Right (1 <=? (m - 1)) :~: 'True
Refl ->
          a %1 -> (a, a)
dup a
a (a, a) %1 -> ((a, a) %1 -> V m a) %1 -> V m a
forall a b. a %1 -> (a %1 -> b) %1 -> b
& \(a
a', a
a'') ->
            a
a' a %1 -> V (m - 1) a %1 -> V m a
forall (n :: Nat) a. a %1 -> V (n - 1) a %1 -> V n a
`cons` forall (m :: Nat). (KnownNat m, 1 <= m) => a %1 -> V m a
go @(m-1) a
a''

  -- An unsafe cast to prove the simple equality.
  pr1 :: forall k. 0 :~: (k - 1) -> k :~: 1
  pr1 :: forall (k :: Nat). (0 :~: (k - 1)) -> k :~: 1
pr1 0 :~: (k - 1)
Refl = (Any :~: Any) %1 -> k :~: 1
forall a b. a %1 -> b
Unsafe.coerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl