{- CAO Compiler Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} {-# LANGUAGE PatternGuards #-} {- | Module : $Header$ Description : CAO type representation utilities Copyright : (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho License : GPL Maintainer : Paulo Silva Stability : experimental Portability : non-portable This module defines utility function for the CAO type representation. -} module Language.CAO.Type.Utils ( isAlgebraic , isBool , isInt , isBits , isVar , isNil , isFunType , isProc , isTySyn , isStruct , isIndet , isMod , isModInt , isModPol , isVector , isMatrix , isIntExt , isRInt , isSimpleType , isContainer , sfType , retType , synType , fieldType , funClass , getModulusBase , extractBaseType , extractBottomBaseType , getStructName , wVars , getPoly , maximumClass , fromTuple , toTuple , innerType , substTy , getBVSize , isDependent ) where import Control.Arrow (second) import Data.List ( foldl' ) import Data.Set ( Set ) import qualified Data.Set as Set import Language.CAO.Common.Outputable import Language.CAO.Common.Polynomial import Language.CAO.Index import Language.CAO.Type -------------------------------------------------------------------------------- -- * Predicates over type representations isAlgebraic :: Type id -> Bool isAlgebraic (Index _ _ t) = isAlgebraic t isAlgebraic Int = True isAlgebraic RInt = True isAlgebraic (Mod _ _ _) = True isAlgebraic (Matrix _ _ t) = isAlgebraic t isAlgebraic _ = False isBool :: Type id -> Bool isBool Bool = True isBool (Index _ _ n) = isBool n isBool _ = False isInt :: Type id -> Bool isInt Int = True isInt RInt = True isInt (Index _ _ n) = isInt n isInt _ = False isIntExt :: Type id -> Bool isIntExt (Index _ _ t) = isIntExt t isIntExt Int = True isIntExt t = isBits t isRInt :: Type id -> Bool isRInt (Index _ _ t) = isRInt t isRInt RInt = True isRInt _ = False isBits :: Type id -> Bool isBits (Bits _ _) = True isBits _ = False isVar :: Type id -> Bool isVar (FuncSig _ _ _) = False isVar (SField _ _) = False isVar (Indet _) = False isVar (Tuple _) = False isVar Bullet = False isVar _ = True isNil :: Type id -> Bool isNil Bullet = True isNil (Tuple []) = True isNil _ = False isFunType :: Type id -> Bool isFunType (FuncSig _ _ _) = True isFunType _ = False isProc :: Type id -> Bool isProc (FuncSig _ _ (Proc _)) = True isProc _ = False isTySyn :: Type id -> Bool isTySyn (TySyn _ _) = True isTySyn _ = False isStruct :: Type id -> Bool isStruct (Struct _ _) = True isStruct _ = False isIndet :: Type id -> Bool isIndet (Indet _) = True isIndet _ = False isMod :: Type id -> Bool isMod (Mod _ _ _) = True isMod _ = False isModInt :: Type id -> Bool isModInt (Mod Nothing Nothing _) = True isModInt _ = False isModPol :: Type id -> Bool isModPol (Mod (Just _) _ _) = True isModPol _ = False isVector :: Type id -> Bool isVector (Vector _ _) = True isVector _ = False isMatrix :: Type id -> Bool isMatrix (Matrix _ _ _) = True isMatrix _ = False isSimpleType :: Type id -> Bool isSimpleType Int = True isSimpleType RInt = True isSimpleType Bool = True isSimpleType (Bits _ _) = True isSimpleType (Mod _ _ _) = True isSimpleType _ = False isContainer :: Type id -> Bool isContainer (Vector {}) = True isContainer (Matrix {}) = True isContainer (Struct {}) = True isContainer _ = False -- Is a data type that may have dependencies isDependent :: Type a -> Bool isDependent t = case t of Int -> False RInt -> False Bool -> False _ -> True -------------------------------------------------------------------------------- getStructName :: PP id => Type id -> id getStructName (Struct v _) = v getStructName t = error $ ".: unexpected type " ++ showPprDebug t funClass :: PP id => Type id -> Class id funClass (FuncSig _ _ c) = c funClass f = error $ ".: unexpected type " ++ showPprDebug f retType :: PP id => Type id -> Type id retType (FuncSig _ t _) = t retType f = error $ ".: unexpected type " ++ showPprDebug f synType :: PP id => Type id -> Type id synType (TySyn _ t) = t synType t = error $ ".: unexpected type " ++ showPprDebug t fromTuple :: Type id -> [Type id] fromTuple (Tuple t) = t fromTuple t = [t] toTuple :: [Type id] -> Type id toTuple [t] = t toTuple t = Tuple t getPoly :: PP id => Type id -> Pol id getPoly (Mod _ _ p) = p getPoly t = error $ ".: unexpected type " ++ showPprDebug t extractBaseType :: PP id => Type id -> Type id extractBaseType (Mod (Just t) _ _) = t extractBaseType t = error $ ".: unexpected type " ++ showPprDebug t extractBottomBaseType :: Type id -> Type id extractBottomBaseType m@(Mod Nothing Nothing _) = m extractBottomBaseType (Mod (Just t) _ _) = extractBottomBaseType t extractBottomBaseType _ = error "extractBottomBaseType: not a Mod" getModulusBase :: Type id -> (IExpr id) getModulusBase (Mod Nothing Nothing (Pol [Mon (CoefI c) EZero])) = c getModulusBase m@(Mod _ _ _) = getModulusBase (extractBottomBaseType m) getModulusBase _ = error "getModulusBase: not a Mod" maximumClass :: Ord id => [Class id] -> Class id maximumClass [] = Pure maximumClass cls | lst <- wVars cls, not (Set.null lst) = Proc $ Set.toList lst | otherwise = maximum cls wVars :: Ord id => [Class id] -> Set id wVars = foldl' goVs Set.empty where goVs acc (Proc wvs) = Set.union (Set.fromList wvs) acc goVs acc _ = acc sfType :: PP id => Type id -> Type id sfType (SField _ rt) = rt sfType t = error $ ".: unexpected type " ++ showPprDebug t fieldType :: (PP id, Eq id) => id -> Type id -> Type id fieldType fi (TySyn _ ty) = fieldType fi ty fieldType fi (Struct n flds) | Just ty' <- lookup fi flds = ty' | otherwise = error $ ".\ \: unknown field " ++ showPprDebug fi ++ " of struct " ++ showPprDebug n fieldType _ ty = error $ ".\ \: unexpected type " ++ showPprDebug ty innerType :: PP a => Type a -> [Type a] innerType t = case t of Vector _ t' -> [t'] Matrix _ _ t' -> [t'] Struct _ flds -> map snd flds _ -> error $ ".: unexpected case for type: " ++ showPpr t substTy :: Type a -> (Type a, TyVarId) -> Type a substTy Int _ = Int substTy RInt _ = RInt substTy Bool _ = Bool substTy Bullet _ = Bullet substTy (Bits sg e) _ = Bits sg e substTy t1@(TyVar v1) (t2, v2) | v1 == v2 = t2 | otherwise = t1 substTy (Mod mty mind pol) s = Mod (fmap (`substTy` s) mty) mind pol substTy (Vector e t) s = Vector e $ substTy t s substTy (Matrix e1 e2 t) s = Matrix e1 e2 $ substTy t s substTy (TySyn sn t) s = TySyn sn $ substTy t s substTy (FuncSig ts t c) s = FuncSig (map (`substTy` s) ts) (substTy t s) c substTy (Struct sn flds) s = Struct sn $ map (second (`substTy` s)) flds substTy (SField fn t) s = SField fn $ substTy t s substTy (Indet t) s = Indet $ substTy t s substTy (Tuple ts) s = Tuple $ map (`substTy` s) ts substTy (Index vn mc t) s = Index vn mc $ substTy t s substTy _ _ = error "" -------------------------------------------------------------------------------- -- Waste getBVSize :: PP id => Type id -> IExpr id getBVSize (Bits _ s) = s getBVSize (Vector s _) = s getBVSize t = error $ ".: unexpected type " ++ showPprDebug t