{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Core.Term
( Term (..)
, mkAbstraction
, mkTyLams
, mkLams
, mkApps
, mkTyApps
, mkTmApps
, mkTicks
, TmName
, idToVar
, varToId
, LetBinding
, Pat (..)
, patIds
, patVars
, Alt
, TickInfo (..)
, stripTicks
, partitionTicks
, NameMod (..)
, PrimInfo (..)
, WorkInfo (..)
, CoreContext (..)
, Context
, isLambdaBodyCtx
, isTickCtx
, walkTerm
, collectArgs
, collectArgsTicks
, collectTicks
, collectTermIds
, collectBndrs
, primArg
) where
import Control.DeepSeq
import Data.Binary (Binary)
import Data.Coerce (coerce)
import qualified Data.DList as DList
import Data.Either (lefts, rights)
import Data.Foldable (foldl')
import Data.Maybe (catMaybes)
import Data.Hashable (Hashable)
import Data.List (nub, partition)
import Data.Text (Text)
import GHC.Generics
import SrcLoc (SrcSpan)
import Clash.Core.DataCon (DataCon)
import Clash.Core.Literal (Literal)
import Clash.Core.Name (Name (..))
import {-# SOURCE #-} Clash.Core.Subst ()
import {-# SOURCE #-} Clash.Core.Type
import Clash.Core.Var (Var(Id), Id)
import Clash.Util (curLoc)
data Term
= Var !Id
| Data !DataCon
| Literal !Literal
| Prim !PrimInfo
| Lam !Id Term
| TyLam !TyVar Term
| App !Term !Term
| TyApp !Term !Type
| Letrec [LetBinding] Term
| Case !Term !Type [Alt]
| Cast !Term !Type !Type
| Tick !TickInfo !Term
deriving (Show,Generic,NFData,Hashable,Binary)
data TickInfo
= SrcSpan !SrcSpan
| NameMod !NameMod !Type
| DeDup
| NoDeDup
deriving (Eq,Show,Generic,NFData,Hashable,Binary)
data NameMod
= PrefixName
| SuffixName
| SuffixNameP
| SetName
deriving (Eq,Show,Generic,NFData,Hashable,Binary)
data PrimInfo = PrimInfo
{ primName :: !Text
, primType :: !Type
, primWorkInfo :: !WorkInfo
} deriving (Show,Generic,NFData,Hashable,Binary)
data WorkInfo
= WorkConstant
| WorkNever
| WorkVariable
| WorkAlways
deriving (Show,Generic,NFData,Hashable,Binary)
type TmName = Name Term
type LetBinding = (Id, Term)
data Pat
= DataPat !DataCon [TyVar] [Id]
| LitPat !Literal
| DefaultPat
deriving (Eq,Ord,Show,Generic,NFData,Hashable,Binary)
type Alt = (Pat,Term)
patIds :: Pat -> ([TyVar],[Id])
patIds (DataPat _ tvs ids) = (tvs,ids)
patIds _ = ([],[])
patVars :: Pat -> [Var a]
patVars (DataPat _ tvs ids) = coerce tvs ++ coerce ids
patVars _ = []
mkAbstraction :: Term -> [Either Id TyVar] -> Term
mkAbstraction = foldr (either Lam TyLam)
mkTyLams :: Term -> [TyVar] -> Term
mkTyLams tm = mkAbstraction tm . map Right
mkLams :: Term -> [Id] -> Term
mkLams tm = mkAbstraction tm . map Left
mkApps :: Term -> [Either Term Type] -> Term
mkApps = foldl' (\e a -> either (App e) (TyApp e) a)
mkTmApps :: Term -> [Term] -> Term
mkTmApps = foldl' App
mkTyApps :: Term -> [Type] -> Term
mkTyApps = foldl' TyApp
mkTicks :: Term -> [TickInfo] -> Term
mkTicks tm ticks = foldl' (\e s -> Tick s e) tm (nub ticks)
data CoreContext
= AppFun
| AppArg (Maybe (Text, Int, Int))
| TyAppC
| LetBinding Id [Id]
| LetBody [Id]
| LamBody Id
| TyLamBody TyVar
| CaseAlt Pat
| CaseScrut
| CastBody
| TickC TickInfo
deriving (Show, Generic, NFData, Hashable, Binary)
type Context = [CoreContext]
instance Eq CoreContext where
c == c' = case (c, c') of
(AppFun, AppFun) -> True
(AppArg _, AppArg _) -> True
(TyAppC, TyAppC) -> True
(LetBinding i is, LetBinding i' is') -> i == i' && is == is'
(LetBody is, LetBody is') -> is == is'
(LamBody i, LamBody i') -> i == i'
(TyLamBody tv, TyLamBody tv') -> tv == tv'
(CaseAlt p, CaseAlt p') -> p == p'
(CaseScrut, CaseScrut) -> True
(CastBody, CastBody) -> True
(TickC sp, TickC sp') -> sp == sp'
(_, _) -> False
isLambdaBodyCtx :: CoreContext -> Bool
isLambdaBodyCtx (LamBody _) = True
isLambdaBodyCtx _ = False
isTickCtx :: CoreContext -> Bool
isTickCtx (TickC _) = True
isTickCtx _ = False
stripTicks :: Term -> Term
stripTicks (Tick _ e) = stripTicks e
stripTicks e = e
collectArgs :: Term -> (Term, [Either Term Type])
collectArgs = go []
where
go args (App e1 e2) = go (Left e2:args) e1
go args (TyApp e t) = go (Right t:args) e
go args (Tick _ e) = go args e
go args e = (e, args)
collectTicks :: Term -> (Term, [TickInfo])
collectTicks = go []
where
go ticks (Tick s e) = go (s:ticks) e
go ticks e = (e,ticks)
collectArgsTicks :: Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks = go [] []
where
go args ticks (App e1 e2) = go (Left e2:args) ticks e1
go args ticks (TyApp e t) = go (Right t:args) ticks e
go args ticks (Tick s e) = go args (s:ticks) e
go args ticks e = (e, args, ticks)
collectBndrs :: Term -> ([Either Id TyVar], Term)
collectBndrs = go []
where
go bs (Lam v e') = go (Left v:bs) e'
go bs (TyLam tv e') = go (Right tv:bs) e'
go bs e' = (reverse bs,e')
primArg
:: Term
-> Maybe (Text, Int, Int)
primArg (collectArgs -> t) =
case t of
(Prim p, args) ->
Just (primName p, length (rights args), length (lefts args))
_ ->
Nothing
partitionTicks
:: [TickInfo]
-> ([TickInfo], [TickInfo])
partitionTicks = partition (\case {SrcSpan {} -> True; _ -> False})
walkTerm :: forall a . (Term -> Maybe a) -> Term -> [a]
walkTerm f = catMaybes . DList.toList . go
where
go :: Term -> DList.DList (Maybe a)
go t = DList.cons (f t) $ case t of
Var _ -> mempty
Data _ -> mempty
Literal _ -> mempty
Prim _ -> mempty
Lam _ t1 -> go t1
TyLam _ t1 -> go t1
App t1 t2 -> go t1 <> go t2
TyApp t1 _ -> go t1
Letrec bndrs t1 -> go t1 <> mconcat (map (go . snd) bndrs)
Case t1 _ alts -> go t1 <> mconcat (map (go . snd) alts)
Cast t1 _ _ -> go t1
Tick _ t1 -> go t1
collectTermIds :: Term -> [Id]
collectTermIds = concat . walkTerm (Just . go)
where
go :: Term -> [Id]
go (Var i) = [i]
go (Lam i _) = [i]
go (Letrec bndrs _) = map fst bndrs
go (Case _ _ alts) = concatMap (pat . fst) alts
go (Data _) = []
go (Literal _) = []
go (Prim _) = []
go (TyLam _ _) = []
go (App _ _) = []
go (TyApp _ _) = []
go (Cast _ _ _) = []
go (Tick _ _) = []
pat :: Pat -> [Id]
pat (DataPat _ _ ids) = ids
pat (LitPat _) = []
pat DefaultPat = []
idToVar :: Id -> Term
idToVar i@(Id {}) = Var i
idToVar tv = error $ $(curLoc) ++ "idToVar: tyVar: " ++ show tv
varToId :: Term -> Id
varToId (Var i) = i
varToId e = error $ $(curLoc) ++ "varToId: not a var: " ++ show e