{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE QuasiQuotes #-}
module Clash.Normalize.PrimitiveReductions where
import qualified Control.Lens as Lens
import Control.Lens ((.=))
import Data.List (mapAccumR)
import qualified Data.Maybe as Maybe
import TextShow (showt)
import PrelNames
(boolTyConKey, typeNatAddTyFamNameKey, typeNatMulTyFamNameKey,
typeNatSubTyFamNameKey)
import Unique (getKey)
import SrcLoc (wiredInSrcSpan)
import Clash.Core.DataCon (DataCon)
import Clash.Core.Literal (Literal (..))
import Clash.Core.Name
(nameOcc, Name(..), NameSort(User), mkUnsafeSystemName)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term
(CoreContext (..), PrimInfo (..), Term (..), WorkInfo (..), Pat (..),
collectTermIds, mkApps)
import Clash.Core.TermInfo
import Clash.Core.Type (LitTy (..), Type (..),
TypeView (..), coreView1,
mkFunTy, mkTyConApp,
splitFunForallTy, tyView)
import Clash.Core.TyCon
(TyConMap, TyConName, tyConDataCons, tyConName)
import Clash.Core.TysPrim
(integerPrimTy, typeNatKind, liftedTypeKind)
import Clash.Core.Util
(appendToVec, extractElems, extractTElems, mkRTree,
mkUniqInternalId, mkUniqSystemTyVar, mkVec, dataConInstArgTys,
primCo, undefinedTm)
import Clash.Core.Var (Var (..), mkTyVar, mkLocalId)
import Clash.Core.VarEnv
(InScopeSet, extendInScopeSetList)
import {-# SOURCE #-} Clash.Normalize.Strategy
import Clash.Normalize.Types
import Clash.Rewrite.Types
import Clash.Rewrite.Util
import Clash.Unique
import Clash.Util
import qualified Clash.Util.Interpolate as I
typeNatAdd :: TyConName
typeNatAdd =
Name User "GHC.TypeNats.+" (getKey typeNatAddTyFamNameKey) wiredInSrcSpan
typeNatMul :: TyConName
typeNatMul =
Name User "GHC.TypeNats.*" (getKey typeNatMulTyFamNameKey) wiredInSrcSpan
typeNatSub :: TyConName
typeNatSub =
Name User "GHC.TypeNats.-" (getKey typeNatSubTyFamNameKey) wiredInSrcSpan
vecHeadPrim
:: TyConName
-> Term
vecHeadPrim vecTcNm =
Prim (PrimInfo "Clash.Sized.Vector.head" (vecHeadTy vecTcNm) WorkNever)
vecLastPrim
:: TyConName
-> Term
vecLastPrim vecTcNm =
Prim (PrimInfo "Clash.Sized.Vector.last" (vecHeadTy vecTcNm) WorkNever)
vecHeadTy
:: TyConName
-> Type
vecHeadTy vecNm =
ForAllTy nTV $
ForAllTy aTV $
mkFunTy
(mkTyConApp vecNm [mkTyConApp typeNatAdd [VarTy nTV, LitTy (NumTy 1)], VarTy aTV])
(VarTy aTV)
where
aTV = mkTyVar liftedTypeKind (mkUnsafeSystemName "a" 0)
nTV = mkTyVar typeNatKind (mkUnsafeSystemName "n" 1)
vecTailPrim
:: TyConName
-> Term
vecTailPrim vecTcNm =
Prim (PrimInfo "Clash.Sized.Vector.tail" (vecTailTy vecTcNm) WorkNever)
vecInitPrim
:: TyConName
-> Term
vecInitPrim vecTcNm =
Prim (PrimInfo "Clash.Sized.Vector.init" (vecTailTy vecTcNm) WorkNever)
vecTailTy
:: TyConName
-> Type
vecTailTy vecNm =
ForAllTy nTV $
ForAllTy aTV $
mkFunTy
(mkTyConApp vecNm [mkTyConApp typeNatAdd [VarTy nTV, LitTy (NumTy 1)], VarTy aTV])
(mkTyConApp vecNm [VarTy nTV, VarTy aTV])
where
nTV = mkTyVar typeNatKind (mkUnsafeSystemName "n" 0)
aTV = mkTyVar liftedTypeKind (mkUnsafeSystemName "a" 1)
extractHeadTail
:: DataCon
-> Type
-> Integer
-> Term
-> (Term, Term)
extractHeadTail consCon elTy n vec =
( Case vec elTy [(pat, Var el)]
, Case vec restTy [(pat, Var rest)] )
where
tys = [(LitTy (NumTy n)), elTy, (LitTy (NumTy (n-1)))]
Just [coTy, _elTy, restTy] = dataConInstArgTys consCon tys
mTV = mkTyVar typeNatKind (mkUnsafeSystemName "m" 0)
co = mkLocalId coTy (mkUnsafeSystemName "_co_" 1)
el = mkLocalId elTy (mkUnsafeSystemName "el" 2)
rest = mkLocalId restTy (mkUnsafeSystemName "res" 3)
pat = DataPat consCon [mTV] [co, el, rest]
extractHead
:: DataCon
-> Type
-> Integer
-> Term
-> Term
extractHead consCon elTy vLength vec =
fst (extractHeadTail consCon elTy vLength vec)
extractTail
:: DataCon
-> Type
-> Integer
-> Term
-> Term
extractTail consCon elTy vLength vec =
snd (extractHeadTail consCon elTy vLength vec)
mkVecCons
:: HasCallStack
=> DataCon
-> Type
-> Integer
-> Term
-> Term
-> Term
mkVecCons consCon resTy n h t
| n <= 0 = error "mkVecCons: n <= 0"
| otherwise =
mkApps (Data consCon) [ Right (LitTy (NumTy n))
, Right resTy
, Right (LitTy (NumTy (n-1)))
, Left (primCo consCoTy)
, Left h
, Left t ]
where
args = dataConInstArgTys consCon [LitTy (NumTy n), resTy, LitTy (NumTy (n-1))]
Just (consCoTy : _) = args
mkVecNil
:: DataCon
-> Type
-> Term
mkVecNil nilCon resTy =
mkApps (Data nilCon) [ Right (LitTy (NumTy 0))
, Right resTy
, Left (primCo nilCoTy) ]
where
args = dataConInstArgTys nilCon [LitTy (NumTy 0), resTy]
Just (nilCoTy : _ ) = args
reduceReverse
:: InScopeSet
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceReverse inScope0 n elTy vArg = do
tcm <- Lens.view tcCache
let ty = termType tcm vArg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| Just vecTc <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon, consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let (uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems uniqs0 inScope0 consCon elTy 'V' n vArg
lbody = mkVec nilCon consCon elTy n (reverse vars)
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceReverse: argument does not have a vector type: " ++ showPpr ty
reduceZipWith
:: TransformContext
-> PrimInfo
-> Integer
-> Type
-> Type
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceZipWith _ctx zipWithPrimInfo n lhsElTy rhsElTy resElTy fun lhsArg rhsArg = do
tcm <- Lens.view tcCache
changed (go tcm (termType tcm lhsArg))
where
go tcm (coreView1 tcm -> Just ty) = go tcm ty
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon, consCon] <- tyConDataCons vecTc
= if n == 0 then
mkVecNil nilCon resElTy
else
let
(a, as) = extractHeadTail consCon lhsElTy n lhsArg
(b, bs) = extractHeadTail consCon rhsElTy n rhsArg
c = mkApps fun [Left a, Left b]
cs = mkApps (Prim zipWithPrimInfo) [ Right lhsElTy
, Right rhsElTy
, Right resElTy
, Right (LitTy (NumTy (n - 1)))
, Left fun
, Left as
, Left bs ]
in
mkVecCons consCon resElTy n c cs
go _ ty =
error $ $(curLoc) ++ [I.i|
reduceZipWith: argument does not have a vector type:
#{showPpr ty}
|]
reduceMap
:: TransformContext
-> PrimInfo
-> Integer
-> Type
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceMap _ctx mapPrimInfo n argElTy resElTy fun arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
changed (go tcm ty)
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= if n == 0 then
mkVecNil nilCon argElTy
else
let
nPredTy = Right (LitTy (NumTy (n - 1)))
(a, as) = extractHeadTail consCon argElTy n arg
b = mkApps fun [Left a]
bs = mkApps (Prim mapPrimInfo) [ Right argElTy
, Right resElTy
, nPredTy
, Left fun
, Left as ]
in
mkVecCons consCon resElTy n b bs
go _ ty =
error $ $(curLoc) ++ [I.i|
reduceMap: argument does not have a vector type:
#{showPpr ty}
|]
reduceImap
:: TransformContext
-> Integer
-> Type
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceImap (TransformContext is0 ctx) n argElTy resElTy fun arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
fun1 <- constantPropagation (TransformContext is0 (AppArg Nothing:ctx)) fun
let is1 = extendInScopeSetList is0 (collectTermIds fun1)
(uniqs1,nTv) = mkUniqSystemTyVar (uniqs0,is1) ("n",typeNatKind)
(uniqs2,(vars,elems)) = second (second concat . unzip)
$ uncurry extractElems uniqs1 consCon argElTy 'I' n arg
(Right idxTy:_,_) = splitFunForallTy (termType tcm fun)
(TyConApp idxTcNm _) = tyView idxTy
idxFromIntegerTy = ForAllTy nTv
(foldr mkFunTy
(mkTyConApp idxTcNm
[VarTy nTv])
[integerPrimTy,integerPrimTy])
idxFromInteger = Prim (PrimInfo "Clash.Sized.Internal.Index.fromInteger#" idxFromIntegerTy WorkNever)
idxs = map (App (App (TyApp idxFromInteger (LitTy (NumTy n)))
(Literal (IntegerLiteral (toInteger n))))
. Literal . IntegerLiteral . toInteger) [0..(n-1)]
funApps = zipWith (\i v -> App (App fun1 i) v) idxs vars
lbody = mkVec nilCon consCon resElTy n funApps
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs2
changed lb
go _ ty = error $ $(curLoc) ++ "reduceImap: argument does not have a vector type: " ++ showPpr ty
reduceIterateI
:: TransformContext
-> Integer
-> Type
-> Type
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceIterateI (TransformContext is0 ctx) n aTy vTy f0 a = do
tcm <- Lens.view tcCache
f1 <- constantPropagation (TransformContext is0 (AppArg Nothing:ctx)) f0
uniqs0 <- Lens.use uniqSupply
let
is1 = extendInScopeSetList is0 (collectTermIds f1)
((uniqs1, _is2), elementIds) =
mapAccumR
mkUniqInternalId
(uniqs0, is1)
(zip (map (("el" <>) . showt) [1..n-1]) (repeat aTy))
uniqSupply .= uniqs1
let
TyConApp vecTcNm _ = tyView vTy
Just vecTc = lookupUniqMap vecTcNm tcm
[nilCon, consCon] = tyConDataCons vecTc
elems = map (App f1) (a:map Var elementIds)
vec = mkVec nilCon consCon aTy n (take (fromInteger n) (a:map Var elementIds))
pure (Letrec (zip elementIds elems) vec)
reduceTraverse
:: TransformContext
-> Integer
-> Type
-> Type
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceTraverse (TransformContext is0 ctx) n aTy fTy bTy dict fun arg = do
tcm <- Lens.view tcCache
let (TyConApp apDictTcNm _) = tyView (termType tcm dict)
ty = termType tcm arg
go tcm apDictTcNm ty
where
go tcm apDictTcNm (coreView1 tcm -> Just ty') = go tcm apDictTcNm ty'
go tcm apDictTcNm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
fun1 <- constantPropagation (TransformContext is0 (AppArg Nothing:ctx)) fun
let is1 = extendInScopeSetList is0 (collectTermIds fun1)
(Just apDictTc) = lookupUniqMap apDictTcNm tcm
[apDictCon] = tyConDataCons apDictTc
(Just apDictIdTys) = dataConInstArgTys apDictCon [fTy]
(uniqs1,apDictIds@[functorDictId,pureId,apId,_,_]) =
mapAccumR mkUniqInternalId (uniqs0,is1)
(zip ["functorDict","pure","ap","apConstL","apConstR"]
apDictIdTys)
(TyConApp funcDictTcNm _) = tyView (head apDictIdTys)
(Just funcDictTc) = lookupUniqMap funcDictTcNm tcm
[funcDictCon] = tyConDataCons funcDictTc
(Just funcDictIdTys) = dataConInstArgTys funcDictCon [fTy]
(uniqs2,funcDicIds@[fmapId,_]) =
mapAccumR mkUniqInternalId uniqs1
(zip ["fmap","fmapConst"] funcDictIdTys)
apPat = DataPat apDictCon [] apDictIds
fnPat = DataPat funcDictCon [] funcDicIds
pureTy = varType pureId
pureTm = Case dict pureTy [(apPat,Var pureId)]
apTy = varType apId
apTm = Case dict apTy [(apPat, Var apId)]
funcTy = varType functorDictId
funcTm = Case dict funcTy
[(apPat,Var functorDictId)]
fmapTy = varType fmapId
fmapTm = Case (Var functorDictId) fmapTy
[(fnPat, Var fmapId)]
(uniqs3,(vars,elems)) = second (second concat . unzip)
$ uncurry extractElems uniqs2 consCon aTy 'T' n arg
funApps = map (fun1 `App`) vars
lbody = mkTravVec vecTcNm nilCon consCon (Var (apDictIds!!1))
(Var (apDictIds!!2))
(Var (funcDicIds!!0))
bTy n funApps
lb = Letrec ([((apDictIds!!0), funcTm)
,((apDictIds!!1), pureTm)
,((apDictIds!!2), apTm)
,((funcDicIds!!0), fmapTm)
] ++ init elems) lbody
uniqSupply Lens..= uniqs3
changed lb
go _ _ ty = error $ $(curLoc) ++ "reduceTraverse: argument does not have a vector type: " ++ showPpr ty
mkTravVec :: TyConName
-> DataCon
-> DataCon
-> Term
-> Term
-> Term
-> Type
-> Integer
-> [Term]
-> Term
mkTravVec vecTc nilCon consCon pureTm apTm fmapTm bTy = go
where
go :: Integer -> [Term] -> Term
go _ [] = mkApps pureTm [Right (mkTyConApp vecTc [LitTy (NumTy 0),bTy])
,Left (mkApps (Data nilCon)
[Right (LitTy (NumTy 0))
,Right bTy
,Left (primCo nilCoTy)])]
go n (x:xs) = mkApps apTm
[Right (mkTyConApp vecTc [LitTy (NumTy (n-1)),bTy])
,Right (mkTyConApp vecTc [LitTy (NumTy n),bTy])
,Left (mkApps fmapTm [Right bTy
,Right (mkFunTy (mkTyConApp vecTc [LitTy (NumTy (n-1)),bTy])
(mkTyConApp vecTc [LitTy (NumTy n),bTy]))
,Left (mkApps (Data consCon)
[Right (LitTy (NumTy n))
,Right bTy
,Right (LitTy (NumTy (n-1)))
,Left (primCo (consCoTy n))
])
,Left x])
,Left (go (n-1) xs)]
nilCoTy = head (Maybe.fromJust (dataConInstArgTys nilCon [(LitTy (NumTy 0))
,bTy]))
consCoTy n = head (Maybe.fromJust (dataConInstArgTys consCon
[(LitTy (NumTy n))
,bTy
,(LitTy (NumTy (n-1)))]))
reduceFoldr
:: TransformContext
-> PrimInfo
-> Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceFoldr _ _ 0 _ _ start _ = changed start
reduceFoldr _ctx foldrPrimInfo n aTy fun start arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
changed (go tcm ty)
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, Just vecTc <- lookupUniqMap vecTcNm tcm
, [_nilCon, consCon] <- tyConDataCons vecTc
= let
(a, as) = extractHeadTail consCon aTy n arg
b = mkApps (Prim foldrPrimInfo) [ Right aTy
, Right (termType tcm start)
, Right (LitTy (NumTy (n - 1)))
, Left fun
, Left start
, Left as ]
in
mkApps fun [Left a, Left b]
go _ ty =
error $ $(curLoc) ++ [I.i|
reduceFoldr: argument does not have a vector type:
#{showPpr ty}
|]
reduceFold
:: TransformContext
-> Integer
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceFold (TransformContext is0 ctx) n aTy fun arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
fun1 <- constantPropagation (TransformContext is0 (AppArg Nothing:ctx)) fun
let is1 = extendInScopeSetList is0 (collectTermIds fun1)
(uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems uniqs0 is1 consCon aTy 'F' n arg
lbody = foldV fun1 vars
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceFold: argument does not have a vector type: " ++ showPpr ty
foldV _ [a] = a
foldV f as = let (l,r) = splitAt (length as `div` 2) as
lF = foldV f l
rF = foldV f r
in mkApps f [Left lF, Left rF]
reduceDFold
:: InScopeSet
-> Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceDFold _ 0 _ _ start _ = changed start
reduceDFold is0 n aTy fun start arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let is1 = extendInScopeSetList is0 (collectTermIds fun)
(uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems uniqs0 is1 consCon aTy 'D' n arg
(_ltv:Right snTy:_,_) = splitFunForallTy (termType tcm fun)
(TyConApp snatTcNm _) = tyView snTy
(Just snatTc) = lookupUniqMap snatTcNm tcm
[snatDc] = tyConDataCons snatTc
lbody = doFold (buildSNat snatDc) (n-1) vars
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceDFold: argument does not have a vector type: " ++ showPpr ty
doFold _ _ [] = start
doFold snDc k (x:xs) = mkApps fun
[Right (LitTy (NumTy k))
,Left (snDc k)
,Left x
,Left (doFold snDc (k-1) xs)
]
reduceHead
:: InScopeSet
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceHead inScope n aTy vArg = do
tcm <- Lens.view tcCache
let ty = termType tcm vArg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let (uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems uniqs0 inScope consCon aTy 'H' n vArg
lb = Letrec [head elems] (head vars)
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceHead: argument does not have a vector type: " ++ showPpr ty
reduceTail
:: InScopeSet
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceTail inScope n aTy vArg = do
tcm <- Lens.view tcCache
let ty = termType tcm vArg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let (uniqs1,(_,elems)) = second (second concat . unzip)
$ extractElems uniqs0 inScope consCon aTy 'L' n vArg
b@(tB,_) = elems !! 1
lb = Letrec [b] (Var tB)
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceTail: argument does not have a vector type: " ++ showPpr ty
reduceLast
:: InScopeSet
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceLast inScope n aTy vArg = do
tcm <- Lens.view tcCache
let ty = termType tcm vArg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let (uniqs1,(_,elems)) = second unzip
$ extractElems uniqs0 inScope consCon aTy 'L' n vArg
(tB,_) = head (last elems)
uniqSupply Lens..= uniqs1
case n of
0 -> changed (undefinedTm aTy)
_ -> changed (Letrec (init (concat elems)) (Var tB))
go _ ty = error $ $(curLoc) ++ "reduceLast: argument does not have a vector type: " ++ showPpr ty
reduceInit
:: InScopeSet
-> PrimInfo
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceInit _inScope initPrimInfo n aTy vArg = do
tcm <- Lens.view tcCache
let ty = termType tcm vArg
changed (go tcm ty)
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon, consCon] <- tyConDataCons vecTc
= if n == 0 then
mkVecNil nilCon aTy
else
let
nPredTy = Right (LitTy (NumTy (n - 1)))
(a, as0) = extractHeadTail consCon aTy (n+1) vArg
as1 = mkApps (Prim initPrimInfo) [nPredTy, Right aTy, Left as0]
in
mkVecCons consCon aTy n a as1
go _ ty =
error $ $(curLoc) ++ [I.i|
reduceInit: argument does not have a vector type:
#{showPpr ty}
|]
reduceAppend
:: InScopeSet
-> Integer
-> Integer
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceAppend inScope n m aTy lArg rArg = do
tcm <- Lens.view tcCache
let ty = termType tcm lArg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do uniqs0 <- Lens.use uniqSupply
let (uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems uniqs0 inScope consCon aTy
'C' n lArg
lbody = appendToVec consCon aTy rArg (n+m) vars
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceAppend: argument does not have a vector type: " ++ showPpr ty
reduceUnconcat :: Integer
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceUnconcat n 0 aTy arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= let nilVec = mkVec nilCon consCon aTy 0 []
innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy]
retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceUnconcat: argument does not have a vector type: " ++ showPpr ty
reduceUnconcat _ _ _ _ = error $ $(curLoc) ++ "reduceUnconcat: unimplemented"
reduceTranspose :: Integer
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceTranspose n 0 aTy arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= let nilVec = mkVec nilCon consCon aTy 0 []
innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy]
retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceTranspose: argument does not have a vector type: " ++ showPpr ty
reduceTranspose _ _ _ _ = error $ $(curLoc) ++ "reduceTranspose: unimplemented"
reduceReplicate :: Integer
-> Type
-> Type
-> Term
-> NormalizeSession Term
reduceReplicate n aTy eTy arg = do
tcm <- Lens.view tcCache
go tcm eTy
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= let retVec = mkVec nilCon consCon aTy n (replicate (fromInteger n) arg)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceReplicate: argument does not have a vector type: " ++ showPpr ty
reduceReplace_int
:: InScopeSet
-> Integer
-> Type
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceReplace_int is0 n aTy vTy v i newA = do
tcm <- Lens.view tcCache
go tcm vTy
where
replace_intElement
:: TyConMap
-> DataCon
-> Type
-> Term
-> Integer
-> Term
replace_intElement tcm iDc iTy oldA elIndex = case0
where
(Just boolTc) = lookupUniqMap (getKey boolTyConKey) tcm
[_,trueDc] = tyConDataCons boolTc
eqInt = eqIntPrim iTy (mkTyConApp (tyConName boolTc) [])
case0 = Case (mkApps eqInt [Left i
,Left (mkApps (Data iDc)
[Left (Literal (IntLiteral elIndex))])
])
aTy
[(DefaultPat, oldA)
,(DataPat trueDc [] [], newA)
]
eqIntPrim
:: Type
-> Type
-> Term
eqIntPrim intTy boolTy =
Prim (PrimInfo "Clash.Transformations.eqInt" (mkFunTy intTy (mkFunTy intTy boolTy)) WorkVariable)
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [nilCon,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let iTy = termType tcm i
(TyConApp iTcNm _) = tyView iTy
(Just iTc) = lookupUniqMap iTcNm tcm
[iDc] = tyConDataCons iTc
(uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems
uniqs0
is0
consCon
aTy
'I'
n
v
let replacedEls = zipWith (replace_intElement tcm iDc iTy) vars [0..]
lbody = mkVec nilCon consCon aTy n replacedEls
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceReplace_int: argument does not have "
++ "a vector type: " ++ showPpr ty
reduceIndex_int
:: InScopeSet
-> Integer
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceIndex_int is0 n aTy v i = do
tcm <- Lens.view tcCache
let vTy = termType tcm v
go tcm vTy
where
index_intElement
:: TyConMap
-> DataCon
-> Type
-> (Term, Integer)
-> Term
-> Term
index_intElement tcm iDc iTy (cur,elIndex) next = case0
where
(Just boolTc) = lookupUniqMap (getKey boolTyConKey) tcm
[_,trueDc] = tyConDataCons boolTc
eqInt = eqIntPrim iTy (mkTyConApp (tyConName boolTc) [])
case0 = Case (mkApps eqInt [Left i
,Left (mkApps (Data iDc)
[Left (Literal (IntLiteral elIndex))])
])
aTy
[(DefaultPat, next)
,(DataPat trueDc [] [], cur)
]
eqIntPrim
:: Type
-> Type
-> Term
eqIntPrim intTy boolTy =
Prim (PrimInfo "Clash.Transformations.eqInt" (mkFunTy intTy (mkFunTy intTy boolTy)) WorkVariable)
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_nilCon,consCon] <- tyConDataCons vecTc
= do
uniqs0 <- Lens.use uniqSupply
let iTy = termType tcm i
(TyConApp iTcNm _) = tyView iTy
(Just iTc) = lookupUniqMap iTcNm tcm
[iDc] = tyConDataCons iTc
(uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems
uniqs0
is0
consCon
aTy
'I'
n
v
let indexed = foldr (index_intElement tcm iDc iTy)
(undefinedTm aTy)
(zip vars [0..])
lb = Letrec (init elems) indexed
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "indexReplace_int: argument does not have "
++ "a vector type: " ++ showPpr ty
reduceDTFold
:: InScopeSet
-> Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceDTFold inScope n aTy lrFun brFun arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- lookupUniqMap vecTcNm tcm
, nameOcc vecTcNm == "Clash.Sized.Vector.Vec"
, [_,consCon] <- tyConDataCons vecTc
= do uniqs0 <- Lens.use uniqSupply
let (uniqs1,(vars,elems)) = second (second concat . unzip)
$ extractElems uniqs0 inScope consCon aTy
'T' (2^n) arg
(_ltv:Right snTy:_,_) = splitFunForallTy (termType tcm brFun)
(TyConApp snatTcNm _) = tyView snTy
(Just snatTc) = lookupUniqMap snatTcNm tcm
[snatDc] = tyConDataCons snatTc
lbody = doFold (buildSNat snatDc) (n-1) vars
lb = Letrec (init elems) lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceDTFold: argument does not have a vector type: " ++ showPpr ty
doFold :: (Integer -> Term) -> Integer -> [Term] -> Term
doFold _ _ [x] = mkApps lrFun [Left x]
doFold snDc k xs =
let (xsL,xsR) = splitAt (2^k) xs
k' = k-1
eL = doFold snDc k' xsL
eR = doFold snDc k' xsR
in mkApps brFun [Right (LitTy (NumTy k))
,Left (snDc k)
,Left eL
,Left eR
]
reduceTFold
:: InScopeSet
-> Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceTFold inScope n aTy lrFun brFun arg = do
tcm <- Lens.view tcCache
let ty = termType tcm arg
go tcm ty
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp treeTcNm _)
| (Just treeTc) <- lookupUniqMap treeTcNm tcm
, nameOcc treeTcNm == "Clash.Sized.RTree.RTree"
, [lrCon,brCon] <- tyConDataCons treeTc
= do uniqs0 <- Lens.use uniqSupply
let (uniqs1,(vars,elems)) = extractTElems uniqs0 inScope lrCon brCon aTy 'T' n arg
(_ltv:Right snTy:_,_) = splitFunForallTy (termType tcm brFun)
(TyConApp snatTcNm _) = tyView snTy
(Just snatTc) = lookupUniqMap snatTcNm tcm
[snatDc] = tyConDataCons snatTc
lbody = doFold (buildSNat snatDc) (n-1) vars
lb = Letrec elems lbody
uniqSupply Lens..= uniqs1
changed lb
go _ ty = error $ $(curLoc) ++ "reduceTFold: argument does not have a tree type: " ++ showPpr ty
doFold _ _ [x] = mkApps lrFun [Left x]
doFold snDc k xs =
let (xsL,xsR) = splitAt (length xs `div` 2) xs
k' = k-1
eL = doFold snDc k' xsL
eR = doFold snDc k' xsR
in mkApps brFun [Right (LitTy (NumTy k))
,Left (snDc k)
,Left eL
,Left eR
]
reduceTReplicate :: Integer
-> Type
-> Type
-> Term
-> NormalizeSession Term
reduceTReplicate n aTy eTy arg = do
tcm <- Lens.view tcCache
go tcm eTy
where
go tcm (coreView1 tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp treeTcNm _)
| (Just treeTc) <- lookupUniqMap treeTcNm tcm
, nameOcc treeTcNm == "Clash.Sized.RTree.RTree"
, [lrCon,brCon] <- tyConDataCons treeTc
= let retVec = mkRTree lrCon brCon aTy n (replicate (2^n) arg)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceTReplicate: argument does not have a RTree type: " ++ showPpr ty
buildSNat :: DataCon -> Integer -> Term
buildSNat snatDc i =
mkApps (Data snatDc)
[Right (LitTy (NumTy i))
,Left (Literal (NaturalLiteral (toInteger i)))
]