{-# LANGUAGE TypeFamilies #-}

-- | Inspecing and modifying t'Pat's, function parameters and
-- pattern elements.
module Futhark.IR.Prop.Pat
  ( -- * Function parameters
    paramIdent,
    paramType,
    paramDeclType,

    -- * Pat elements
    patElemIdent,
    patElemType,
    setPatElemDec,
    patIdents,
    patNames,
    patTypes,
    patSize,

    -- * Pat construction
    basicPat,
  )
where

import Futhark.IR.Prop.Types (DeclTyped (..), Typed (..))
import Futhark.IR.Syntax

-- | The 'Type' of a parameter.
paramType :: (Typed dec) => Param dec -> Type
paramType :: forall dec. Typed dec => Param dec -> Type
paramType = Param dec -> Type
forall t. Typed t => t -> Type
typeOf

-- | The 'DeclType' of a parameter.
paramDeclType :: (DeclTyped dec) => Param dec -> DeclType
paramDeclType :: forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType = Param dec -> DeclType
forall t. DeclTyped t => t -> DeclType
declTypeOf

-- | An 'Ident' corresponding to a parameter.
paramIdent :: (Typed dec) => Param dec -> Ident
paramIdent :: forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
param = VName -> Type -> Ident
Ident (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) (Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
param)

-- | An 'Ident' corresponding to a pattern element.
patElemIdent :: (Typed dec) => PatElem dec -> Ident
patElemIdent :: forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem dec
pelem = VName -> Type -> Ident
Ident (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pelem) (PatElem dec -> Type
forall t. Typed t => t -> Type
typeOf PatElem dec
pelem)

-- | The type of a name bound by a t'PatElem'.
patElemType :: (Typed dec) => PatElem dec -> Type
patElemType :: forall dec. Typed dec => PatElem dec -> Type
patElemType = PatElem dec -> Type
forall t. Typed t => t -> Type
typeOf

-- | Set the rep of a t'PatElem'.
setPatElemDec :: PatElem oldattr -> newattr -> PatElem newattr
setPatElemDec :: forall oldattr newattr.
PatElem oldattr -> newattr -> PatElem newattr
setPatElemDec PatElem oldattr
pe newattr
x = (oldattr -> newattr) -> PatElem oldattr -> PatElem newattr
forall a b. (a -> b) -> PatElem a -> PatElem b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (newattr -> oldattr -> newattr
forall a b. a -> b -> a
const newattr
x) PatElem oldattr
pe

-- | Return a list of the 'Ident's bound by the t'Pat'.
patIdents :: (Typed dec) => Pat dec -> [Ident]
patIdents :: forall dec. Typed dec => Pat dec -> [Ident]
patIdents = (PatElem dec -> Ident) -> [PatElem dec] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElem dec -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent ([PatElem dec] -> [Ident])
-> (Pat dec -> [PatElem dec]) -> Pat dec -> [Ident]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat dec -> [PatElem dec]
forall dec. Pat dec -> [PatElem dec]
patElems

-- | Return a list of the 'Name's bound by the t'Pat'.
patNames :: Pat dec -> [VName]
patNames :: forall dec. Pat dec -> [VName]
patNames = (PatElem dec -> VName) -> [PatElem dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem dec] -> [VName])
-> (Pat dec -> [PatElem dec]) -> Pat dec -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat dec -> [PatElem dec]
forall dec. Pat dec -> [PatElem dec]
patElems

-- | Return a list of the typess bound by the pattern.
patTypes :: (Typed dec) => Pat dec -> [Type]
patTypes :: forall dec. Typed dec => Pat dec -> [Type]
patTypes = (Ident -> Type) -> [Ident] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Type
identType ([Ident] -> [Type]) -> (Pat dec -> [Ident]) -> Pat dec -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat dec -> [Ident]
forall dec. Typed dec => Pat dec -> [Ident]
patIdents

-- | Return the number of names bound by the pattern.
patSize :: Pat dec -> Int
patSize :: forall dec. Pat dec -> Int
patSize (Pat [PatElem dec]
xs) = [PatElem dec] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem dec]
xs

-- | Create a pattern using 'Type' as the attribute.
basicPat :: [Ident] -> Pat Type
basicPat :: [Ident] -> Pat Type
basicPat [Ident]
values =
  [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type) -> [PatElem Type] -> Pat Type
forall a b. (a -> b) -> a -> b
$ (Ident -> PatElem Type) -> [Ident] -> [PatElem Type]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> PatElem Type
patElem [Ident]
values
  where
    patElem :: Ident -> PatElem Type
patElem (Ident VName
name Type
t) = VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
name Type
t