{-# LANGUAGE CPP
, KindSignatures
, DataKinds
, GADTs
, StandaloneDeriving
, ExistentialQuantification
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Types.Coercion
(
PrimCoercion(..)
, PrimCoerce(..)
, Coercion(..)
, singletonCoercion
, signed
, continuous
, Coerce(..)
, singCoerceDom
, singCoerceCod
, singCoerceDomCod
, CoercionMode(..)
, findCoercion
, findEitherCoercion
, Lub(..)
, findLub
, SomeRing(..)
, findRing
, SomeFractional(..)
, findFractional
, ZigZag(..)
, simplifyZZ
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
#endif
import Control.Applicative ((<|>))
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.IClasses
(TypeEq(..), Eq1(..), Eq2(..), JmEq1(..), JmEq2(..))
data PrimCoercion :: Hakaru -> Hakaru -> * where
Signed :: !(HRing a) -> PrimCoercion (NonNegative a) a
Continuous :: !(HContinuous a) -> PrimCoercion (HIntegral a) a
deriving instance Show (PrimCoercion a b)
instance Eq (PrimCoercion a b) where
(==) = eq1
instance Eq1 (PrimCoercion a) where
eq1 = eq2
instance Eq2 PrimCoercion where
eq2 x y = maybe False (const True) (jmEq2 x y)
instance JmEq1 (PrimCoercion a) where
jmEq1 x y = snd <$> jmEq2 x y
instance JmEq2 PrimCoercion where
jmEq2 (Signed r1) (Signed r2) =
jmEq1 r1 r2 >>= \Refl -> Just (Refl, Refl)
jmEq2 (Continuous c1) (Continuous c2) =
jmEq1 c1 c2 >>= \Refl -> Just (Refl, Refl)
jmEq2 _ _ = Nothing
data Coercion :: Hakaru -> Hakaru -> * where
CNil :: Coercion a a
CCons :: !(PrimCoercion a b) -> !(Coercion b c) -> Coercion a c
infixr 5 `CCons`
deriving instance Show (Coercion a b)
instance Eq (Coercion a b) where
(==) = eq1
instance Eq1 (Coercion a) where
eq1 = eq2
instance Eq2 Coercion where
eq2 CNil CNil = True
eq2 (CCons x xs) (CCons y ys) =
case jmEq2 x y of
Just (Refl, Refl) -> eq2 xs ys
Nothing -> False
eq2 _ _ = False
instance Category Coercion where
id = CNil
xs . CNil = xs
xs . CCons y ys = CCons y (xs . ys)
singletonCoercion :: PrimCoercion a b -> Coercion a b
singletonCoercion c = CCons c CNil
signed :: (HRing_ a) => Coercion (NonNegative a) a
signed = singletonCoercion $ Signed hRing
continuous :: (HContinuous_ a) => Coercion (HIntegral a) a
continuous = singletonCoercion $ Continuous hContinuous
class PrimCoerce (f :: Hakaru -> *) where
primCoerceTo :: PrimCoercion a b -> f a -> f b
primCoerceFrom :: PrimCoercion a b -> f b -> f a
instance PrimCoerce (Sing :: Hakaru -> *) where
primCoerceTo (Signed theRing) s =
case jmEq1 s (sing_NonNegative theRing) of
Just Refl -> sing_HRing theRing
Nothing -> error "primCoerceTo@Sing: the impossible happened"
primCoerceTo (Continuous theCont) s =
case jmEq1 s (sing_HIntegral theCont) of
Just Refl -> sing_HContinuous theCont
Nothing -> error "primCoerceTo@Sing: the impossible happened"
primCoerceFrom (Signed theRing) s =
case jmEq1 s (sing_HRing theRing) of
Just Refl -> sing_NonNegative theRing
Nothing -> error "primCoerceFrom@Sing: the impossible happened"
primCoerceFrom (Continuous theCont) s =
case jmEq1 s (sing_HContinuous theCont) of
Just Refl -> sing_HIntegral theCont
Nothing -> error "primCoerceFrom@Sing: the impossible happened"
class Coerce (f :: Hakaru -> *) where
coerceTo :: Coercion a b -> f a -> f b
coerceFrom :: Coercion a b -> f b -> f a
instance Coerce (Sing :: Hakaru -> *) where
coerceTo CNil s = s
coerceTo (CCons c cs) s = coerceTo cs (primCoerceTo c s)
coerceFrom CNil s = s
coerceFrom (CCons c cs) s = primCoerceFrom c (coerceFrom cs s)
singPrimCoerceDom :: PrimCoercion a b -> Sing a
singPrimCoerceDom (Signed theRing) = sing_NonNegative theRing
singPrimCoerceDom (Continuous theCont) = sing_HIntegral theCont
singPrimCoerceCod :: PrimCoercion a b -> Sing b
singPrimCoerceCod (Signed theRing) = sing_HRing theRing
singPrimCoerceCod (Continuous theCont) = sing_HContinuous theCont
singCoerceDom :: Coercion a b -> Maybe (Sing a)
singCoerceDom CNil = Nothing
singCoerceDom (CCons c CNil) = Just $ singPrimCoerceDom c
singCoerceDom (CCons c cs) = primCoerceFrom c <$> singCoerceDom cs
singCoerceCod :: Coercion a b -> Maybe (Sing b)
singCoerceCod CNil = Nothing
singCoerceCod (CCons c CNil) = Just $ singPrimCoerceCod c
singCoerceCod (CCons c cs) = Just . coerceTo cs $ singPrimCoerceCod c
singCoerceDomCod :: Coercion a b -> Maybe (Sing a, Sing b)
singCoerceDomCod CNil = Nothing
singCoerceDomCod (CCons c CNil) =
Just (singPrimCoerceDom c, singPrimCoerceCod c)
singCoerceDomCod (CCons c cs) = do
dom <- singCoerceDom cs
Just (primCoerceFrom c dom
, coerceTo cs $ singPrimCoerceCod c
)
findCoercion :: Sing a -> Sing b -> Maybe (Coercion a b)
findCoercion SNat SInt = Just signed
findCoercion SProb SReal = Just signed
findCoercion SNat SProb = Just continuous
findCoercion SInt SReal = Just continuous
findCoercion SNat SReal = Just (continuous . signed)
findCoercion a b = jmEq1 a b >>= \Refl -> Just CNil
findMixedCoercion
:: Sing a
-> Sing b
-> Maybe (Sing 'HNat, Coercion 'HNat a, Coercion 'HNat b)
findMixedCoercion SProb SInt = Just (SNat, continuous, signed)
findMixedCoercion SInt SProb = Just (SNat, signed, continuous)
findMixedCoercion _ _ = Nothing
data CoercionMode a b =
Safe (Coercion a b)
| Unsafe (Coercion b a)
| forall c. Mixed (Sing c, Coercion c a, Coercion c b)
findEitherCoercion
:: Sing a
-> Sing b
-> Maybe (CoercionMode a b)
findEitherCoercion a b =
(Safe <$> findCoercion a b) <|>
(Unsafe <$> findCoercion b a) <|>
(Mixed <$> findMixedCoercion a b)
data Lub (a :: Hakaru) (b :: Hakaru)
= forall c. Lub !(Sing c) !(Coercion a c) !(Coercion b c)
findLub :: Sing a -> Sing b -> Maybe (Lub a b)
findLub SNat SInt = Just $ Lub SInt signed CNil
findLub SProb SReal = Just $ Lub SReal signed CNil
findLub SNat SProb = Just $ Lub SProb continuous CNil
findLub SInt SReal = Just $ Lub SReal continuous CNil
findLub SNat SReal = Just $ Lub SReal (continuous . signed) CNil
findLub SInt SNat = Just $ Lub SInt CNil signed
findLub SReal SProb = Just $ Lub SReal CNil signed
findLub SProb SNat = Just $ Lub SProb CNil continuous
findLub SReal SInt = Just $ Lub SReal CNil continuous
findLub SReal SNat = Just $ Lub SReal CNil (continuous . signed)
findLub SInt SProb = Just $ Lub SReal continuous signed
findLub SProb SInt = Just $ Lub SReal signed continuous
findLub a b = jmEq1 a b >>= \Refl -> Just $ Lub a CNil CNil
data SomeRing (a :: Hakaru)
= forall b. SomeRing !(HRing b) !(Coercion a b)
findRing :: Sing a -> Maybe (SomeRing a)
findRing SNat = Just (SomeRing HRing_Int signed)
findRing SInt = Just (SomeRing HRing_Int CNil)
findRing SProb = Just (SomeRing HRing_Real signed)
findRing SReal = Just (SomeRing HRing_Real CNil)
findRing _ = Nothing
data SomeFractional (a :: Hakaru)
= forall b. SomeFractional !(HFractional b) !(Coercion a b)
findFractional :: Sing a -> Maybe (SomeFractional a)
findFractional SNat = Just (SomeFractional HFractional_Prob continuous)
findFractional SInt = Just (SomeFractional HFractional_Real continuous)
findFractional SProb = Just (SomeFractional HFractional_Prob CNil)
findFractional SReal = Just (SomeFractional HFractional_Real CNil)
findFractional _ = Nothing
data CoerceTo_UnsafeFrom :: Hakaru -> Hakaru -> * where
CTUF :: !(Coercion b c) -> !(Coercion b a) -> CoerceTo_UnsafeFrom a c
deriving instance Show (CoerceTo_UnsafeFrom a b)
simplifyCTUF :: CoerceTo_UnsafeFrom a c -> CoerceTo_UnsafeFrom a c
simplifyCTUF (CTUF xs ys) =
case xs of
CNil -> CTUF CNil ys
CCons x xs' ->
case ys of
CNil -> CTUF xs CNil
CCons y ys' ->
case jmEq2 x y of
Just (Refl, Refl) -> simplifyCTUF (CTUF xs' ys')
Nothing -> CTUF xs ys
data RevCoercion :: Hakaru -> Hakaru -> * where
CLin :: RevCoercion a a
CSnoc :: !(RevCoercion a b) -> !(PrimCoercion b c) -> RevCoercion a c
deriving instance Show (RevCoercion a b)
instance Category RevCoercion where
id = CLin
CLin . xs = xs
CSnoc ys y . xs = CSnoc (ys . xs) y
revCons :: PrimCoercion a b -> RevCoercion b c -> RevCoercion a c
revCons x CLin = CSnoc CLin x
revCons x (CSnoc ys y) = CSnoc (revCons x ys) y
toRev :: Coercion a b -> RevCoercion a b
toRev CNil = CLin
toRev (CCons x xs) = revCons x (toRev xs)
obvSnoc :: Coercion a b -> PrimCoercion b c -> Coercion a c
obvSnoc CNil y = CCons y CNil
obvSnoc (CCons x xs) y = CCons x (obvSnoc xs y)
fromRev :: RevCoercion a b -> Coercion a b
fromRev CLin = CNil
fromRev (CSnoc xs x) = obvSnoc (fromRev xs) x
data UnsafeFrom_CoerceTo :: Hakaru -> Hakaru -> * where
UFCT
:: !(Coercion c b)
-> !(Coercion a b)
-> UnsafeFrom_CoerceTo a c
deriving instance Show (UnsafeFrom_CoerceTo a b)
data RevUFCT :: Hakaru -> Hakaru -> * where
RevUFCT :: !(RevCoercion c b) -> !(RevCoercion a b) -> RevUFCT a c
simplifyUFCT :: UnsafeFrom_CoerceTo a c -> UnsafeFrom_CoerceTo a c
simplifyUFCT (UFCT xs ys) =
case simplifyRevUFCT $ RevUFCT (toRev xs) (toRev ys) of
RevUFCT xs' ys' -> UFCT (fromRev xs') (fromRev ys')
simplifyRevUFCT :: RevUFCT a c -> RevUFCT a c
simplifyRevUFCT (RevUFCT xs ys) =
case xs of
CLin -> RevUFCT CLin ys
CSnoc xs' x ->
case ys of
CLin -> RevUFCT xs CLin
CSnoc ys' y ->
case jmEq2 x y of
Just (Refl, Refl) -> simplifyRevUFCT (RevUFCT xs' ys')
Nothing -> RevUFCT xs ys
data ZigZag :: Hakaru -> Hakaru -> * where
ZRefl :: ZigZag a a
Zig :: !(Coercion a b) -> !(ZigZag b c) -> ZigZag a c
Zag :: !(Coercion b a) -> !(ZigZag b c) -> ZigZag a c
deriving instance Show (ZigZag a b)
simplifyZZ :: ZigZag a b -> ZigZag a b
simplifyZZ ZRefl = ZRefl
simplifyZZ (Zig x xs) =
case simplifyZZ xs of
ZRefl -> Zig x ZRefl
Zig y z -> Zig (y . x) z
Zag y z ->
case simplifyUFCT (UFCT x y) of
UFCT CNil CNil -> z
UFCT CNil y' -> Zag y' z
UFCT x' CNil -> Zig x' z
UFCT x' y' -> Zig x' (Zag y' z)
simplifyZZ (Zag x xs) =
case simplifyZZ xs of
ZRefl -> Zag x ZRefl
Zag y z -> Zag (x . y) z
Zig y z ->
case simplifyCTUF (CTUF x y) of
CTUF CNil CNil -> z
CTUF CNil y' -> Zig y' z
CTUF x' CNil -> Zag x' z
CTUF x' y' -> Zag x' (Zig y' z)