-- | A minor cleanup pass that runs after defunctorisation and applies
-- any type abbreviations. After this, the program consists entirely
-- value bindings.
module Futhark.Internalise.ApplyTypeAbbrs (transformProg) where

import Control.Monad.Identity
import Data.Map.Strict qualified as M
import Data.Maybe (mapMaybe)
import Language.Futhark
import Language.Futhark.Semantic (TypeBinding (..))
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Types

type Types = M.Map VName (Subst StructRetType)

getTypes :: Types -> [Dec] -> Types
getTypes :: Types -> [Dec] -> Types
getTypes Types
types [] = Types
types
getTypes Types
types (TypeDec TypeBindBase Info VName
typebind : [Dec]
ds) = do
  let (TypeBind VName
name Liftedness
l [TypeParamBase VName]
tparams TypeExp (ExpBase Info VName) VName
_ (Info (RetType [VName]
dims StructType
t)) Maybe DocComment
_ SrcLoc
_) = TypeBindBase Info VName
typebind
      tbinding :: TypeBinding
tbinding = Liftedness -> [TypeParamBase VName] -> StructRetType -> TypeBinding
TypeAbbr Liftedness
l [TypeParamBase VName]
tparams (StructRetType -> TypeBinding) -> StructRetType -> TypeBinding
forall a b. (a -> b) -> a -> b
$ [VName] -> StructType -> StructRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims (StructType -> StructRetType) -> StructType -> StructRetType
forall a b. (a -> b) -> a -> b
$ TypeSubs -> StructType -> StructType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName -> Types -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Types
types) StructType
t
      types' :: Types
types' = VName -> Subst StructRetType -> Types -> Types
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name (TypeBinding -> Subst StructRetType
substFromAbbr TypeBinding
tbinding) Types
types
  Types -> [Dec] -> Types
getTypes Types
types' [Dec]
ds
getTypes Types
types (Dec
_ : [Dec]
ds) =
  Types -> [Dec] -> Types
getTypes Types
types [Dec]
ds

-- Perform a given substitution on the types in a pattern.
substPat :: (t -> t) -> Pat t -> Pat t
substPat :: forall t. (t -> t) -> Pat t -> Pat t
substPat t -> t
f Pat t
pat = case Pat t
pat of
  TuplePat [Pat t]
pats SrcLoc
loc -> [Pat t] -> SrcLoc -> Pat t
forall (f :: * -> *) vn t.
[PatBase f vn t] -> SrcLoc -> PatBase f vn t
TuplePat ((Pat t -> Pat t) -> [Pat t] -> [Pat t]
forall a b. (a -> b) -> [a] -> [b]
map ((t -> t) -> Pat t -> Pat t
forall t. (t -> t) -> Pat t -> Pat t
substPat t -> t
f) [Pat t]
pats) SrcLoc
loc
  RecordPat [(Name, Pat t)]
fs SrcLoc
loc -> [(Name, Pat t)] -> SrcLoc -> Pat t
forall (f :: * -> *) vn t.
[(Name, PatBase f vn t)] -> SrcLoc -> PatBase f vn t
RecordPat (((Name, Pat t) -> (Name, Pat t))
-> [(Name, Pat t)] -> [(Name, Pat t)]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pat t) -> (Name, Pat t)
forall {a}. (a, Pat t) -> (a, Pat t)
substField [(Name, Pat t)]
fs) SrcLoc
loc
    where
      substField :: (a, Pat t) -> (a, Pat t)
substField (a
n, Pat t
p) = (a
n, (t -> t) -> Pat t -> Pat t
forall t. (t -> t) -> Pat t -> Pat t
substPat t -> t
f Pat t
p)
  PatParens Pat t
p SrcLoc
loc -> Pat t -> SrcLoc -> Pat t
forall (f :: * -> *) vn t.
PatBase f vn t -> SrcLoc -> PatBase f vn t
PatParens ((t -> t) -> Pat t -> Pat t
forall t. (t -> t) -> Pat t -> Pat t
substPat t -> t
f Pat t
p) SrcLoc
loc
  PatAttr AttrInfo VName
