{-
Part of the code in this file comes from the parameterized-utils package:

Copyright (c) 2013-2022 Galois Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:

  * Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

  * Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in
    the documentation and/or other materials provided with the
    distribution.

  * Neither the name of Galois, Inc. nor the names of its contributors
    may be used to endorse or promote products derived from this
    software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

-- |
-- Module      :   Grisette.Internal.Utils.Parameterized
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Utils.Parameterized
  ( -- * Unsafe axiom
    unsafeAxiom,

    -- * Unparameterized type
    SomeNatRepr (..),
    SomePositiveNatRepr (..),

    -- * Runtime representation of type-level natural numbers
    NatRepr,
    withKnownNat,
    natValue,
    mkNatRepr,
    mkPositiveNatRepr,
    natRepr,
    decNat,
    predNat,
    incNat,
    addNat,
    subNat,
    divNat,
    halfNat,

    -- * Proof of KnownNat
    KnownProof (..),
    hasRepr,
    withKnownProof,
    unsafeKnownProof,
    knownAdd,

    -- * Proof of (<=) for type-level natural numbers
    LeqProof (..),
    withLeqProof,
    unsafeLeqProof,
    testLeq,
    leqRefl,
    leqSucc,
    leqTrans,
    leqZero,
    leqAdd2,
    leqAdd,
    leqAddPos,
  )
where

import Data.Typeable (Proxy (Proxy), type (:~:) (Refl))
import GHC.TypeNats
  ( Div,
    KnownNat,
    Nat,
    SomeNat (SomeNat),
    natVal,
    someNatVal,
    type (+),
    type (-),
    type (<=),
  )
import Numeric.Natural (Natural)
import Unsafe.Coerce (unsafeCoerce)

-- | Assert a proof of equality between two types.
-- This is unsafe if used improperly, so use this with caution!
unsafeAxiom :: forall a b. a :~: b
unsafeAxiom :: forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom = (a :~: a) -> a :~: b
forall a b. a -> b
unsafeCoerce (forall (a :: k). a :~: a
forall {k} (a :: k). a :~: a
Refl @a)

withKnownNat :: forall n r. NatRepr n -> ((KnownNat n) => r) -> r
withKnownNat :: forall (n :: Natural) r. NatRepr n -> (KnownNat n => r) -> r
withKnownNat (NatRepr Natural
nVal) KnownNat n => r
v =
  case Natural -> SomeNat
someNatVal Natural
nVal of
    SomeNat (Proxy n
Proxy :: Proxy n') ->
      case n :~: n
forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom :: n :~: n' of
        n :~: n
Refl -> r
KnownNat n => r
v

-- | A runtime representation of type-level natural numbers.
-- This can be used for performing dynamic checks on type-level natural numbers.
newtype NatRepr (n :: Nat) = NatRepr Natural

-- | The underlying runtime natural number value of a type-level natural number.
natValue :: NatRepr n -> Natural
natValue :: forall (n :: Natural). NatRepr n -> Natural
natValue (NatRepr Natural
n) = Natural
n

data SomeNatReprHelper where
  SomeNatReprHelper :: NatRepr n -> SomeNatReprHelper

data SomeNatRepr where
  SomeNatRepr :: (KnownNat n) => NatRepr n -> SomeNatRepr

-- | Turn a @Natural@ into the corresponding @NatRepr@ with the KnownNat
-- constraint.
mkNatRepr :: Natural -> SomeNatRepr
mkNatRepr :: Natural -> SomeNatRepr
mkNatRepr Natural
n = case NatRepr Any -> SomeNatReprHelper
forall (n :: Natural). NatRepr n -> SomeNatReprHelper
SomeNatReprHelper (Natural -> NatRepr Any
forall (n :: Natural). Natural -> NatRepr n
NatRepr Natural
n) of
  SomeNatReprHelper NatRepr n
natRepr -> NatRepr n -> (KnownNat n => SomeNatRepr) -> SomeNatRepr
forall (n :: Natural) r. NatRepr n -> (KnownNat n => r) -> r
withKnownNat NatRepr n
natRepr ((KnownNat n => SomeNatRepr) -> SomeNatRepr)
-> (KnownNat n => SomeNatRepr) -> SomeNatRepr
forall a b. (a -> b) -> a -> b
$ NatRepr n -> SomeNatRepr
forall (n :: Natural). KnownNat n => NatRepr n -> SomeNatRepr
SomeNatRepr NatRepr n
natRepr

data SomePositiveNatRepr where
  SomePositiveNatRepr ::
    (KnownNat n, 1 <= n) => NatRepr n -> SomePositiveNatRepr

-- | Turn a @NatRepr@ into the corresponding @NatRepr@ with the KnownNat
-- constraint and asserts that its greater than 0.
mkPositiveNatRepr :: Natural -> SomePositiveNatRepr
mkPositiveNatRepr :: Natural -> SomePositiveNatRepr
mkPositiveNatRepr Natural
0 = [Char] -> SomePositiveNatRepr
forall a. HasCallStack => [Char] -> a
error [Char]
"mkPositiveNatRepr: 0 is not a positive number"
mkPositiveNatRepr Natural
n = case Natural -> SomeNatRepr
mkNatRepr Natural
n of
  SomeNatRepr (NatRepr n
natRepr :: NatRepr n) -> case forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof @1 @n of
    LeqProof 1 n
LeqProof -> NatRepr n -> SomePositiveNatRepr
forall (n :: Natural).
(KnownNat n, 1 <= n) =>
NatRepr n -> SomePositiveNatRepr
SomePositiveNatRepr NatRepr n
natRepr

-- | Construct a runtime representation of a type-level natural number when its
-- runtime value is known.
natRepr :: forall n. (KnownNat n) => NatRepr n
natRepr :: forall (n :: Natural). KnownNat n => NatRepr n
natRepr = Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n))

