{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Normalize.PrimitiveReductions where
import qualified Control.Lens as Lens
import qualified Data.HashMap.Lazy as HashMap
import qualified Data.Maybe as Maybe
import Unbound.Generics.LocallyNameless (bind, embed, rec, rebind)
import Clash.Core.DataCon (DataCon, dataConInstArgTys)
import Clash.Core.Literal (Literal (..))
import Clash.Core.Name
import Clash.Core.Pretty (showDoc)
import Clash.Core.Term (Term (..), Pat (..))
import Clash.Core.Type (LitTy (..), Type (..),
TypeView (..), coreView,
mkFunTy, mkTyConApp,
splitFunForallTy, tyView,
undefinedTy)
import Clash.Core.TyCon (TyConName, tyConDataCons)
import Clash.Core.TysPrim (integerPrimTy, typeNatKind)
import Clash.Core.Util (appendToVec, extractElems,
extractTElems, idToVar,
mkApps, mkRTree, mkVec,
termType)
import Clash.Core.Var (Var (..))
import Clash.Normalize.Types
import Clash.Rewrite.Types
import Clash.Rewrite.Util
import Clash.Util
reduceZipWith :: Integer
-> Type
-> Type
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceZipWith n lhsElTy rhsElTy resElTy fun lhsArg rhsArg = do
tcm <- Lens.view tcCache
ty <- termType tcm lhsArg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let (varsL,elemsL) = second concat . unzip
$ extractElems consCon lhsElTy 'L' n lhsArg
(varsR,elemsR) = second concat . unzip
$ extractElems consCon rhsElTy 'R' n rhsArg
funApps = zipWith (\l r -> mkApps fun [Left l,Left r]) varsL varsR
lbody = mkVec nilCon consCon resElTy n funApps
lb = Letrec (bind (rec (init elemsL ++ init elemsR)) lbody)
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceZipWith: argument does not have a vector type: " ++ showDoc ty
reduceMap :: Integer
-> Type
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceMap n argElTy resElTy fun arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let (vars,elems) = second concat . unzip
$ extractElems consCon argElTy 'A' n arg
funApps = map (fun `App`) vars
lbody = mkVec nilCon consCon resElTy n funApps
lb = Letrec (bind (rec (init elems)) lbody)
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceMap: argument does not have a vector type: " ++ showDoc ty
reduceImap :: Integer
-> Type
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceImap n argElTy resElTy fun arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= do
let (vars,elems) = second concat . unzip
$ extractElems consCon argElTy 'I' n arg
(Right idxTy:_,_) <- splitFunForallTy <$> termType tcm fun
let (TyConApp idxTcNm _) = tyView idxTy
nTv = string2InternalName "n"
idxFromIntegerTy = ForAllTy (bind (TyVar nTv (embed typeNatKind))
(foldr mkFunTy
(mkTyConApp idxTcNm
[VarTy typeNatKind nTv])
[integerPrimTy,integerPrimTy]))
idxFromInteger = Prim "Clash.Sized.Internal.Index.fromInteger#"
idxFromIntegerTy
idxs = map (App (App (TyApp idxFromInteger (LitTy (NumTy n)))
(Literal (IntegerLiteral (toInteger n))))
. Literal . IntegerLiteral . toInteger) [0..(n-1)]
funApps = zipWith (\i v -> App (App fun i) v) idxs vars
lbody = mkVec nilCon consCon resElTy n funApps
lb = Letrec (bind (rec (init elems)) lbody)
changed lb
go _ ty = error $ $(curLoc) ++ "reduceImap: argument does not have a vector type: " ++ showDoc ty
reduceTraverse :: Integer
-> Type
-> Type
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceTraverse n aTy fTy bTy dict fun arg = do
tcm <- Lens.view tcCache
(TyConApp apDictTcNm _) <- tyView <$> termType tcm dict
ty <- termType tcm arg
go tcm apDictTcNm ty
where
go tcm apDictTcNm (coreView tcm -> Just ty') = go tcm apDictTcNm ty'
go tcm apDictTcNm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let (Just apDictTc) = HashMap.lookup (nameOcc apDictTcNm) tcm
[apDictCon] = tyConDataCons apDictTc
(Just apDictIdTys) = dataConInstArgTys apDictCon [fTy]
apDictIds = zipWith Id (map string2InternalName
["functorDict"
,"pure"
,"ap"
,"apConstL"
,"apConstR"])
(map embed apDictIdTys)
(TyConApp funcDictTcNm _) = tyView (head apDictIdTys)
(Just funcDictTc) = HashMap.lookup (nameOcc funcDictTcNm) tcm
[funcDictCon] = tyConDataCons funcDictTc
(Just funcDictIdTys) = dataConInstArgTys funcDictCon [fTy]
funcDicIds = zipWith Id (map string2InternalName ["fmap","fmapConst"])
(map embed funcDictIdTys)
apPat = DataPat (embed apDictCon) (rebind [] apDictIds)
fnPat = DataPat (embed funcDictCon) (rebind [] funcDicIds)
pureTy = apDictIdTys!!1
pureTm = Case dict pureTy [bind apPat (Var pureTy (string2InternalName "pure"))]
apTy = apDictIdTys!!2
apTm = Case dict apTy [bind apPat (Var apTy (string2InternalName "ap"))]
funcTy = (head apDictIdTys)
funcTm = Case dict funcTy
[bind apPat (Var funcTy (string2InternalName "functorDict"))]
fmapTy = (head funcDictIdTys)
fmapTm = Case (Var funcTy (string2InternalName "functorDict")) fmapTy
[bind fnPat (Var fmapTy (string2InternalName "fmap"))]
(vars,elems) = second concat . unzip
$ extractElems consCon aTy 'T' n arg
funApps = map (fun `App`) vars
lbody = mkTravVec vecTcNm nilCon consCon (idToVar (apDictIds!!1))
(idToVar (apDictIds!!2))
(idToVar (funcDicIds!!0))
bTy n funApps
lb = Letrec (bind (rec ([((apDictIds!!0),embed funcTm)
,((apDictIds!!1),embed pureTm)
,((apDictIds!!2),embed apTm)
,((funcDicIds!!0),embed fmapTm)
] ++ init elems)) lbody)
in changed lb
go _ _ ty = error $ $(curLoc) ++ "reduceTraverse: argument does not have a vector type: " ++ showDoc ty
mkTravVec :: TyConName
-> DataCon
-> DataCon
-> Term
-> Term
-> Term
-> Type
-> Integer
-> [Term]
-> Term
mkTravVec vecTc nilCon consCon pureTm apTm fmapTm bTy = go
where
go :: Integer -> [Term] -> Term
go _ [] = mkApps pureTm [Right (mkTyConApp vecTc [LitTy (NumTy 0),bTy])
,Left (mkApps (Data nilCon)
[Right (LitTy (NumTy 0))
,Right bTy
,Left (Prim "_CO_" nilCoTy)])]
go n (x:xs) = mkApps apTm
[Right (mkTyConApp vecTc [LitTy (NumTy (n-1)),bTy])
,Right (mkTyConApp vecTc [LitTy (NumTy n),bTy])
,Left (mkApps fmapTm [Right bTy
,Right (mkFunTy (mkTyConApp vecTc [LitTy (NumTy (n-1)),bTy])
(mkTyConApp vecTc [LitTy (NumTy n),bTy]))
,Left (mkApps (Data consCon)
[Right (LitTy (NumTy n))
,Right bTy
,Right (LitTy (NumTy (n-1)))
,Left (Prim "_CO_" (consCoTy n))
])
,Left x])
,Left (go (n-1) xs)]
nilCoTy = head (Maybe.fromJust (dataConInstArgTys nilCon [(LitTy (NumTy 0))
,bTy]))
consCoTy n = head (Maybe.fromJust (dataConInstArgTys consCon
[(LitTy (NumTy n))
,bTy
,(LitTy (NumTy (n-1)))]))
reduceFoldr :: Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceFoldr n aTy fun start arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= let (vars,elems) = second concat . unzip
$ extractElems consCon aTy 'G' n arg
lbody = foldr (\l r -> mkApps fun [Left l,Left r]) start vars
lb = Letrec (bind (rec (init elems)) lbody)
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceFoldr: argument does not have a vector type: " ++ showDoc ty
reduceFold :: Integer
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceFold n aTy fun arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= let (vars,elems) = second concat . unzip
$ extractElems consCon aTy 'F' n arg
lbody = foldV vars
lb = Letrec (bind (rec (init elems)) lbody)
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceFold: argument does not have a vector type: " ++ showDoc ty
foldV [a] = a
foldV as = let (l,r) = splitAt (length as `div` 2) as
lF = foldV l
rF = foldV r
in mkApps fun [Left lF, Left rF]
reduceDFold :: Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceDFold n aTy fun start arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= do
let (vars,elems) = second concat . unzip
$ extractElems consCon aTy 'D' n arg
(_ltv:Right snTy:_,_) <- splitFunForallTy <$> termType tcm fun
let (TyConApp snatTcNm _) = tyView snTy
(Just snatTc) = HashMap.lookup (nameOcc snatTcNm) tcm
[snatDc] = tyConDataCons snatTc
lbody = doFold (buildSNat snatDc) (n-1) vars
lb = Letrec (bind (rec (init elems)) lbody)
changed lb
go _ ty = error $ $(curLoc) ++ "reduceDFold: argument does not have a vector type: " ++ showDoc ty
doFold _ _ [] = start
doFold snDc k (x:xs) = mkApps fun
[Right (LitTy (NumTy k))
,Left (snDc k)
,Left x
,Left (doFold snDc (k-1) xs)
]
reduceHead :: Integer
-> Type
-> Term
-> NormalizeSession Term
reduceHead n aTy vArg = do
tcm <- Lens.view tcCache
ty <- termType tcm vArg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= let (vars,elems) = second concat . unzip
$ extractElems consCon aTy 'H' n vArg
lb = Letrec (bind (rec [head elems]) (head vars))
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceHead: argument does not have a vector type: " ++ showDoc ty
reduceTail :: Integer
-> Type
-> Term
-> NormalizeSession Term
reduceTail n aTy vArg = do
tcm <- Lens.view tcCache
ty <- termType tcm vArg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= let (_,elems) = second concat . unzip
$ extractElems consCon aTy 'L' n vArg
b@(tB,_) = elems !! 1
lb = Letrec (bind (rec [b]) (idToVar tB))
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceTail: argument does not have a vector type: " ++ showDoc ty
reduceLast :: Integer
-> Type
-> Term
-> NormalizeSession Term
reduceLast n aTy vArg = do
tcm <- Lens.view tcCache
ty <- termType tcm vArg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= let (_,elems) = unzip
$ extractElems consCon aTy 'L' n vArg
(tB,_) = head (last elems)
in case n of
0 -> changed (mkApps (Prim "Clash.Transformations.undefined" undefinedTy) [Right aTy])
_ -> changed (Letrec (bind (rec (init (concat elems))) (idToVar tB)))
go _ ty = error $ $(curLoc) ++ "reduceLast: argument does not have a vector type: " ++ showDoc ty
reduceInit :: Integer
-> Type
-> Term
-> NormalizeSession Term
reduceInit n aTy vArg = do
tcm <- Lens.view tcCache
ty <- termType tcm vArg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let (_,elems) = unzip
$ extractElems consCon aTy 'L' n vArg
in case n of
0 -> changed (mkApps (Prim "Clash.Transformations.undefined" undefinedTy) [Right aTy])
1 -> changed (mkVec nilCon consCon aTy 0 [])
_ -> let el = init elems
iv = mkVec nilCon consCon aTy (n-1) (map (idToVar . fst . head) el)
lb = rec (init (concat el))
in changed (Letrec (bind lb iv))
go _ ty = error $ $(curLoc) ++ "reduceInit: argument does not have a vector type: " ++ showDoc ty
reduceAppend :: Integer
-> Integer
-> Type
-> Term
-> Term
-> NormalizeSession Term
reduceAppend n m aTy lArg rArg = do
tcm <- Lens.view tcCache
ty <- termType tcm lArg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= let (vars,elems) = second concat . unzip
$ extractElems consCon aTy 'C' n lArg
lbody = appendToVec consCon aTy rArg (n+m) vars
lb = Letrec (bind (rec (init elems)) lbody)
in changed lb
go _ ty = error $ $(curLoc) ++ "reduceAppend: argument does not have a vector type: " ++ showDoc ty
reduceUnconcat :: Integer
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceUnconcat n 0 aTy arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let nilVec = mkVec nilCon consCon aTy 0 []
innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy]
retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceUnconcat: argument does not have a vector type: " ++ showDoc ty
reduceUnconcat _ _ _ _ = error $ $(curLoc) ++ "reduceUnconcat: unimplemented"
reduceTranspose :: Integer
-> Integer
-> Type
-> Term
-> NormalizeSession Term
reduceTranspose n 0 aTy arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let nilVec = mkVec nilCon consCon aTy 0 []
innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy]
retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceTranspose: argument does not have a vector type: " ++ showDoc ty
reduceTranspose _ _ _ _ = error $ $(curLoc) ++ "reduceTranspose: unimplemented"
reduceReplicate :: Integer
-> Type
-> Type
-> Term
-> NormalizeSession Term
reduceReplicate n aTy eTy arg = do
tcm <- Lens.view tcCache
go tcm eTy
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [nilCon,consCon] <- tyConDataCons vecTc
= let retVec = mkVec nilCon consCon aTy n (replicate (fromInteger n) arg)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceReplicate: argument does not have a vector type: " ++ showDoc ty
reduceDTFold :: Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceDTFold n aTy lrFun brFun arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp vecTcNm _)
| (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm
, [_,consCon] <- tyConDataCons vecTc
= do let (vars,elems) = second concat . unzip
$ extractElems consCon aTy 'T' (2^n) arg
(_ltv:Right snTy:_,_) <- splitFunForallTy <$> termType tcm brFun
let (TyConApp snatTcNm _) = tyView snTy
(Just snatTc) = HashMap.lookup (nameOcc snatTcNm) tcm
[snatDc] = tyConDataCons snatTc
lbody = doFold (buildSNat snatDc) (n-1) vars
lb = Letrec (bind (rec (init elems)) lbody)
changed lb
go _ ty = error $ $(curLoc) ++ "reduceDTFold: argument does not have a vector type: " ++ showDoc ty
doFold :: (Integer -> Term) -> Integer -> [Term] -> Term
doFold _ _ [x] = mkApps lrFun [Left x]
doFold snDc k xs =
let (xsL,xsR) = splitAt (2^k) xs
k' = k-1
eL = doFold snDc k' xsL
eR = doFold snDc k' xsR
in mkApps brFun [Right (LitTy (NumTy k))
,Left (snDc k)
,Left eL
,Left eR
]
reduceTFold :: Integer
-> Type
-> Term
-> Term
-> Term
-> NormalizeSession Term
reduceTFold n aTy lrFun brFun arg = do
tcm <- Lens.view tcCache
ty <- termType tcm arg
go tcm ty
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp treeTcNm _)
| (Just treeTc) <- HashMap.lookup (nameOcc treeTcNm) tcm
, [lrCon,brCon] <- tyConDataCons treeTc
= do let (vars,elems) = extractTElems lrCon brCon aTy 'T' n arg
(_ltv:Right snTy:_,_) <- splitFunForallTy <$> termType tcm brFun
let (TyConApp snatTcNm _) = tyView snTy
(Just snatTc) = HashMap.lookup (nameOcc snatTcNm) tcm
[snatDc] = tyConDataCons snatTc
lbody = doFold (buildSNat snatDc) (n-1) vars
lb = Letrec (bind (rec elems) lbody)
changed lb
go _ ty = error $ $(curLoc) ++ "reduceTFold: argument does not have a tree type: " ++ showDoc ty
doFold _ _ [x] = mkApps lrFun [Left x]
doFold snDc k xs =
let (xsL,xsR) = splitAt (length xs `div` 2) xs
k' = k-1
eL = doFold snDc k' xsL
eR = doFold snDc k' xsR
in mkApps brFun [Right (LitTy (NumTy k))
,Left (snDc k)
,Left eL
,Left eR
]
reduceTReplicate :: Integer
-> Type
-> Type
-> Term
-> NormalizeSession Term
reduceTReplicate n aTy eTy arg = do
tcm <- Lens.view tcCache
go tcm eTy
where
go tcm (coreView tcm -> Just ty') = go tcm ty'
go tcm (tyView -> TyConApp treeTcNm _)
| (Just treeTc) <- HashMap.lookup (nameOcc treeTcNm) tcm
, [lrCon,brCon] <- tyConDataCons treeTc
= let retVec = mkRTree lrCon brCon aTy n (replicate (2^n) arg)
in changed retVec
go _ ty = error $ $(curLoc) ++ "reduceTReplicate: argument does not have a vector type: " ++ showDoc ty
buildSNat :: DataCon -> Integer -> Term
buildSNat snatDc i =
mkApps (Data snatDc)
[Right (LitTy (NumTy i))
#if MIN_VERSION_ghc(8,2,0)
,Left (Literal (NaturalLiteral (toInteger i)))
#else
,Left (Literal (IntegerLiteral (toInteger i)))
#endif
]