attr Pat t
p SrcLoc
loc -> AttrInfo VName -> Pat t -> SrcLoc -> Pat t
forall (f :: * -> *) vn t.
AttrInfo vn -> PatBase f vn t -> SrcLoc -> PatBase f vn t
PatAttr AttrInfo VName
attr ((t -> t) -> Pat t -> Pat t
forall t. (t -> t) -> Pat t -> Pat t
substPat t -> t
f Pat t
p) SrcLoc
loc
  Id VName
vn (Info t
tp) SrcLoc
loc -> VName -> Info t -> SrcLoc -> Pat t
forall (f :: * -> *) vn t. vn -> f t -> SrcLoc -> PatBase f vn t
Id VName
vn (t -> Info t
forall a. a -> Info a
Info (t -> Info t) -> t -> Info t
forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) SrcLoc
loc
  Wildcard (Info t
tp) SrcLoc
loc -> Info t -> SrcLoc -> Pat t
forall (f :: * -> *) vn t. f t -> SrcLoc -> PatBase f vn t
Wildcard (t -> Info t
forall a. a -> Info a
Info (t -> Info t) -> t -> Info t
forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) SrcLoc
loc
  PatAscription Pat t
p TypeExp (ExpBase Info VName) VName
_ SrcLoc
_ -> (t -> t) -> Pat t -> Pat t
forall t. (t -> t) -> Pat t -> Pat t
substPat t -> t
f Pat t
p
  PatLit PatLit
e (Info t
tp) SrcLoc
loc -> PatLit -> Info t -> SrcLoc -> Pat t
forall (f :: * -> *) vn t.
PatLit -> f t -> SrcLoc -> PatBase f vn t
PatLit PatLit
e (t -> Info t
forall a. a -> Info a
Info (t -> Info t) -> t -> Info t
forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) SrcLoc
loc
  PatConstr Name
n (Info t
tp) [Pat t]
ps SrcLoc
loc -> Name -> Info t -> [Pat t] -> SrcLoc -> Pat t
forall (f :: * -> *) vn t.
Name -> f t -> [PatBase f vn t] -> SrcLoc -> PatBase f vn t
PatConstr Name
n (t -> Info t
forall a. a -> Info a
Info (t -> Info t) -> t -> Info t
forall a b. (a -> b) -> a -> b
$ t -> t
f t
tp) [Pat t]
ps SrcLoc
loc

removeTypeVariablesInType :: Types -> StructType -> StructType
removeTypeVariablesInType :: Types -> StructType -> StructType
removeTypeVariablesInType Types
types =
  TypeSubs -> StructType -> StructType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName -> Types -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Types
types)

substEntry :: Types -> EntryPoint -> EntryPoint
substEntry :: Types -> EntryPoint -> EntryPoint
substEntry Types
types (EntryPoint [EntryParam]
params EntryType
ret) =
  [EntryParam] -> EntryType -> EntryPoint
EntryPoint ((EntryParam -> EntryParam) -> [EntryParam] -> [EntryParam]
forall a b. (a -> b) -> [a] -> [b]
map EntryParam -> EntryParam
onEntryParam [EntryParam]
params) (EntryType -> EntryType
onEntryType EntryType
ret)
  where
    onEntryParam :: EntryParam -> EntryParam
onEntryParam (EntryParam Name
v EntryType
t) =
      Name -> EntryType -> EntryParam
EntryParam Name
v (EntryType -> EntryParam) -> EntryType -> EntryParam
forall a b. (a -> b) -> a -> b
$ EntryType -> EntryType
onEntryType EntryType
t
    onEntryType :: EntryType -> EntryType
onEntryType (EntryType StructType
t Maybe (TypeExp (ExpBase Info VName) VName)
te) =
      StructType
-> Maybe (TypeExp (ExpBase Info VName) VName) -> EntryType
EntryType (Types -> StructType -> StructType
removeTypeVariablesInType Types
types StructType
t) Maybe (TypeExp (ExpBase Info VName) VName)
te