-- | Decrement a 'NatRepr' by 1.
decNat :: (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat :: forall (n :: Natural). (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat (NatRepr Natural
n) = Natural -> NatRepr (n - 1)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1)

-- | Predecessor of a 'NatRepr'
predNat :: NatRepr (n + 1) -> NatRepr n
predNat :: forall (n :: Natural). NatRepr (n + 1) -> NatRepr n
predNat (NatRepr Natural
n) = Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1)

-- | Increment a 'NatRepr' by 1.
incNat :: NatRepr n -> NatRepr (n + 1)
incNat :: forall (n :: Natural). NatRepr n -> NatRepr (n + 1)
incNat (NatRepr Natural
n) = Natural -> NatRepr (n + 1)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Natural
1)

-- | Addition of two 'NatRepr's.
addNat :: NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat :: forall (m :: Natural) (n :: Natural).
NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat (NatRepr Natural
m) (NatRepr Natural
n) = Natural -> NatRepr (m + n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Natural
n)

-- | Subtraction of two 'NatRepr's.
subNat :: (n <= m) => NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat :: forall (n :: Natural) (m :: Natural).
(n <= m) =>
NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat (NatRepr Natural
m) (NatRepr Natural
n) = Natural -> NatRepr (m - n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
n)

-- | Division of two 'NatRepr's.
divNat :: (1 <= n) => NatRepr m -> NatRepr n -> NatRepr (Div m n)
divNat :: forall (n :: Natural) (m :: Natural).
(1 <= n) =>
NatRepr m -> NatRepr n -> NatRepr (Div m n)
divNat (NatRepr Natural
m) (NatRepr Natural
n) = Natural -> NatRepr (Div m n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
m Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`div` Natural
n)

-- | Half of a 'NatRepr'.
halfNat :: NatRepr (n + n) -> NatRepr n
halfNat :: forall (n :: Natural). NatRepr (n + n) -> NatRepr n
halfNat (NatRepr Natural
n) = Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Natural
n Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`div` Natural
2)

-- | @'KnownProof n'@ is a type whose values are only inhabited when @n@ has
-- a known runtime value.
data KnownProof (n :: Nat) where
  KnownProof :: (KnownNat n) => KnownProof n

-- | Introduces the 'KnownNat' constraint when it's proven.
withKnownProof :: KnownProof n -> ((KnownNat n) => r) -> r
withKnownProof :: forall (n :: Natural) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof KnownProof n
p KnownNat n => r
r = case KnownProof n
p of KnownProof n
KnownProof -> r
KnownNat n => r
r

-- | Construct a 'KnownProof' given the runtime value.
--
-- __Note:__ This function is unsafe, as it does not check that the runtime
-- representation is consistent with the type-level representation.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeKnownProof :: Natural -> KnownProof n
unsafeKnownProof :: forall (n :: Natural). Natural -> KnownProof n
unsafeKnownProof Natural
nVal = NatRepr n -> KnownProof n
forall (n :: Natural). NatRepr n -> KnownProof n
hasRepr (Natural -> NatRepr n
forall (n :: Natural). Natural -> NatRepr n
NatRepr Natural
nVal)

-- | Construct a 'KnownProof' given the runtime representation.
hasRepr :: forall n. NatRepr n -> KnownProof n
hasRepr :: forall (n :: Natural). NatRepr n -> KnownProof n
hasRepr (NatRepr Natural
nVal) =
  case Natural -> SomeNat
someNatVal Natural
nVal of
    SomeNat (Proxy n
Proxy :: Proxy n') ->
      case n :~: n
forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom :: n :~: n' of
        n :~: n
Refl -> KnownProof n
forall (n :: Natural). KnownNat n => KnownProof n
KnownProof

-- | Adding two type-level natural numbers with known runtime values gives a
-- type-level natural number with a known runtime value.
knownAdd :: forall m n. KnownProof m -> KnownProof n -> KnownProof (m + n)
knownAdd :: forall (m :: Natural) (n :: Natural).
KnownProof m -> KnownProof n -> KnownProof (m + n)
knownAdd KnownProof m
KnownProof KnownProof n
KnownProof = forall (n :: Natural). NatRepr n -> KnownProof n
hasRepr @(m + n) (Natural -> NatRepr (m + n)
forall (n :: Natural). Natural -> NatRepr n
NatRepr (Proxy m -> Natural
forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @m) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Proxy n -> Natural
forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal (forall (t :: Natural). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)))

