{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
module Ivory.Opts.Overflow
( overflowFold, addBase, subBase, mulBase, divBase, (<+>), ext
) where
import Ivory.Opts.AssertFold
import qualified Ivory.Language.Array as I
import qualified Ivory.Language.Syntax.AST as I
import qualified Ivory.Language.Syntax.Type as I
import qualified Ivory.Language.Syntax.Names as I
import qualified Ivory.Language.Type as T
import Ivory.Language
import Prelude hiding (max,min)
import Data.Word
import Data.Int
overflowFold :: I.Proc -> I.Proc
overflowFold = procFold "ovf" (expFoldDefault arithAssert)
type Bounds a = (a,a)
arithAssert :: I.Type -> I.Expr -> FolderStmt ()
arithAssert ty e = case e of
I.ExpLit i -> litAssert ty i
I.ExpOp op args -> arithAssert' ty op args
_ -> return ()
litAssert :: I.Type -> I.Literal -> FolderStmt ()
litAssert ty lit = case lit of
I.LitInteger i ->
case ty of
I.TyWord I.Word8 -> boundLit (minMax :: Bounds Word8)
I.TyWord I.Word16 -> boundLit (minMax :: Bounds Word16)
I.TyWord I.Word32 -> boundLit (minMax :: Bounds Word32)
I.TyWord I.Word64 -> boundLit (minMax :: Bounds Word64)
I.TyInt I.Int8 -> boundLit (minMax :: Bounds Int8)
I.TyInt I.Int16 -> boundLit (minMax :: Bounds Int16)
I.TyInt I.Int32 -> boundLit (minMax :: Bounds Int32)
I.TyInt I.Int64 -> boundLit (minMax :: Bounds Int64)
I.TyIndex n -> boundLit (0 :: Integer, n)
_ -> return ()
where
boundLit (min,max) = insert ca
where
ca = I.CompilerAssert (T.unwrapExpr res)
res = if fromIntegral min <= i && i <= fromIntegral max
then true
else false
minMax :: forall t . (Bounded t) => Bounds t
minMax = (minBound :: t, maxBound :: t)
_ -> return ()
arithAssert' :: I.Type -> I.ExpOp -> [I.Expr] -> FolderStmt ()
arithAssert' ty op args =
case op of
I.ExpAdd -> case ty of
I.TyWord I.Word8 -> mkCall addBase ty args
I.TyWord I.Word16 -> mkCall addBase ty args
I.TyWord I.Word32 -> mkCall addBase ty args
I.TyWord I.Word64 -> mkCall addBase ty args
I.TyInt I.Int8 -> mkCall addBase ty args
I.TyInt I.Int16 -> mkCall addBase ty args
I.TyInt I.Int32 -> mkCall addBase ty args
I.TyInt I.Int64 -> mkCall addBase ty args
I.TyIndex _ -> mkCall addBase ty args
_ -> return ()
I.ExpSub -> case ty of
I.TyWord I.Word8 -> mkCall subBase ty args
I.TyWord I.Word16 -> mkCall subBase ty args
I.TyWord I.Word32 -> mkCall subBase ty args
I.TyWord I.Word64 -> mkCall subBase ty args
I.TyInt I.Int8 -> mkCall subBase ty args
I.TyInt I.Int16 -> mkCall subBase ty args
I.TyInt I.Int32 -> mkCall subBase ty args
I.TyInt I.Int64 -> mkCall subBase ty args
I.TyIndex _ -> mkCall subBase ty args
_ -> return ()
I.ExpMul -> case ty of
I.TyWord I.Word8 -> mkCall mulBase ty args
I.TyWord I.Word16 -> mkCall mulBase ty args
I.TyWord I.Word32 -> mkCall mulBase ty args
I.TyWord I.Word64 -> mkCall mulBase ty args
I.TyInt I.Int8 -> mkCall mulBase ty args
I.TyInt I.Int16 -> mkCall mulBase ty args
I.TyInt I.Int32 -> mkCall mulBase ty args
I.TyInt I.Int64 -> mkCall mulBase ty args
I.TyIndex _ -> mkCall mulBase ty args
_ -> return ()
I.ExpDiv -> case ty of
I.TyWord I.Word8 -> mkCall divBase ty args
I.TyWord I.Word16 -> mkCall divBase ty args
I.TyWord I.Word32 -> mkCall divBase ty args
I.TyWord I.Word64 -> mkCall divBase ty args
I.TyInt I.Int8 -> mkCall divBase ty args
I.TyInt I.Int16 -> mkCall divBase ty args
I.TyInt I.Int32 -> mkCall divBase ty args
I.TyInt I.Int64 -> mkCall divBase ty args
I.TyIndex _ -> mkCall divBase ty args
_ -> return ()
I.ExpMod -> case ty of
I.TyWord I.Word8 -> mkCall divBase ty args
I.TyWord I.Word16 -> mkCall divBase ty args
I.TyWord I.Word32 -> mkCall divBase ty args
I.TyWord I.Word64 -> mkCall divBase ty args
I.TyInt I.Int8 -> mkCall divBase ty args
I.TyInt I.Int16 -> mkCall divBase ty args
I.TyInt I.Int32 -> mkCall divBase ty args
I.TyInt I.Int64 -> mkCall divBase ty args
I.TyIndex _ -> mkCall divBase ty args
_ -> return ()
_ -> return ()
mkCall :: String -> I.Type -> [I.Expr] -> FolderStmt ()
mkCall f ty args = do
var <- freshVar
let v = I.VarInternal var
insert $ I.Call I.TyBool (Just v) (I.NameSym $ f <+> ext ty)
(map (I.Typed ty) args)
insert $ I.CompilerAssert (I.ExpVar v)
(<+>) :: String -> String -> String
a <+> b = a ++ "_" ++ b
mkOvf :: String -> String
mkOvf a = a <+> "ovf"
addBase, subBase, mulBase, divBase :: String
addBase = mkOvf "add"
subBase = mkOvf "sub"
mulBase = mkOvf "mul"
divBase = mkOvf "div"
ext :: I.Type -> String
ext ty = case ty of
I.TyChar
-> "char"
I.TyFloat
-> "float"
I.TyDouble
-> "double"
I.TyInt i
-> case i of
I.Int8 -> "i8"
I.Int16 -> "i16"
I.Int32 -> "i32"
I.Int64 -> "i64"
I.TyWord w
-> case w of
I.Word8 -> "u8"
I.Word16 -> "u16"
I.Word32 -> "u32"
I.Word64 -> "u64"
I.TyIndex _ -> ext I.ixRep
_ -> error $ "Unexpected type " ++ show ty ++ " in ext."