{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.AST.Environment
where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Error
data Val env where
Empty :: Val ()
Push :: Val env -> t -> Val (env, t)
push :: Val env -> (LeftHandSide s t env env', t) -> Val env'
push env (LeftHandSideWildcard _, _ ) = env
push env (LeftHandSideSingle _ , a ) = env `Push` a
push env (LeftHandSidePair l1 l2, (a, b)) = push env (l1, a) `push` (l2, b)
prj :: Idx env t -> Val env -> t
prj ZeroIdx (Push _ v) = v
prj (SuccIdx idx) (Push val _) = prj idx val
newtype env :> env' = Weaken { (>:>) :: forall t'. Idx env t' -> Idx env' t' }
weakenId :: env :> env
weakenId = Weaken id
weakenSucc' :: env :> env' -> env :> (env', t)
weakenSucc' (Weaken f) = Weaken (SuccIdx . f)
weakenSucc :: (env, t) :> env' -> env :> env'
weakenSucc (Weaken f) = Weaken (f . SuccIdx)
weakenEmpty :: () :> env'
weakenEmpty = Weaken $ \case { }
sink :: forall env env' t. env :> env' -> (env, t) :> (env', t)
sink (Weaken f) = Weaken g
where
g :: Idx (env, t) t' -> Idx (env', t) t'
g ZeroIdx = ZeroIdx
g (SuccIdx ix) = SuccIdx $ f ix
infixr 9 .>
(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
(.>) (Weaken f) (Weaken g) = Weaken (f . g)
sinkWithLHS :: HasCallStack => LeftHandSide s t env1 env1' -> LeftHandSide s t env2 env2' -> env1 :> env2 -> env1' :> env2'
sinkWithLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k
sinkWithLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = sink k
sinkWithLHS (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k = sinkWithLHS b1 b2 $ sinkWithLHS a1 a2 k
sinkWithLHS _ _ _ = internalError "left hand sides do not match"
weakenWithLHS :: forall s t env env'. LeftHandSide s t env env' -> env :> env'
weakenWithLHS = go weakenId
where
go :: env2 :> env' -> LeftHandSide s arrs env1 env2 -> env1 :> env'
go k (LeftHandSideWildcard _) = k
go k (LeftHandSideSingle _) = weakenSucc k
go k (LeftHandSidePair l1 l2) = go (go k l2) l1