{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Analysis.Match (
MatchAcc,
(:~:)(..),
matchPreOpenAcc,
matchPreOpenAfun,
matchOpenExp,
matchOpenFun,
matchPrimFun, matchPrimFun',
matchIdx, matchVar, matchVars, matchArrayR, matchArraysR, matchTypeR, matchShapeR,
matchShapeType, matchIntegralType, matchFloatingType, matchNumType, matchScalarType,
matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchSingleType, matchTupR
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Hash
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Slice
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import qualified Data.Array.Accelerate.Sugar.Shape as Sugar
import Data.Maybe
import Data.Typeable
import Unsafe.Coerce ( unsafeCoerce )
import System.IO.Unsafe ( unsafePerformIO )
import System.Mem.StableName
import Prelude hiding ( exp )
type MatchAcc acc = forall aenv s t. acc aenv s -> acc aenv t -> Maybe (s :~: t)
{-# INLINEABLE matchPreOpenAcc #-}
matchPreOpenAcc
:: forall acc aenv s t. HasArraysR acc
=> MatchAcc acc
-> PreOpenAcc acc aenv s
-> PreOpenAcc acc aenv t
-> Maybe (s :~: t)
matchPreOpenAcc matchAcc = match
where
matchFun :: OpenFun env' aenv' u -> OpenFun env' aenv' v -> Maybe (u :~: v)
matchFun = matchOpenFun
matchExp :: OpenExp env' aenv' u -> OpenExp env' aenv' v -> Maybe (u :~: v)
matchExp = matchOpenExp
match :: PreOpenAcc acc aenv s -> PreOpenAcc acc aenv t -> Maybe (s :~: t)
match (Alet lhs1 x1 a1) (Alet lhs2 x2 a2)
| Just Refl <- matchALeftHandSide lhs1 lhs2
, Just Refl <- matchAcc x1 x2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Avar v1) (Avar v2)
= matchVar v1 v2
match (Apair a1 a2) (Apair b1 b2)
| Just Refl <- matchAcc a1 b1
, Just Refl <- matchAcc a2 b2
= Just Refl
match Anil Anil
= Just Refl
match (Apply _ f1 a1) (Apply _ f2 a2)
| Just Refl <- matchPreOpenAfun matchAcc f1 f2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Aforeign _ ff1 f1 a1) (Aforeign _ ff2 f2 a2)
| Just Refl <- matchAcc a1 a2
, unsafePerformIO $ do
sn1 <- makeStableName ff1
sn2 <- makeStableName ff2
return $! hashStableName sn1 == hashStableName sn2
, Just Refl <- matchPreOpenAfun matchAcc f1 f2
= Just Refl
match (Acond p1 t1 e1) (Acond p2 t2 e2)
| Just Refl <- matchExp p1 p2
, Just Refl <- matchAcc t1 t2
, Just Refl <- matchAcc e1 e2
= Just Refl
match (Awhile p1 f1 a1) (Awhile p2 f2 a2)
| Just Refl <- matchAcc a1 a2
, Just Refl <- matchPreOpenAfun matchAcc p1 p2
, Just Refl <- matchPreOpenAfun matchAcc f1 f2
= Just Refl
match (Use repr1 a1) (Use repr2 a2)
| Just Refl <- matchArray repr1 repr2 a1 a2
= Just Refl
match (Unit t1 e1) (Unit t2 e2)
| Just Refl <- matchTypeR t1 t2
, Just Refl <- matchExp e1 e2
= Just Refl
match (Reshape _ sh1 a1) (Reshape _ sh2 a2)
| Just Refl <- matchExp sh1 sh2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Generate _ sh1 f1) (Generate _ sh2 f2)
| Just Refl <- matchExp sh1 sh2
, Just Refl <- matchFun f1 f2
= Just Refl
match (Transform _ sh1 ix1 f1 a1) (Transform _ sh2 ix2 f2 a2)
| Just Refl <- matchExp sh1 sh2
, Just Refl <- matchFun ix1 ix2
, Just Refl <- matchFun f1 f2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Replicate si1 ix1 a1) (Replicate si2 ix2 a2)
| Just Refl <- matchSliceIndex si1 si2
, Just Refl <- matchExp ix1 ix2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Slice si1 a1 ix1) (Slice si2 a2 ix2)
| Just Refl <- matchSliceIndex si1 si2
, Just Refl <- matchAcc a1 a2
, Just Refl <- matchExp ix1 ix2
= Just Refl
match (Map _ f1 a1) (Map _ f2 a2)
| Just Refl <- matchFun f1 f2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (ZipWith _ f1 a1 b1) (ZipWith _ f2 a2 b2)
| Just Refl <- matchFun f1 f2
, Just Refl <- matchAcc a1 a2
, Just Refl <- matchAcc b1 b2
= Just Refl
match (Fold f1 z1 a1) (Fold f2 z2 a2)
| Just Refl <- matchFun f1 f2
, matchMaybe matchExp z1 z2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (FoldSeg _ f1 z1 a1 s1) (FoldSeg _ f2 z2 a2 s2)
| Just Refl <- matchFun f1 f2
, matchMaybe matchExp z1 z2
, Just Refl <- matchAcc a1 a2
, Just Refl <- matchAcc s1 s2
= Just Refl
match (Scan d1 f1 z1 a1) (Scan d2 f2 z2 a2)
| d1 == d2
, Just Refl <- matchFun f1 f2
, matchMaybe matchExp z1 z2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Scan' d1 f1 z1 a1) (Scan' d2 f2 z2 a2)
| d1 == d2
, Just Refl <- matchFun f1 f2
, Just Refl <- matchExp z1 z2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Permute f1 d1 p1 a1) (Permute f2 d2 p2 a2)
| Just Refl <- matchFun f1 f2
, Just Refl <- matchAcc d1 d2
, Just Refl <- matchFun p1 p2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Backpermute _ sh1 ix1 a1) (Backpermute _ sh2 ix2 a2)
| Just Refl <- matchExp sh1 sh2
, Just Refl <- matchFun ix1 ix2
, Just Refl <- matchAcc a1 a2
= Just Refl
match (Stencil s1 _ f1 b1 a1) (Stencil _ _ f2 b2 a2)
| Just Refl <- matchFun f1 f2
, Just Refl <- matchAcc a1 a2
, matchBoundary (stencilEltR s1) b1 b2
= Just Refl
match (Stencil2 s1 s2 _ f1 b1 a1 b2 a2) (Stencil2 _ _ _ f2 b1' a1' b2' a2')
| Just Refl <- matchFun f1 f2
, Just Refl <- matchAcc a1 a1'
, Just Refl <- matchAcc a2 a2'
, matchBoundary (stencilEltR s1) b1 b1'
, matchBoundary (stencilEltR s2) b2 b2'
= Just Refl
match _ _
= Nothing
{-# INLINEABLE matchPreOpenAfun #-}
matchPreOpenAfun
:: MatchAcc acc
-> PreOpenAfun acc aenv s
-> PreOpenAfun acc aenv t
-> Maybe (s :~: t)
matchPreOpenAfun m (Alam lhs1 s) (Alam lhs2 t)
| Just Refl <- matchALeftHandSide lhs1 lhs2
, Just Refl <- matchPreOpenAfun m s t
= Just Refl
matchPreOpenAfun m (Abody s) (Abody t) = m s t
matchPreOpenAfun _ _ _ = Nothing
matchALeftHandSide
:: forall aenv aenv1 aenv2 t1 t2.
ALeftHandSide t1 aenv aenv1
-> ALeftHandSide t2 aenv aenv2
-> Maybe (ALeftHandSide t1 aenv aenv1 :~: ALeftHandSide t2 aenv aenv2)
matchALeftHandSide = matchLeftHandSide matchArrayR
matchELeftHandSide
:: forall env env1 env2 t1 t2.
ELeftHandSide t1 env env1
-> ELeftHandSide t2 env env2
-> Maybe (ELeftHandSide t1 env env1 :~: ELeftHandSide t2 env env2)
matchELeftHandSide = matchLeftHandSide matchScalarType
matchLeftHandSide
:: forall s env env1 env2 t1 t2.
(forall x y. s x -> s y -> Maybe (x :~: y))
-> LeftHandSide s t1 env env1
-> LeftHandSide s t2 env env2
-> Maybe (LeftHandSide s t1 env env1 :~: LeftHandSide s t2 env env2)
matchLeftHandSide f (LeftHandSideWildcard repr1) (LeftHandSideWildcard repr2)
| Just Refl <- matchTupR f repr1 repr2
= Just Refl
matchLeftHandSide f (LeftHandSideSingle x) (LeftHandSideSingle y)
| Just Refl <- f x y
= Just Refl
matchLeftHandSide f (LeftHandSidePair a1 a2) (LeftHandSidePair b1 b2)
| Just Refl <- matchLeftHandSide f a1 b1
, Just Refl <- matchLeftHandSide f a2 b2
= Just Refl
matchLeftHandSide _ _ _ = Nothing
matchBoundary
:: TypeR t
-> Boundary aenv (Array sh t)
-> Boundary aenv (Array sh t)
-> Bool
matchBoundary _ Clamp Clamp = True
matchBoundary _ Mirror Mirror = True
matchBoundary _ Wrap Wrap = True
matchBoundary tp (Constant s) (Constant t) = matchConst tp s t
matchBoundary _ (Function f) (Function g)
| Just Refl <- matchOpenFun f g
= True
matchBoundary _ _ _
= False
matchMaybe :: (s1 -> s2 -> Maybe (t1 :~: t2)) -> Maybe s1 -> Maybe s2 -> Bool
matchMaybe _ Nothing Nothing = True
matchMaybe f (Just x) (Just y)
| Just Refl <- f x y = True
matchMaybe _ _ _ = False
matchArray :: ArrayR (Array sh1 e1)
-> ArrayR (Array sh2 e2)
-> Array sh1 e1
-> Array sh2 e2
-> Maybe (Array sh1 e1 :~: Array sh2 e2)
matchArray repr1 repr2 (Array _ ad1) (Array _ ad2)
| Just Refl <- matchArrayR repr1 repr2
, unsafePerformIO $ do
sn1 <- makeStableName ad1
sn2 <- makeStableName ad2
return $! hashStableName sn1 == hashStableName sn2
= Just Refl
matchArray _ _ _ _
= Nothing
matchTupR :: (forall u1 u2. s u1 -> s u2 -> Maybe (u1 :~: u2)) -> TupR s t1 -> TupR s t2 -> Maybe (t1 :~: t2)
matchTupR _ TupRunit TupRunit = Just Refl
matchTupR f (TupRsingle x) (TupRsingle y) = f x y
matchTupR f (TupRpair x1 x2) (TupRpair y1 y2)
| Just Refl <- matchTupR f x1 y1
, Just Refl <- matchTupR f x2 y2 = Just Refl
matchTupR _ _ _ = Nothing
matchArraysR :: ArraysR s -> ArraysR t -> Maybe (s :~: t)
matchArraysR = matchTupR matchArrayR
matchArrayR :: ArrayR s -> ArrayR t -> Maybe (s :~: t)
matchArrayR (ArrayR shr1 tp1) (ArrayR shr2 tp2)
| Just Refl <- matchShapeR shr1 shr2
, Just Refl <- matchTypeR tp1 tp2 = Just Refl
matchArrayR _ _ = Nothing
{-# INLINEABLE matchOpenExp #-}
matchOpenExp
:: forall env aenv s t.
OpenExp env aenv s
-> OpenExp env aenv t
-> Maybe (s :~: t)
matchOpenExp (Let lhs1 x1 e1) (Let lhs2 x2 e2)
| Just Refl <- matchELeftHandSide lhs1 lhs2
, Just Refl <- matchOpenExp x1 x2
, Just Refl <- matchOpenExp e1 e2
= Just Refl
matchOpenExp (Evar v1) (Evar v2)
= matchVar v1 v2
matchOpenExp (Foreign _ ff1 f1 e1) (Foreign _ ff2 f2 e2)
| Just Refl <- matchOpenExp e1 e2
, unsafePerformIO $ do
sn1 <- makeStableName ff1
sn2 <- makeStableName ff2
return $! hashStableName sn1 == hashStableName sn2
, Just Refl <- matchOpenFun f1 f2
= Just Refl
matchOpenExp (Const t1 c1) (Const t2 c2)
| Just Refl <- matchScalarType t1 t2
, matchConst (TupRsingle t1) c1 c2
= Just Refl
matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2
matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2)
| Just Refl <- matchScalarType t1 t2
, Just Refl <- matchOpenExp e1 e2
= Just Refl
matchOpenExp (Pair a1 b1) (Pair a2 b2)
| Just Refl <- matchOpenExp a1 a2
, Just Refl <- matchOpenExp b1 b2
= Just Refl
matchOpenExp Nil Nil
= Just Refl
matchOpenExp (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2)
| Just Refl <- matchOpenExp ix1 ix2
, Just Refl <- matchOpenExp sh1 sh2
, Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2
= Just Refl
matchOpenExp (IndexFull sliceIndex1 ix1 sl1) (IndexFull sliceIndex2 ix2 sl2)
| Just Refl <- matchOpenExp ix1 ix2
, Just Refl <- matchOpenExp sl1 sl2
, Just Refl <- matchSliceIndex sliceIndex1 sliceIndex2
= Just Refl
matchOpenExp (ToIndex _ sh1 i1) (ToIndex _ sh2 i2)
| Just Refl <- matchOpenExp sh1 sh2
, Just Refl <- matchOpenExp i1 i2
= Just Refl
matchOpenExp (FromIndex _ sh1 i1) (FromIndex _ sh2 i2)
| Just Refl <- matchOpenExp i1 i2
, Just Refl <- matchOpenExp sh1 sh2
= Just Refl
matchOpenExp (Cond p1 t1 e1) (Cond p2 t2 e2)
| Just Refl <- matchOpenExp p1 p2
, Just Refl <- matchOpenExp t1 t2
, Just Refl <- matchOpenExp e1 e2
= Just Refl
matchOpenExp (While p1 f1 x1) (While p2 f2 x2)
| Just Refl <- matchOpenExp x1 x2
, Just Refl <- matchOpenFun p1 p2
, Just Refl <- matchOpenFun f1 f2
= Just Refl
matchOpenExp (PrimConst c1) (PrimConst c2)
= matchPrimConst c1 c2
matchOpenExp (PrimApp f1 x1) (PrimApp f2 x2)
| Just x1' <- commutes f1 x1
, Just x2' <- commutes f2 x2
, Just Refl <- matchOpenExp x1' x2'
, Just Refl <- matchPrimFun f1 f2
= Just Refl
| Just Refl <- matchOpenExp x1 x2
, Just Refl <- matchPrimFun f1 f2
= Just Refl
matchOpenExp (Index a1 x1) (Index a2 x2)
| Just Refl <- matchVar a1 a2
, Just Refl <- matchOpenExp x1 x2
= Just Refl
matchOpenExp (LinearIndex a1 x1) (LinearIndex a2 x2)
| Just Refl <- matchVar a1 a2
, Just Refl <- matchOpenExp x1 x2
= Just Refl
matchOpenExp (Shape a1) (Shape a2)
| Just Refl <- matchVar a1 a2
= Just Refl
matchOpenExp (ShapeSize _ sh1) (ShapeSize _ sh2)
| Just Refl <- matchOpenExp sh1 sh2
= Just Refl
matchOpenExp _ _
= Nothing
{-# INLINEABLE matchOpenFun #-}
matchOpenFun
:: OpenFun env aenv s
-> OpenFun env aenv t
-> Maybe (s :~: t)
matchOpenFun (Lam lhs1 s) (Lam lhs2 t)
| Just Refl <- matchELeftHandSide lhs1 lhs2
, Just Refl <- matchOpenFun s t
= Just Refl
matchOpenFun (Body s) (Body t) = matchOpenExp s t
matchOpenFun _ _ = Nothing
matchConst :: TypeR a -> a -> a -> Bool
matchConst TupRunit () () = True
matchConst (TupRsingle ty) a b = evalEq ty (a,b)
matchConst (TupRpair ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2
evalEq :: ScalarType a -> (a, a) -> Bool
evalEq (SingleScalarType t) = evalEqSingle t
evalEq (VectorScalarType t) = evalEqVector t
evalEqSingle :: SingleType a -> (a, a) -> Bool
evalEqSingle (NumSingleType t) = evalEqNum t
evalEqVector :: VectorType a -> (a, a) -> Bool
evalEqVector VectorType{} = uncurry (==)
evalEqNum :: NumType a -> (a, a) -> Bool
evalEqNum (IntegralNumType t) | IntegralDict <- integralDict t = uncurry (==)
evalEqNum (FloatingNumType t) | FloatingDict <- floatingDict t = uncurry (==)
{-# INLINEABLE matchIdx #-}
matchIdx :: Idx env s -> Idx env t -> Maybe (s :~: t)
matchIdx ZeroIdx ZeroIdx = Just Refl
matchIdx (SuccIdx u) (SuccIdx v) = matchIdx u v
matchIdx _ _ = Nothing
{-# INLINEABLE matchVar #-}
matchVar :: Var s env t1 -> Var s env t2 -> Maybe (t1 :~: t2)
matchVar (Var _ v1) (Var _ v2) = matchIdx v1 v2
{-# INLINEABLE matchVars #-}
matchVars :: Vars s env t1 -> Vars s env t2 -> Maybe (t1 :~: t2)
matchVars TupRunit TupRunit = Just Refl
matchVars (TupRsingle v1) (TupRsingle v2)
| Just Refl <- matchVar v1 v2 = Just Refl
matchVars (TupRpair v w) (TupRpair x y)
| Just Refl <- matchVars v x
, Just Refl <- matchVars w y = Just Refl
matchVars _ _ = Nothing
matchSliceIndex :: SliceIndex slix1 sl1 co1 sh1 -> SliceIndex slix2 sl2 co2 sh2 -> Maybe (SliceIndex slix1 sl1 co1 sh1 :~: SliceIndex slix2 sl2 co2 sh2)
matchSliceIndex SliceNil SliceNil
= Just Refl
matchSliceIndex (SliceAll sl1) (SliceAll sl2)
| Just Refl <- matchSliceIndex sl1 sl2
= Just Refl
matchSliceIndex (SliceFixed sl1) (SliceFixed sl2)
| Just Refl <- matchSliceIndex sl1 sl2
= Just Refl
matchSliceIndex _ _
= Nothing
matchPrimConst :: PrimConst s -> PrimConst t -> Maybe (s :~: t)
matchPrimConst (PrimMinBound s) (PrimMinBound t) = matchBoundedType s t
matchPrimConst (PrimMaxBound s) (PrimMaxBound t) = matchBoundedType s t
matchPrimConst (PrimPi s) (PrimPi t) = matchFloatingType s t
matchPrimConst _ _ = Nothing
{-# INLINEABLE matchPrimFun #-}
matchPrimFun :: PrimFun (a -> s) -> PrimFun (a -> t) -> Maybe (s :~: t)
matchPrimFun (PrimAdd _) (PrimAdd _) = Just Refl
matchPrimFun (PrimSub _) (PrimSub _) = Just Refl
matchPrimFun (PrimMul _) (PrimMul _) = Just Refl
matchPrimFun (PrimNeg _) (PrimNeg _) = Just Refl
matchPrimFun (PrimAbs _) (PrimAbs _) = Just Refl
matchPrimFun (PrimSig _) (PrimSig _) = Just Refl
matchPrimFun (PrimQuot _) (PrimQuot _) = Just Refl
matchPrimFun (PrimRem _) (PrimRem _) = Just Refl
matchPrimFun (PrimQuotRem _) (PrimQuotRem _) = Just Refl
matchPrimFun (PrimIDiv _) (PrimIDiv _) = Just Refl
matchPrimFun (PrimMod _) (PrimMod _) = Just Refl
matchPrimFun (PrimDivMod _) (PrimDivMod _) = Just Refl
matchPrimFun (PrimBAnd _) (PrimBAnd _) = Just Refl
matchPrimFun (PrimBOr _) (PrimBOr _) = Just Refl
matchPrimFun (PrimBXor _) (PrimBXor _) = Just Refl
matchPrimFun (PrimBNot _) (PrimBNot _) = Just Refl
matchPrimFun (PrimBShiftL _) (PrimBShiftL _) = Just Refl
matchPrimFun (PrimBShiftR _) (PrimBShiftR _) = Just Refl
matchPrimFun (PrimBRotateL _) (PrimBRotateL _) = Just Refl
matchPrimFun (PrimBRotateR _) (PrimBRotateR _) = Just Refl
matchPrimFun (PrimPopCount _) (PrimPopCount _) = Just Refl
matchPrimFun (PrimCountLeadingZeros _) (PrimCountLeadingZeros _) = Just Refl
matchPrimFun (PrimCountTrailingZeros _) (PrimCountTrailingZeros _) = Just Refl
matchPrimFun (PrimFDiv _) (PrimFDiv _) = Just Refl
matchPrimFun (PrimRecip _) (PrimRecip _) = Just Refl
matchPrimFun (PrimSin _) (PrimSin _) = Just Refl
matchPrimFun (PrimCos _) (PrimCos _) = Just Refl
matchPrimFun (PrimTan _) (PrimTan _) = Just Refl
matchPrimFun (PrimAsin _) (PrimAsin _) = Just Refl
matchPrimFun (PrimAcos _) (PrimAcos _) = Just Refl
matchPrimFun (PrimAtan _) (PrimAtan _) = Just Refl
matchPrimFun (PrimSinh _) (PrimSinh _) = Just Refl
matchPrimFun (PrimCosh _) (PrimCosh _) = Just Refl
matchPrimFun (PrimTanh _) (PrimTanh _) = Just Refl
matchPrimFun (PrimAsinh _) (PrimAsinh _) = Just Refl
matchPrimFun (PrimAcosh _) (PrimAcosh _) = Just Refl
matchPrimFun (PrimAtanh _) (PrimAtanh _) = Just Refl
matchPrimFun (PrimExpFloating _) (PrimExpFloating _) = Just Refl
matchPrimFun (PrimSqrt _) (PrimSqrt _) = Just Refl
matchPrimFun (PrimLog _) (PrimLog _) = Just Refl
matchPrimFun (PrimFPow _) (PrimFPow _) = Just Refl
matchPrimFun (PrimLogBase _) (PrimLogBase _) = Just Refl
matchPrimFun (PrimAtan2 _) (PrimAtan2 _) = Just Refl
matchPrimFun (PrimTruncate _ s) (PrimTruncate _ t) = matchIntegralType s t
matchPrimFun (PrimRound _ s) (PrimRound _ t) = matchIntegralType s t
matchPrimFun (PrimFloor _ s) (PrimFloor _ t) = matchIntegralType s t
matchPrimFun (PrimCeiling _ s) (PrimCeiling _ t) = matchIntegralType s t
matchPrimFun (PrimIsNaN _) (PrimIsNaN _) = Just Refl
matchPrimFun (PrimIsInfinite _) (PrimIsInfinite _) = Just Refl
matchPrimFun (PrimLt _) (PrimLt _) = Just Refl
matchPrimFun (PrimGt _) (PrimGt _) = Just Refl
matchPrimFun (PrimLtEq _) (PrimLtEq _) = Just Refl
matchPrimFun (PrimGtEq _) (PrimGtEq _) = Just Refl
matchPrimFun (PrimEq _) (PrimEq _) = Just Refl
matchPrimFun (PrimNEq _) (PrimNEq _) = Just Refl
matchPrimFun (PrimMax _) (PrimMax _) = Just Refl
matchPrimFun (PrimMin _) (PrimMin _) = Just Refl
matchPrimFun (PrimFromIntegral _ s) (PrimFromIntegral _ t) = matchNumType s t
matchPrimFun (PrimToFloating _ s) (PrimToFloating _ t) = matchFloatingType s t
matchPrimFun PrimLAnd PrimLAnd = Just Refl
matchPrimFun PrimLOr PrimLOr = Just Refl
matchPrimFun PrimLNot PrimLNot = Just Refl
matchPrimFun _ _
= Nothing
{-# INLINEABLE matchPrimFun' #-}
matchPrimFun' :: PrimFun (s -> a) -> PrimFun (t -> a) -> Maybe (s :~: t)
matchPrimFun' (PrimAdd _) (PrimAdd _) = Just Refl
matchPrimFun' (PrimSub _) (PrimSub _) = Just Refl
matchPrimFun' (PrimMul _) (PrimMul _) = Just Refl
matchPrimFun' (PrimNeg _) (PrimNeg _) = Just Refl
matchPrimFun' (PrimAbs _) (PrimAbs _) = Just Refl
matchPrimFun' (PrimSig _) (PrimSig _) = Just Refl
matchPrimFun' (PrimQuot _) (PrimQuot _) = Just Refl
matchPrimFun' (PrimRem _) (PrimRem _) = Just Refl
matchPrimFun' (PrimQuotRem _) (PrimQuotRem _) = Just Refl
matchPrimFun' (PrimIDiv _) (PrimIDiv _) = Just Refl
matchPrimFun' (PrimMod _) (PrimMod _) = Just Refl
matchPrimFun' (PrimDivMod _) (PrimDivMod _) = Just Refl
matchPrimFun' (PrimBAnd _) (PrimBAnd _) = Just Refl
matchPrimFun' (PrimBOr _) (PrimBOr _) = Just Refl
matchPrimFun' (PrimBXor _) (PrimBXor _) = Just Refl
matchPrimFun' (PrimBNot _) (PrimBNot _) = Just Refl
matchPrimFun' (PrimBShiftL _) (PrimBShiftL _) = Just Refl
matchPrimFun' (PrimBShiftR _) (PrimBShiftR _) = Just Refl
matchPrimFun' (PrimBRotateL _) (PrimBRotateL _) = Just Refl
matchPrimFun' (PrimBRotateR _) (PrimBRotateR _) = Just Refl
matchPrimFun' (PrimPopCount s) (PrimPopCount t) = matchIntegralType s t
matchPrimFun' (PrimCountLeadingZeros s) (PrimCountLeadingZeros t) = matchIntegralType s t
matchPrimFun' (PrimCountTrailingZeros s) (PrimCountTrailingZeros t) = matchIntegralType s t
matchPrimFun' (PrimFDiv _) (PrimFDiv _) = Just Refl
matchPrimFun' (PrimRecip _) (PrimRecip _) = Just Refl
matchPrimFun' (PrimSin _) (PrimSin _) = Just Refl
matchPrimFun' (PrimCos _) (PrimCos _) = Just Refl
matchPrimFun' (PrimTan _) (PrimTan _) = Just Refl
matchPrimFun' (PrimAsin _) (PrimAsin _) = Just Refl
matchPrimFun' (PrimAcos _) (PrimAcos _) = Just Refl
matchPrimFun' (PrimAtan _) (PrimAtan _) = Just Refl
matchPrimFun' (PrimSinh _) (PrimSinh _) = Just Refl
matchPrimFun' (PrimCosh _) (PrimCosh _) = Just Refl
matchPrimFun' (PrimTanh _) (PrimTanh _) = Just Refl
matchPrimFun' (PrimAsinh _) (PrimAsinh _) = Just Refl
matchPrimFun' (PrimAcosh _) (PrimAcosh _) = Just Refl
matchPrimFun' (PrimAtanh _) (PrimAtanh _) = Just Refl
matchPrimFun' (PrimExpFloating _) (PrimExpFloating _) = Just Refl
matchPrimFun' (PrimSqrt _) (PrimSqrt _) = Just Refl
matchPrimFun' (PrimLog _) (PrimLog _) = Just Refl
matchPrimFun' (PrimFPow _) (PrimFPow _) = Just Refl
matchPrimFun' (PrimLogBase _) (PrimLogBase _) = Just Refl
matchPrimFun' (PrimAtan2 _) (PrimAtan2 _) = Just Refl
matchPrimFun' (PrimTruncate s _) (PrimTruncate t _) = matchFloatingType s t
matchPrimFun' (PrimRound s _) (PrimRound t _) = matchFloatingType s t
matchPrimFun' (PrimFloor s _) (PrimFloor t _) = matchFloatingType s t
matchPrimFun' (PrimCeiling s _) (PrimCeiling t _) = matchFloatingType s t
matchPrimFun' (PrimIsNaN s) (PrimIsNaN t) = matchFloatingType s t
matchPrimFun' (PrimIsInfinite s) (PrimIsInfinite t) = matchFloatingType s t
matchPrimFun' (PrimMax _) (PrimMax _) = Just Refl
matchPrimFun' (PrimMin _) (PrimMin _) = Just Refl
matchPrimFun' (PrimFromIntegral s _) (PrimFromIntegral t _) = matchIntegralType s t
matchPrimFun' (PrimToFloating s _) (PrimToFloating t _) = matchNumType s t
matchPrimFun' PrimLAnd PrimLAnd = Just Refl
matchPrimFun' PrimLOr PrimLOr = Just Refl
matchPrimFun' PrimLNot PrimLNot = Just Refl
matchPrimFun' (PrimLt s) (PrimLt t)
| Just Refl <- matchSingleType s t
= Just Refl
matchPrimFun' (PrimGt s) (PrimGt t)
| Just Refl <- matchSingleType s t
= Just Refl
matchPrimFun' (PrimLtEq s) (PrimLtEq t)
| Just Refl <- matchSingleType s t
= Just Refl
matchPrimFun' (PrimGtEq s) (PrimGtEq t)
| Just Refl <- matchSingleType s t
= Just Refl
matchPrimFun' (PrimEq s) (PrimEq t)
| Just Refl <- matchSingleType s t
= Just Refl
matchPrimFun' (PrimNEq s) (PrimNEq t)
| Just Refl <- matchSingleType s t
= Just Refl
matchPrimFun' _ _
= Nothing
{-# INLINEABLE matchTypeR #-}
matchTypeR :: TypeR s -> TypeR t -> Maybe (s :~: t)
matchTypeR = matchTupR matchScalarType
{-# INLINEABLE matchShapeType #-}
matchShapeType :: forall s t. (Sugar.Shape s, Sugar.Shape t) => Maybe (s :~: t)
matchShapeType
| Just Refl <- matchShapeR (Sugar.shapeR @s) (Sugar.shapeR @t)
#ifdef ACCELERATE_INTERNAL_CHECKS
= gcast Refl
#else
= Just (unsafeCoerce Refl)
#endif
| otherwise
= Nothing
{-# INLINEABLE matchShapeR #-}
matchShapeR :: forall s t. ShapeR s -> ShapeR t -> Maybe (s :~: t)
matchShapeR ShapeRz ShapeRz = Just Refl
matchShapeR (ShapeRsnoc shr1) (ShapeRsnoc shr2)
| Just Refl <- matchShapeR shr1 shr2
= Just Refl
matchShapeR _ _ = Nothing
{-# INLINEABLE matchScalarType #-}
matchScalarType :: ScalarType s -> ScalarType t -> Maybe (s :~: t)
matchScalarType (SingleScalarType s) (SingleScalarType t) = matchSingleType s t
matchScalarType (VectorScalarType s) (VectorScalarType t) = matchVectorType s t
matchScalarType _ _ = Nothing
{-# INLINEABLE matchSingleType #-}
matchSingleType :: SingleType s -> SingleType t -> Maybe (s :~: t)
matchSingleType (NumSingleType s) (NumSingleType t) = matchNumType s t
{-# INLINEABLE matchVectorType #-}
matchVectorType :: forall m n s t. VectorType (Vec n s) -> VectorType (Vec m t) -> Maybe (Vec n s :~: Vec m t)
matchVectorType (VectorType n s) (VectorType m t)
| Just Refl <- if n == m
then Just (unsafeCoerce Refl :: n :~: m)
else Nothing
, Just Refl <- matchSingleType s t
= Just Refl
matchVectorType _ _
= Nothing
{-# INLINEABLE matchNumType #-}
matchNumType :: NumType s -> NumType t -> Maybe (s :~: t)
matchNumType (IntegralNumType s) (IntegralNumType t) = matchIntegralType s t
matchNumType (FloatingNumType s) (FloatingNumType t) = matchFloatingType s t
matchNumType _ _ = Nothing
{-# INLINEABLE matchBoundedType #-}
matchBoundedType :: BoundedType s -> BoundedType t -> Maybe (s :~: t)
matchBoundedType (IntegralBoundedType s) (IntegralBoundedType t) = matchIntegralType s t
{-# INLINEABLE matchIntegralType #-}
matchIntegralType :: IntegralType s -> IntegralType t -> Maybe (s :~: t)
matchIntegralType TypeInt TypeInt = Just Refl
matchIntegralType TypeInt8 TypeInt8 = Just Refl
matchIntegralType TypeInt16 TypeInt16 = Just Refl
matchIntegralType TypeInt32 TypeInt32 = Just Refl
matchIntegralType TypeInt64 TypeInt64 = Just Refl
matchIntegralType TypeWord TypeWord = Just Refl
matchIntegralType TypeWord8 TypeWord8 = Just Refl
matchIntegralType TypeWord16 TypeWord16 = Just Refl
matchIntegralType TypeWord32 TypeWord32 = Just Refl
matchIntegralType TypeWord64 TypeWord64 = Just Refl
matchIntegralType _ _ = Nothing
{-# INLINEABLE matchFloatingType #-}
matchFloatingType :: FloatingType s -> FloatingType t -> Maybe (s :~: t)
matchFloatingType TypeHalf TypeHalf = Just Refl
matchFloatingType TypeFloat TypeFloat = Just Refl
matchFloatingType TypeDouble TypeDouble = Just Refl
matchFloatingType _ _ = Nothing
commutes
:: forall env aenv a r.
PrimFun (a -> r)
-> OpenExp env aenv a
-> Maybe (OpenExp env aenv a)
commutes f x = case f of
PrimAdd{} -> Just (swizzle x)
PrimMul{} -> Just (swizzle x)
PrimBAnd{} -> Just (swizzle x)
PrimBOr{} -> Just (swizzle x)
PrimBXor{} -> Just (swizzle x)
PrimEq{} -> Just (swizzle x)
PrimNEq{} -> Just (swizzle x)
PrimMax{} -> Just (swizzle x)
PrimMin{} -> Just (swizzle x)
PrimLAnd -> Just (swizzle x)
PrimLOr -> Just (swizzle x)
_ -> Nothing
where
swizzle :: OpenExp env aenv (a',a') -> OpenExp env aenv (a',a')
swizzle exp
| (a `Pair` b) <- exp
, hashOpenExp a > hashOpenExp b = b `Pair` a
| otherwise = exp