module Feldspar.Compiler.Optimization.PrimitiveInstructions where
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Options
import Data.Map hiding (filter,map)
class HandlePrimitives t where
handlePrimitives :: Options -> t -> t
instance (HandlePrimitives a) => HandlePrimitives [a] where
handlePrimitives opts = map (handlePrimitives opts)
instance HandlePrimitives ImpFunction where
handlePrimitives opts (Fun n i o prg) = Fun n i o $ handlePrimitives opts prg
instance HandlePrimitives CompleteProgram where
handlePrimitives opts (CompPrg d b) = CompPrg d $ handlePrimitives opts b
instance HandlePrimitives Program where
handlePrimitives opts Empty = Empty
handlePrimitives opts (Primitive instr s) = transformPrimitive opts instr s
handlePrimitives opts (Seq prgs s) = Seq (map (handlePrimitives opts) prgs) s
handlePrimitives opts (IfThenElse v b1 b2 s) = IfThenElse v (handlePrimitives opts b1) (handlePrimitives opts b2) s
handlePrimitives opts (SeqLoop cnd condCalc bod s)
= SeqLoop cnd (handlePrimitives opts condCalc) (handlePrimitives opts bod) s
handlePrimitives opts (ParLoop cnt max st bod s) = ParLoop cnt max st (handlePrimitives opts bod) s
transformPrimitive :: Options -> Instruction -> SemInfPrim -> Program
transformPrimitive opts (CFun "(==)" [In in1, In in2, Out out]) s = op2 in1 in2 out "==" "equal" s
transformPrimitive opts (CFun "(/=)" [In in1, In in2, Out out]) s = op2 in1 in2 out "!=" "not_equal" s
transformPrimitive opts (CFun "(<)" [In in1, In in2, Out out]) s = op2 in1 in2 out "<" "less" s
transformPrimitive opts (CFun "(>)" [In in1, In in2, Out out]) s = op2 in1 in2 out ">" "greater" s
transformPrimitive opts (CFun "(<=)" [In in1, In in2, Out out]) s = op2 in1 in2 out "<=" "less_equal" s
transformPrimitive opts (CFun "(>=)" [In in1, In in2, Out out]) s = op2 in1 in2 out ">=" "greater_equal" s
transformPrimitive opts (CFun "not" [In in1, Out out]) s = op1 in1 out "!" "not" s
transformPrimitive opts (CFun "(&&)" [In in1, In in2, Out out]) s = op2 in1 in2 out "&&" "and" s
transformPrimitive opts (CFun "(||)" [In in1, In in2, Out out]) s = op2 in1 in2 out "||" "or" s
transformPrimitive opts (CFun "div" [In in1, In in2, Out out]) s = op2 in1 in2 out "/" "divide" s
transformPrimitive opts (CFun "(^)" [In in1, In in2, Out out]) s = fun2 in1 in2 out "pow" s
transformPrimitive opts (CFun "abs" [In in1, Out out]) s = fun1 in1 out "abs" s
transformPrimitive opts (CFun "signum" [In in1, Out out]) s = fun1 in1 out "signum" s
transformPrimitive opts (CFun "(+)" [In in1, In in2, Out out]) s = op2 in1 in2 out "+" "add" s
transformPrimitive opts (CFun "(-)" [In in1, In in2, Out out]) s = op2 in1 in2 out "-" "sub" s
transformPrimitive opts (CFun "(*)" [In in1, In in2, Out out]) s = op2 in1 in2 out "*" "mult" s
transformPrimitive opts (CFun "(/)" [In in1, In in2, Out out]) s = op2 in1 in2 out "/" "divide" s
transformPrimitive opts (CFun "(!)" [In arr, In idx, Out (kind,out)]) s
= Primitive (Assign left right) semInf
where
left = toLeftValue out
right = Expr (LeftExpr $ ArrayElem (toLeftValue arr) idx) $ exprType out
semInf = s{ varMap = addVarMap (leftVarMap left $ Just right) (rightVarMap right) }
transformPrimitive opts (CFun "setIx" [In original, In index, In value, Out (kind,result)]) s
= Seq
[ mkCopy opts (original,(kind, result)) s
, mkCopy opts (value, (Normal,Expr (LeftExpr $ ArrayElem (toLeftValue result) index) $ exprType value)) s
] []
transformPrimitive opts (CFun "copy" [In in1, Out (kind,out)]) s
| simpleType (exprType in1) && kind == Normal
= Primitive (Assign (toLeftValue out) $ in1) semInf
| simpleType (exprType in1) && kind == OutKind
= Primitive (Assign (PointedVal $ toLeftValue out) $ in1) semInf
| otherwise = Primitive (CFun ("copy" ++ "_" ++ toFunName (exprType in1)) ([In in1] ++ arrayDim (exprType in1) ++ [Out (kind,out)])) semInf
where
semInf = s{ varMap = vm }
vm = addVarMap (leftVarMap (toLeftValue out) $ Just in1) (rightVarMap in1)
arrayDim (ImpArrayType (Just n) t) = In (Expr (ConstExpr (IntConst n)) (Numeric ImpSigned S32)) : arrayDim t
arrayDim (ImpArrayType Nothing t) = In (Expr (ConstExpr (IntConst 16)) (Numeric ImpSigned S32)) : arrayDim t
arrayDim _ = []
transformPrimitive opts c@(CFun "copy" pars) s
| length ins /= length outs = error $ "Error: invalid arguments to 'copy':\n" ++ toC 0 c
| otherwise = Seq (map genTwoParamCopy $ zip ins outs) []
where
ins = filter inparam pars
outs = filter (not . inparam) pars
genTwoParamCopy (i,o) = transformPrimitive opts (CFun "copy" [i,o]) s
transformPrimitive _ i@(CFun _ pars) s = Primitive i semInf where
semInf = s{ varMap = foldr addVarMap Data.Map.empty mapList }
mapList
= map rightVarMap (inParams pars)
++ map (\out -> leftVarMap (toLeftValue out) Nothing) (map snd $ outParams pars)
fun1 in1 (kind,out) cFunName s
| simpleType (exprType out) && kind == Normal
= Primitive (Assign (toLeftValue out) right) semInf
| simpleType (exprType out) && kind == OutKind
= Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
| otherwise
= Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, Out (kind,out)]) semInf
where
right = Expr (FunCall SimpleFun (cFunName ++ "_fun_" ++ toFunName (exprType in1)) [in1]) (exprType out)
semInf = s{ varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }
fun2 in1 in2 (kind,out) cFunName s
| simpleType (exprType out) && kind == Normal
= Primitive (Assign (toLeftValue out) right) semInf
| simpleType (exprType out) && kind == OutKind
= Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
| otherwise
= Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, In in2, Out (kind,out)]) semInf
where
right = Expr (FunCall SimpleFun (cFunName ++ "_fun_" ++ toFunName (exprType in1)) [in1, in2]) (exprType out)
semInf = s{varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }
op1 in1 (kind,out) cOpName cFunName s
| simpleType (exprType out) && kind == Normal
= Primitive (Assign (toLeftValue out) right) semInf
| simpleType (exprType out) && kind == OutKind
= Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
| otherwise
= Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, Out (kind,out)]) semInf
where
right = Expr (FunCall PrefixOp cOpName [in1]) $ exprType out
semInf = s{ varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }
op2 in1 in2 (kind,out) cOpName cFunName s
| simpleType (exprType out) && kind == Normal
= Primitive (Assign (toLeftValue out) right) semInf
| simpleType (exprType out) && kind == OutKind
= Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
| otherwise
= Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, In in2, Out (kind,out)]) semInf
where
right = Expr (FunCall InfixOp cOpName [in1,in2]) $ exprType out
semInf = s{ varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }
inParams ps = map (\(In x) -> x) $ filter inparam ps
outParams ps = map (\(Out x) -> x) $ filter (not . inparam) ps
inparam p = case p of
In _ -> True
Out _ -> False
mkCopy opt (in1,out) s = transformPrimitive opt (CFun "copy" [In in1, Out out]) s
toFunName :: Type -> String
toFunName BoolType = "bool"
toFunName FloatType = "float"
toFunName (Numeric sig siz) = toC 0 sig ++ "_" ++ toC 0 siz
toFunName (ImpArrayType _ t) = "arrayOf_" ++ toFunName t
toFunName (Feldspar.Compiler.Imperative.Representation.Pointer t) = "pointerTo_" ++ toFunName t