-- Remove all type variables and type abbreviations from a value binding.
removeTypeVariables :: Types -> ValBind -> ValBind
removeTypeVariables :: Types -> ValBind -> ValBind
removeTypeVariables Types
types ValBind
valbind = do
  let (ValBind Maybe (Info EntryPoint)
entry VName
_ Maybe (TypeExp (ExpBase Info VName) VName)
_ (Info (RetType [VName]
dims TypeBase (ExpBase Info VName) Uniqueness
rettype)) [TypeParamBase VName]
_ [PatBase Info VName ParamType]
pats ExpBase Info VName
body Maybe DocComment
_ [AttrInfo VName]
_ SrcLoc
_) = ValBind
valbind
      mapper :: ASTMapper Identity
mapper =
        ASTMapper
          { mapOnExp :: ExpBase Info VName -> Identity (ExpBase Info VName)
mapOnExp = ExpBase Info VName -> Identity (ExpBase Info VName)
onExp,
            mapOnName :: QualName VName -> Identity (QualName VName)
mapOnName = QualName VName -> Identity (QualName VName)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
            mapOnStructType :: StructType -> Identity StructType
mapOnStructType = StructType -> Identity StructType
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StructType -> Identity StructType)
-> (StructType -> StructType) -> StructType -> Identity StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeSubs -> StructType -> StructType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName -> Types -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Types
types),
            mapOnParamType :: ParamType -> Identity ParamType
mapOnParamType = ParamType -> Identity ParamType
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ParamType -> Identity ParamType)
-> (ParamType -> ParamType) -> ParamType -> Identity ParamType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeSubs -> ParamType -> ParamType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName -> Types -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Types
types),
            mapOnResRetType :: ResRetType -> Identity ResRetType
mapOnResRetType = ResRetType -> Identity ResRetType
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ResRetType -> Identity ResRetType)
-> (ResRetType -> ResRetType) -> ResRetType -> Identity ResRetType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeSubs -> ResRetType -> ResRetType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName -> Types -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Types
types)
          }
      onExp :: ExpBase Info VName -> Identity (ExpBase Info VName)
onExp = ASTMapper Identity
-> ExpBase Info VName -> Identity (ExpBase Info VName)
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
forall (m :: * -> *).
Monad m =>
ASTMapper m -> ExpBase Info VName -> m (ExpBase Info VName)
astMap ASTMapper Identity
mapper

  let body' :: ExpBase Info VName
body' = Identity (ExpBase Info VName) -> ExpBase Info VName
forall a. Identity a -> a
runIdentity (Identity (ExpBase Info VName) -> ExpBase Info VName)
-> Identity (ExpBase Info VName) -> ExpBase Info VName
forall a b. (a -> b) -> a -> b
$ ExpBase Info VName -> Identity (ExpBase Info VName)
onExp ExpBase Info VName
body

  ValBind
valbind
    { valBindRetType = Info (applySubst (`M.lookup` types) $ RetType dims rettype),
      valBindParams = map (substPat $ applySubst (`M.lookup` types)) pats,
      valBindEntryPoint = fmap (substEntry types) <$> entry,
      valBindBody = body'
    }

-- | Apply type abbreviations from a list of top-level declarations. A
-- module-free input program is expected, so only value declarations
-- and type declaration are accepted.
transformProg :: (Monad m) => [Dec] -> m [ValBind]
transformProg :: forall (m :: * -> *). Monad m => [Dec] -> m [ValBind]
transformProg [Dec]
decs =
  let types :: Types
types = Types -> [Dec] -> Types
getTypes Types
forall a. Monoid a => a
mempty [Dec]
decs
      onDec :: Dec -> Maybe ValBind
onDec (ValDec ValBind
valbind) = ValBind -> Maybe ValBind
forall a. a -> Maybe a
Just (ValBind -> Maybe ValBind) -> ValBind -> Maybe ValBind
forall a b. (a -> b) -> a -> b
$ Types -> ValBind -> ValBind
removeTypeVariables Types
types ValBind
valbind
      onDec Dec
_ = Maybe ValBind
forall a. Maybe a
Nothing
   in [ValBind] -> m [ValBind]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ValBind] -> m [ValBind]) -> [ValBind] -> m [ValBind]
forall a b. (a -> b) -> a -> b
$ (Dec -> Maybe ValBind) -> [Dec] -> [ValBind]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Dec -> Maybe ValBind
onDec [Dec]
decs