{- CAO Compiler
Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see . -}
{-# LANGUAGE PatternGuards #-}
{- |
Module : $Header$
Description : CAO type representation utilities
Copyright : (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho
License : GPL
Maintainer : Paulo Silva
Stability : experimental
Portability : non-portable
This module defines utility function for the CAO type representation.
-}
module Language.CAO.Type.Utils
( isAlgebraic
, isBool
, isInt
, isBits
, isVar
, isNil
, isFunType
, isProc
, isTySyn
, isStruct
, isIndet
, isMod
, isModInt
, isModPol
, isVector
, isMatrix
, isIntExt
, isRInt
, isSimpleType
, isContainer
, sfType
, retType
, synType
, fieldType
, funClass
, getModulusBase
, extractBaseType
, extractBottomBaseType
, getStructName
, wVars
, getPoly
, maximumClass
, fromTuple
, toTuple
, innerType
, substTy
, getBVSize
, isDependent
) where
import Control.Arrow (second)
import Data.List ( foldl' )
import Data.Set ( Set )
import qualified Data.Set as Set
import Language.CAO.Common.Outputable
import Language.CAO.Common.Polynomial
import Language.CAO.Index
import Language.CAO.Type
--------------------------------------------------------------------------------
-- * Predicates over type representations
isAlgebraic :: Type id -> Bool
isAlgebraic (Index _ _ t) = isAlgebraic t
isAlgebraic Int = True
isAlgebraic RInt = True
isAlgebraic (Mod _ _ _) = True
isAlgebraic (Matrix _ _ t) = isAlgebraic t
isAlgebraic _ = False
isBool :: Type id -> Bool
isBool Bool = True
isBool (Index _ _ n) = isBool n
isBool _ = False
isInt :: Type id -> Bool
isInt Int = True
isInt RInt = True
isInt (Index _ _ n) = isInt n
isInt _ = False
isIntExt :: Type id -> Bool
isIntExt (Index _ _ t) = isIntExt t
isIntExt Int = True
isIntExt t = isBits t
isRInt :: Type id -> Bool
isRInt (Index _ _ t) = isRInt t
isRInt RInt = True
isRInt _ = False
isBits :: Type id -> Bool
isBits (Bits _ _) = True
isBits _ = False
isVar :: Type id -> Bool
isVar (FuncSig _ _ _) = False
isVar (SField _ _) = False
isVar (Indet _) = False
isVar (Tuple _) = False
isVar Bullet = False
isVar _ = True
isNil :: Type id -> Bool
isNil Bullet = True
isNil (Tuple []) = True
isNil _ = False
isFunType :: Type id -> Bool
isFunType (FuncSig _ _ _) = True
isFunType _ = False
isProc :: Type id -> Bool
isProc (FuncSig _ _ (Proc _)) = True
isProc _ = False
isTySyn :: Type id -> Bool
isTySyn (TySyn _ _) = True
isTySyn _ = False
isStruct :: Type id -> Bool
isStruct (Struct _ _) = True
isStruct _ = False
isIndet :: Type id -> Bool
isIndet (Indet _) = True
isIndet _ = False
isMod :: Type id -> Bool
isMod (Mod _ _ _) = True
isMod _ = False
isModInt :: Type id -> Bool
isModInt (Mod Nothing Nothing _) = True
isModInt _ = False
isModPol :: Type id -> Bool
isModPol (Mod (Just _) _ _) = True
isModPol _ = False
isVector :: Type id -> Bool
isVector (Vector _ _) = True
isVector _ = False
isMatrix :: Type id -> Bool
isMatrix (Matrix _ _ _) = True
isMatrix _ = False
isSimpleType :: Type id -> Bool
isSimpleType Int = True
isSimpleType RInt = True
isSimpleType Bool = True
isSimpleType (Bits _ _) = True
isSimpleType (Mod _ _ _) = True
isSimpleType _ = False
isContainer :: Type id -> Bool
isContainer (Vector {}) = True
isContainer (Matrix {}) = True
isContainer (Struct {}) = True
isContainer _ = False
-- Is a data type that may have dependencies
isDependent :: Type a -> Bool
isDependent t = case t of
Int -> False
RInt -> False
Bool -> False
_ -> True
--------------------------------------------------------------------------------
getStructName :: PP id => Type id -> id
getStructName (Struct v _) = v
getStructName t = error $ ".: unexpected type "
++ showPprDebug t
funClass :: PP id => Type id -> Class id
funClass (FuncSig _ _ c) = c
funClass f = error $ ".: unexpected type "
++ showPprDebug f
retType :: PP id => Type id -> Type id
retType (FuncSig _ t _) = t
retType f = error $ ".: unexpected type "
++ showPprDebug f
synType :: PP id => Type id -> Type id
synType (TySyn _ t) = t
synType t = error $ ".: unexpected type "
++ showPprDebug t
fromTuple :: Type id -> [Type id]
fromTuple (Tuple t) = t
fromTuple t = [t]
toTuple :: [Type id] -> Type id
toTuple [t] = t
toTuple t = Tuple t
getPoly :: PP id => Type id -> Pol id
getPoly (Mod _ _ p) = p
getPoly t = error $ ".: unexpected type "
++ showPprDebug t
extractBaseType :: PP id => Type id -> Type id
extractBaseType (Mod (Just t) _ _)
= t
extractBaseType t
= error $ ".: unexpected type "
++ showPprDebug t
extractBottomBaseType :: Type id -> Type id
extractBottomBaseType m@(Mod Nothing Nothing _)
= m
extractBottomBaseType (Mod (Just t) _ _)
= extractBottomBaseType t
extractBottomBaseType _
= error "extractBottomBaseType: not a Mod"
getModulusBase :: Type id -> (IExpr id)
getModulusBase (Mod Nothing Nothing (Pol [Mon (CoefI c) EZero])) = c
getModulusBase m@(Mod _ _ _) = getModulusBase (extractBottomBaseType m)
getModulusBase _ = error "getModulusBase: not a Mod"
maximumClass :: Ord id => [Class id] -> Class id
maximumClass []
= Pure
maximumClass cls
| lst <- wVars cls, not (Set.null lst) = Proc $ Set.toList lst
| otherwise = maximum cls
wVars :: Ord id => [Class id] -> Set id
wVars = foldl' goVs Set.empty
where
goVs acc (Proc wvs) = Set.union (Set.fromList wvs) acc
goVs acc _ = acc
sfType :: PP id => Type id -> Type id
sfType (SField _ rt) = rt
sfType t = error $ ".: unexpected type "
++ showPprDebug t
fieldType :: (PP id, Eq id) => id -> Type id -> Type id
fieldType fi (TySyn _ ty) = fieldType fi ty
fieldType fi (Struct n flds)
| Just ty' <- lookup fi flds = ty'
| otherwise = error $ ".\
\: unknown field " ++ showPprDebug fi ++ " of struct " ++
showPprDebug n
fieldType _ ty = error $ ".\
\: unexpected type " ++ showPprDebug ty
innerType :: PP a => Type a -> [Type a]
innerType t = case t of
Vector _ t' -> [t']
Matrix _ _ t' -> [t']
Struct _ flds -> map snd flds
_ -> error $
".: unexpected case for type: "
++ showPpr t
substTy :: Type a -> (Type a, TyVarId) -> Type a
substTy Int _ = Int
substTy RInt _ = RInt
substTy Bool _ = Bool
substTy Bullet _ = Bullet
substTy (Bits sg e) _ = Bits sg e
substTy t1@(TyVar v1) (t2, v2)
| v1 == v2 = t2
| otherwise = t1
substTy (Mod mty mind pol) s = Mod (fmap (`substTy` s) mty) mind pol
substTy (Vector e t) s = Vector e $ substTy t s
substTy (Matrix e1 e2 t) s = Matrix e1 e2 $ substTy t s
substTy (TySyn sn t) s = TySyn sn $ substTy t s
substTy (FuncSig ts t c) s = FuncSig (map (`substTy` s) ts) (substTy t s) c
substTy (Struct sn flds) s = Struct sn $ map (second (`substTy` s)) flds
substTy (SField fn t) s = SField fn $ substTy t s
substTy (Indet t) s = Indet $ substTy t s
substTy (Tuple ts) s = Tuple $ map (`substTy` s) ts
substTy (Index vn mc t) s = Index vn mc $ substTy t s
substTy _ _ = error ""
--------------------------------------------------------------------------------
-- Waste
getBVSize :: PP id => Type id -> IExpr id
getBVSize (Bits _ s) = s
getBVSize (Vector s _) = s
getBVSize t = error $ ".: unexpected type "
++ showPprDebug t