-- | @'LeqProof m n'@ is a type whose values are only inhabited when @m <= n@.
data LeqProof (m :: Nat) (n :: Nat) where
  LeqProof :: (m <= n) => LeqProof m n

-- | Introduces the @m <= n@ constraint when it's proven.
withLeqProof :: LeqProof m n -> ((m <= n) => r) -> r
withLeqProof :: forall (m :: Natural) (n :: Natural) r.
LeqProof m n -> ((m <= n) => r) -> r
withLeqProof LeqProof m n
p (m <= n) => r
r = case LeqProof m n
p of LeqProof m n
LeqProof -> r
(m <= n) => r
r

-- | Construct a 'LeqProof'.
--
-- __Note:__ This function is unsafe, as it does not check that the left-hand
-- side is less than or equal to the right-hand side.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeLeqProof :: forall m n. LeqProof m n
unsafeLeqProof :: forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof = LeqProof 0 0 -> LeqProof m n
forall a b. a -> b
unsafeCoerce (forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof @0 @0)

-- | Checks if a 'NatRepr' is less than or equal to another 'NatRepr'.
testLeq :: NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq :: forall (m :: Natural) (n :: Natural).
NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq (NatRepr Natural
m) (NatRepr Natural
n) =
  case Natural -> Natural -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Natural
m Natural
n of
    Ordering
LT -> Maybe (LeqProof m n)
forall a. Maybe a
Nothing
    Ordering
EQ -> LeqProof m n -> Maybe (LeqProof m n)
forall a. a -> Maybe a
Just LeqProof m n
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof
    Ordering
GT -> LeqProof m n -> Maybe (LeqProof m n)
forall a. a -> Maybe a
Just LeqProof m n
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof

-- | Apply reflexivity to 'LeqProof'.
leqRefl :: f n -> LeqProof n n
leqRefl :: forall (f :: Natural -> *) (n :: Natural). f n -> LeqProof n n
leqRefl f n
_ = LeqProof n n
forall (m :: Natural) (n :: Natural). (m <= n) => LeqProof m n
LeqProof

-- | A natural number is less than or equal to its successor.
leqSucc :: f n -> LeqProof n (n + 1)
leqSucc :: forall (f :: Natural -> *) (n :: Natural).
f n -> LeqProof n (n + 1)
leqSucc f n
_ = LeqProof n (n + 1)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof

-- | Apply transitivity to 'LeqProof'.
leqTrans :: LeqProof a b -> LeqProof b c -> LeqProof a c
leqTrans :: forall (a :: Natural) (b :: Natural) (c :: Natural).
LeqProof a b -> LeqProof b c -> LeqProof a c
leqTrans LeqProof a b
_ LeqProof b c
_ = LeqProof a c
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof

-- | Zero is less than or equal to any natural number.
leqZero :: LeqProof 0 n
leqZero :: forall (n :: Natural). LeqProof 0 n
leqZero = LeqProof 0 n
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof

-- | Add both sides of two inequalities.
leqAdd2 :: LeqProof xl xh -> LeqProof yl yh -> LeqProof (xl + yl) (xh + yh)
leqAdd2 :: forall (xl :: Natural) (xh :: Natural) (yl :: Natural)
       (yh :: Natural).
LeqProof xl xh -> LeqProof yl yh -> LeqProof (xl + yl) (xh + yh)
leqAdd2 LeqProof xl xh
_ LeqProof yl yh
_ = LeqProof (xl + yl) (xh + yh)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof

-- | Produce proof that adding a value to the larger element in an 'LeqProof'
-- is larger.
leqAdd :: LeqProof m n -> f o -> LeqProof m (n + o)
leqAdd :: forall (m :: Natural) (n :: Natural) (f :: Natural -> *)
       (o :: Natural).
LeqProof m n -> f o -> LeqProof m (n + o)
leqAdd LeqProof m n
_ f o
_ = LeqProof m (n + o)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof

-- | Adding two positive natural numbers is positive.
leqAddPos :: (1 <= m, 1 <= n) => p m -> q n -> LeqProof 1 (m + n)
leqAddPos :: forall (m :: Natural) (n :: Natural) (p :: Natural -> *)
       (q :: Natural -> *).
(1 <= m, 1 <= n) =>
p m -> q n -> LeqProof 1 (m + n)
leqAddPos p m
_ q n
_ = LeqProof 1 (m + n)
forall (m :: Natural) (n :: Natural). LeqProof m n
unsafeLeqProof