{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE OverloadedStrings #-}

module Strongweak.Strengthen where

import GHC.TypeNats ( Natural, KnownNat )
import Data.Word
import Data.Int
import Refined ( Refined, refine, Predicate )
import Data.Vector.Sized qualified as Vector
import Data.Vector.Sized ( Vector )
import Type.Reflection ( Typeable, typeRep )

import Prettyprinter
import Prettyprinter.Render.String

import Data.Validation
import Data.List.NonEmpty ( NonEmpty( (:|) ) )
import Data.Foldable qualified as Foldable

{- | Any 'w' can be "strengthened" into an 's' by asserting some properties.

For example, you may strengthen some 'Natural' @n@ into a 'Word8' by asserting
@0 <= n <= 255@.

Note that we restrict strengthened types to having only one corresponding weak
representation using functional dependencies.
-}
class Strengthen w s | s -> w where strengthen :: w -> Validation (NonEmpty StrengthenError) s

data StrengthenError
  = StrengthenErrorBase String String String String
  -- ^ weak type, strong type, weak value, msg
  | StrengthenErrorField String String String String String String StrengthenError
  -- ^ weak datatype name, strong datatype name,
  --   weak constructor name, strong constructor name,
  --   weak field name, strong field name,
  --   error

instance Show StrengthenError where
    showsPrec :: Int -> StrengthenError -> ShowS
showsPrec Int
_ = SimpleDocStream Any -> ShowS
forall ann. SimpleDocStream ann -> ShowS
renderShowS (SimpleDocStream Any -> ShowS)
-> (StrengthenError -> SimpleDocStream Any)
-> StrengthenError
-> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayoutOptions -> Doc Any -> SimpleDocStream Any
forall ann. LayoutOptions -> Doc ann -> SimpleDocStream ann
layoutPretty LayoutOptions
defaultLayoutOptions (Doc Any -> SimpleDocStream Any)
-> (StrengthenError -> Doc Any)
-> StrengthenError
-> SimpleDocStream Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrengthenError -> Doc Any
forall a ann. Pretty a => a -> Doc ann
pretty

instance Pretty StrengthenError where
    pretty :: forall ann. StrengthenError -> Doc ann
pretty = \case
      StrengthenErrorBase String
wt String
st String
wv String
msg ->
        [Doc ann] -> Doc ann
forall ann. [Doc ann] -> Doc ann
vsep [ String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
wtDoc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Doc ann
"->"Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
st
             , String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
wvDoc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Doc ann
"->"Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Doc ann
"FAIL"
             , String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
msg ]
      StrengthenErrorField String
dw String
_ds String
cw String
_cs String
sw String
_ss StrengthenError
err ->
        Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
nest Int
1 (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
dwDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>Doc ann
"."Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
cwDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>Doc ann
"."Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
swDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>Doc ann
forall ann. Doc ann
lineDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>StrengthenError -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty StrengthenError
err

strengthenErrorBase
    :: forall s w. (Typeable w, Show w, Typeable s)
    => w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase :: forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase w
w String
msg = NonEmpty StrengthenError -> Validation (NonEmpty StrengthenError) s
forall err a. err -> Validation err a
Failure (StrengthenError
e StrengthenError -> [StrengthenError] -> NonEmpty StrengthenError
forall a. a -> [a] -> NonEmpty a
:| [])
  where e :: StrengthenError
e = String -> String -> String -> String -> StrengthenError
StrengthenErrorBase (TypeRep w -> String
forall a. Show a => a -> String
show (TypeRep w -> String) -> TypeRep w -> String
forall a b. (a -> b) -> a -> b
$ forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @w) (TypeRep s -> String
forall a. Show a => a -> String
show (TypeRep s -> String) -> TypeRep s -> String
forall a b. (a -> b) -> a -> b
$ forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @s) (w -> String
forall a. Show a => a -> String
show w
w) String
msg

strengthenErrorPretty :: NonEmpty StrengthenError -> Doc a
strengthenErrorPretty :: forall a. NonEmpty StrengthenError -> Doc a
strengthenErrorPretty = [Doc a] -> Doc a
forall ann. [Doc ann] -> Doc ann
vsep ([Doc a] -> Doc a)
-> (NonEmpty StrengthenError -> [Doc a])
-> NonEmpty StrengthenError
-> Doc a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StrengthenError -> Doc a) -> [StrengthenError] -> [Doc a]
forall a b. (a -> b) -> [a] -> [b]
map StrengthenError -> Doc a
forall {a} {ann}. Pretty a => a -> Doc ann
go ([StrengthenError] -> [Doc a])
-> (NonEmpty StrengthenError -> [StrengthenError])
-> NonEmpty StrengthenError
-> [Doc a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty StrengthenError -> [StrengthenError]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList
  where go :: a -> Doc ann
go a
e = Doc ann
"-"Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
indent Int
0 (a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
e)

-- | Strengthen each element of a list.
instance Strengthen w s => Strengthen [w] [s] where
    strengthen :: [w] -> Validation (NonEmpty StrengthenError) [s]
strengthen = (w -> Validation (NonEmpty StrengthenError) s)
-> [w] -> Validation (NonEmpty StrengthenError) [s]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse w -> Validation (NonEmpty StrengthenError) s
forall w s.
Strengthen w s =>
w -> Validation (NonEmpty StrengthenError) s
strengthen

-- | Obtain a sized vector by asserting the size of a plain list.
instance (KnownNat n, Typeable a, Show a) => Strengthen [a] (Vector n a) where
    strengthen :: [a] -> Validation (NonEmpty StrengthenError) (Vector n a)
strengthen [a]
w =
        case [a] -> Maybe (Vector n a)
forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
Vector.fromList [a]
w of
          Maybe (Vector n a)
Nothing -> [a] -> String -> Validation (NonEmpty StrengthenError) (Vector n a)
forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase [a]
w String
"TODO bad size vector"
          Just Vector n a
s  -> Vector n a -> Validation (NonEmpty StrengthenError) (Vector n a)
forall err a. a -> Validation err a
Success Vector n a
s

-- | Obtain a refined type by applying its associated refinement.
instance (Predicate p a, Typeable a, Show a) => Strengthen a (Refined p a) where
    strengthen :: a -> Validation (NonEmpty StrengthenError) (Refined p a)
strengthen a
a =
        case a -> Either RefineException (Refined p a)
forall p x.
Predicate p x =>
x -> Either RefineException (Refined p x)
refine a
a of
          Left  RefineException
err -> a -> String -> Validation (NonEmpty StrengthenError) (Refined p a)
forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase a
a (RefineException -> String
forall a. Show a => a -> String
show RefineException
err)
          Right Refined p a
ra  -> Refined p a -> Validation (NonEmpty StrengthenError) (Refined p a)
forall err a. a -> Validation err a
Success Refined p a
ra

-- Strengthen 'Natural's into Haskell's bounded unsigned numeric types.
instance Strengthen Natural Word8  where strengthen :: Nat -> Validation (NonEmpty StrengthenError) Word8
strengthen = Nat -> Validation (NonEmpty StrengthenError) Word8
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Natural Word16 where strengthen :: Nat -> Validation (NonEmpty StrengthenError) Word16
strengthen = Nat -> Validation (NonEmpty StrengthenError) Word16
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Natural Word32 where strengthen :: Nat -> Validation (NonEmpty StrengthenError) Word32
strengthen = Nat -> Validation (NonEmpty StrengthenError) Word32
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Natural Word64 where strengthen :: Nat -> Validation (NonEmpty StrengthenError) Word64
strengthen = Nat -> Validation (NonEmpty StrengthenError) Word64
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded

-- Strengthen 'Integer's into Haskell's bounded signed numeric types.
instance Strengthen Integer Int8   where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int8
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int8
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Integer Int16  where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int16
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int16
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Integer Int32  where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int32
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int32
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Integer Int64  where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int64
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int64
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded

strengthenBounded
    :: forall b n
    .  (Integral b, Bounded b, Show b, Typeable b, Integral n, Show n, Typeable n)
    => n -> Validation (NonEmpty StrengthenError) b
strengthenBounded :: forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded n
n =
    if   n
n n -> n -> Bool
forall a. Ord a => a -> a -> Bool
<= n
maxB Bool -> Bool -> Bool
&& n
n n -> n -> Bool
forall a. Ord a => a -> a -> Bool
>= n
minB then b -> Validation (NonEmpty StrengthenError) b
forall err a. a -> Validation err a
Success (n -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral n
n)
    else n -> String -> Validation (NonEmpty StrengthenError) b
forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase n
n (String -> Validation (NonEmpty StrengthenError) b)
-> String -> Validation (NonEmpty StrengthenError) b
forall a b. (a -> b) -> a -> b
$ String
"not well bounded, require: "
                                 String -> ShowS
forall a. Semigroup a => a -> a -> a
<>n -> String
forall a. Show a => a -> String
show n
minBString -> ShowS
forall a. Semigroup a => a -> a -> a
<>String
" <= n <= "String -> ShowS
forall a. Semigroup a => a -> a -> a
<>n -> String
forall a. Show a => a -> String
show n
maxB
  where
    maxB :: n
maxB = forall a b. (Integral a, Num b) => a -> b
fromIntegral @b @n b
forall a. Bounded a => a
maxBound
    minB :: n
minB = forall a b. (Integral a, Num b) => a -> b
fromIntegral @b @n b
forall a. Bounded a => a
minBound