{-# LANGUAGE FlexibleInstances, TypeFamilies #-} module Feldspar.Compiler.Backend.C.Plugin.HandlePrimitives ( HandlePrimitives(..) , completeFunProcName ) where import Data.List (find) import Data.Maybe (fromJust) import Feldspar.Compiler.Imperative.Representation import Feldspar.Compiler.Backend.C.CodeGeneration (typeof, defaultMemberName) import Feldspar.Transformation import Feldspar.Compiler.Backend.C.Options import Feldspar.Compiler.Error handlePrimitivesError = handleError "PluginArch/HandlePrimitives" InternalError data HandleTraceFunctions = HandleTraceFunctions instance Default Bool where def = False instance Combine Bool where combine x y = or [x,y] instance Transformation HandleTraceFunctions where type From HandleTraceFunctions = () type To HandleTraceFunctions = () type Down HandleTraceFunctions = () type Up HandleTraceFunctions = Bool type State HandleTraceFunctions = () instance Transformable HandleTraceFunctions Expression where transform t s d p = tr { up = combine u' (up tr) } where tr = defaultTransform t s d p u' = case p of FunctionCall "trace" _ _ _ _ _ -> True _ -> False instance Transformable HandleTraceFunctions Definition where transform t s d p@(Procedure n i o b _ _) = case up tr of False -> tr True -> tr { result = (result tr) { procBody = (procBody $ result tr) { blockBody = addTraceSE $ blockBody $ procBody $ result tr } } } where tr = defaultTransform t s d p addTraceSE sequ@(Sequence _ _ _) = sequ { sequenceProgs = [traceStart] ++ (sequenceProgs sequ) ++ [traceEnd] } addTraceSE p = Sequence [traceStart, p, traceEnd] () () traceStart = ProcedureCall "traceStart" [] () () traceEnd = ProcedureCall "traceEnd" [] () () transform t s d p = defaultTransform t s d p data HandlePrimitives = HandlePrimitives instance Transformation HandlePrimitives where type From HandlePrimitives = () type To HandlePrimitives = () type Down HandlePrimitives = (Int, Platform, Maybe (Expression ())) type Up HandlePrimitives = ([Declaration ()], [Program ()]) type State HandlePrimitives = Int instance Plugin HandlePrimitives where type ExternalInfo HandlePrimitives = (Int, DebugOption, Platform) executePlugin _ (_,NoPrimitiveInstructionHandling,_) procedure = procedure executePlugin _ (defArrSize,_,platform) procedure = result $ transform HandlePrimitives 0 (defArrSize, platform, Nothing) $ result $ transform HandleTraceFunctions ({-state-}) ({-down-}) procedure instance Combine ([Declaration ()], [Program ()]) where combine (xl, xi) (yl, yi) = (xl ++ yl, xi ++ yi) instance Default [Declaration ()] where def = [] instance Default [Program ()] where def = [] instance Transformable HandlePrimitives Block where transform t s d b = tr { result = addToBlock (result tr) (up tr) , up = ([],[]) } where tr = case (up tr') of (_,[]) -> tr' _ -> handlePrimitivesError $ "transform Block: upwards program list is not empty." tr' = defaultTransform t s d b instance Transformable HandlePrimitives Program where transform t s d p@(ProcedureCall "copy" [o@(Out out _), i@(In inp _)] _ _) = case typeof out of (ArrayType _ _) -> Result (ProcedureCall "copyArray" [out', inp'] () ()) arrS' arrU' _ -> Result (Assign lhs rhs () ()) assS' assU' where (Result out' arrS arrU1) = transform t s d o (Result inp' arrS' arrU2) = transform t arrS d i arrU' = arrU1 `combine` arrU2 (Result lhs assS assU1) = transform t s d out (Result rhs assS' assU2) = transform t assS d inp assU' = assU1 `combine` assU2 -- transform t s d@(das, pfm, _) p@(ProcedureCall "copy" [Out out _, In _ _] _ _) = tr { result = makeAssignment (das, pfm) inp' out' } where -- tr = case out of -- e@(VarExpr v _) -> defaultTransform t s (das, pfm, Just e) p -- e@(ArrayElem _ _ _ _) -> defaultTransform t s (das, pfm, Just e) p -- e@(StructField _ _ _ _) -> defaultTransform t s (das, pfm, Just e) p -- _ -> defaultTransform t s (das, pfm, Nothing) p -- inp' = aToE $ head $ filter isInparam $ procCallParams $ result tr -- out' = aToE $ head $ filter (not . isInparam) $ procCallParams $ result tr transform t s d (SeqLoop c cc p inf1 inf2) = Result (SeqLoop (result tr1) cc' (result tr3) (convert inf1) $ convert inf2) (state tr3) ([],[]) where tr1 = transform t s d c tr2 = transform t (state tr1) d cc tr3 = transform t (state tr2) d p cc' = addToBlock (result tr2) (up tr1) transform t s d p = defaultTransform t s d p instance Transformable1 HandlePrimitives [] Program where transform1 t s d [] = Result1 [] s def transform1 t s d (x:xs) = Result1 (snd (up tr1) ++ [result tr1] ++ (result1 tr2)) (state1 tr2) (concatMap fst [up tr1,up1 tr2],[]) where tr1 = transform t s d x tr2 = transform1 t (state tr1) d xs instance Transformable HandlePrimitives Declaration where transform t s d@(das, pfm, _) (Declaration v i inf) = Result (Declaration (result tr1) i' $ convert inf) (state1 tr2) u' where tr1 = transform t s d v tr2 = transform1 t (state tr1) d i (i',u') = case (up1 tr2) of u@(ls,[]) -> (result1 tr2, combine (up tr1) u) (ls, is) -> (Nothing, (ls, is ++ [makeAssignment (das, pfm) (fromJust $ result1 tr2) (vToE $ result tr1)])) instance Transformable HandlePrimitives Expression where transform t s d@(das, pfm, me) f@(FunctionCall nameS ot origRole origInps _ _) = res where res = case (nameS, origInps) of ("getFst", [FunctionCall "pair" _ _ [fs,sn] _ _]) -> transform t s (das, pfm, Nothing) fs ("getSnd", [FunctionCall "pair" _ _ [fs,sn] _ _]) -> transform t s (das, pfm, Nothing) sn _ -> Result e' s' $ combine (up tr) (l',p') tr = defaultTransform t s (das, pfm, Nothing) f s2 = state tr (s',l',p',e') = case (nameS, inps, me) of ("(!)", [arr, idx], _) -> (s2, [], [], ArrayElem arr idx () ()) ("setIx", [arr, idx, val], _) -> (s2 , [] , [ makeAssignment d' val (ArrayElem arr idx () ()) ] , arr ) ("getFst", [l], _) -> (s2, [], [], StructField l (defaultMemberName ++ "1") () ()) ("getSnd", [l], _) -> (s2, [], [], StructField l (defaultMemberName ++ "2") () ()) ("pair", [a,b], Just e) -> (s2 , [] , [ makeAssignment d' a (StructField e (defaultMemberName ++ "1") () ()) , makeAssignment d' b (StructField e (defaultMemberName ++ "2") () ()) ] , e ) ("pair", [a,b], Nothing) -> (s3 , [ makeDeclaration stc Nothing ] , [ makeAssignment d' a (StructField (VarExpr stc ()) (defaultMemberName ++ "1") () ()) , makeAssignment d' b (StructField (VarExpr stc ()) (defaultMemberName ++ "2") () ()) ] , VarExpr stc () ) where (s3, stc) = makeVariable ot "stc" s2 ("trace", [lab, orig], Just e) -> (s2 , [] , [ makeAssignment d' orig e , makeProcedureCall pfm (Proc "trace" firstInFP) [e, lab] [] ] , e ) ("trace", [lab, orig], Nothing) -> (s3 , [ makeDeclaration trc Nothing ] , [ makeAssignment d' orig (VarExpr trc ()) , makeProcedureCall pfm (Proc "trace" firstInFP) [VarExpr trc (), lab] [] ] , VarExpr trc () ) where (s3, trc) = makeVariable ot "trc" s2 _ -> case (find matchPrimitive $ primitives pfm) of Just (fd,Right tp) -> transformPrgDesc d' s2 (tp fd inps ot) Just (fd,Left cd) -> transformCPrimDesc d' s2 cd inps ot Nothing -> (s2, [], [], result tr) matchPrimitive (fd,_) = (fName fd == nameS) && (matchTypes' (inputs fd) inps) inps = funCallParams $ result tr d' = (das, pfm) transform t s d@(das, pfm, _) p = defaultTransform t s (das, pfm, Nothing) p addToBlock :: Block () -> ([Declaration ()], [Program ()]) -> Block () addToBlock b (ls,is) = b { locals = locals b ++ ls, blockBody = case (blockBody b) of (Sequence s () ()) -> Sequence (s ++ is) () () p -> Sequence ([p] ++ is) () () } transformCPrimDesc :: (Int,Platform) -> Int -> CPrimDesc -> [Expression ()] -> Type -> (Int, [Declaration ()], [Program ()], Expression ()) transformCPrimDesc (_,pfm) serial cd inps ot = case (cd, length inps) of (Op1 op, 1) -> (serial, [], [], FunctionCall op ot PrefixOp inps () ()) (Op2 op, 2) -> (serial, [], [], FunctionCall op ot InfixOp inps () ()) (Fun _ _, _) -> (serial, [], [], FunctionCall (completeFunProcName pfm cd (map typeof inps) [ot]) ot SimpleFun inps () ()) (Cas, 1) -> (serial, [], [], Cast ot (head inps) () ()) (Assig, 1) -> (serial, [], [], head inps) _ -> (serial', [makeDeclaration ov Nothing], [makeProcedureCall pfm cd inps [vToE ov]], vToE ov) where (serial', ov) = makeVariable ot "vhp" serial transformPrgDesc :: (Int,Platform) -> Int -> PrgDesc -> (Int, [Declaration ()], [Program ()], Expression ()) transformPrgDesc down@(_,pfm) serial (PrgDesc crts lns rgt) = (serial', map (\(_,_,v,me) -> makeDeclaration v me) vars, ins, transformRgt vars rgt) where (serial', vars') = foldl transformCrtFold (serial, []) (map searchDuplicateLabels crts) (vars, ins) = foldl transformLineFold (vars', []) lns searchDuplicateLabels c = if (length $ filter (==c) crts) > 1 then handlePrimitivesError $ "multiple declaration" ++ show c else c transformCrtFold (n ,vs) (Crt t v@(Var s) (Just r)) = (n', vs ++ [(v, True, mv, Just $ transformRgt vs r)]) where (n', mv) = makeVariable t s n transformCrtFold (n ,vs) (Crt t v@(Var s) Nothing) = (n', vs ++ [(v, False, mv, Nothing)]) where (n', mv) = makeVariable t s n transformLineFold (vs, is) ln = case (ln) of (Asg v r) -> (updateVars [v], is ++ [makeAssignment down (transformRgt' r) (transformVarL' v)]) (Prc cd inps outs) -> (updateVars outs, is ++ [makeProcedureCall pfm cd (map transformRgt' inps) (map transformVarL' outs)]) where updateVars xs = map (\y@(v',_,vv,mr) -> if elem v' xs then (v',True,vv,mr) else y) vs transformRgt' = transformRgt vs transformVarL' = vToE . transformVarL vs transformRgt vs (Exp e) = e transformRgt vs (Fnc cd rgts ot) = makeFunctionCallOrCast down cd (map (transformRgt vs) rgts) ot transformRgt vs (VarR v) = vToE $ transformVarR vs v transformVarL vs v@(Var s) = case (find (\(v',_,_,_) -> v' == v) vs) of Just (_,_,vv,_) -> vv Nothing -> handlePrimitivesError $ "Not declared: " ++ show v transformVarR vs v@(Var s) = case (find (\(v',_,_,_) -> v' == v) vs) of Just (_,True,vv,_) -> vv Just (_,False,vv,_) -> vv -- Do not check that is there any initial assignment -- quick bugfix with pair - set_pair macros -- Just _ -> handlePrimitivesError $ "The variable hasn't got value yet: " ++ show v Nothing -> handlePrimitivesError $ "Not declared: " ++ show v makeFunctionCallOrCast :: (Int,Platform) -> CPrimDesc -> [Expression ()] -> Type -> Expression () makeFunctionCallOrCast down cd inps ot = case (transformCPrimDesc down (-1) cd inps ot) of (_, [], [], ed) -> ed _ -> handlePrimitivesError $ "it's not a FunctionCall: " ++ show cd ++ "number of inputs: " ++ (show $ length inps) makeVariable :: Type -> String -> Int -> (Int, Variable ()) makeVariable t s n = (n+1, Variable (s ++ show n) t Value ()) makeDeclaration :: Variable () -> Maybe (Expression ()) -> Declaration () makeDeclaration v me = Declaration v me () makeAssignment :: (Int,Platform) -> Expression () -> Expression () -> Program () makeAssignment (defArrSize,pfm) inp out = case (sameVariable inp out, typeof inp) of (True, _) -> Empty () () (_, ArrayType _ t) -> ProcedureCall "copyArray" [eToOut out, eToIn inp] () () where -- size = prod_const (arraySize (typeof out) defArrSize) (SizeOf (Left $ baseType t) () ()) -- baseType (ArrayType _ t) = baseType t -- baseType t = t _ -> Assign out inp () () where sameVariable (VarExpr v1 _) (VarExpr v2 _) | v1 == v2 = True | otherwise = False sameVariable (ArrayElem a1 i1 _ _) (ArrayElem a2 i2 _ _) | a1 == a2 && i1 == i2 = True | otherwise = False sameVariable _ _ = False makeProcedureCall :: Platform -> CPrimDesc -> [Expression ()] -> [Expression ()] -> Program () makeProcedureCall pfm cd@(Proc _ _) inps outs = ProcedureCall (completeFunProcName pfm cd its ots) (inps' ++ outs') () () where inps' = map eToIn inps outs' = map eToOut outs its = map typeof inps ots = map typeof outs makeProcedureCall _ cd _ _ = handlePrimitivesError $ "Wrong C pirmitive description in makeProcedureCall:\n" ++ show cd matchTypes' :: [TypeDesc] -> [Expression ()] -> Bool matchTypes' [] [] = True matchTypes' [] (y:ys) = False matchTypes' (x:xs) [] = False matchTypes' (x:xs) (y:ys) = (machTypes x $ typeof y) && (matchTypes' xs ys) completeFunProcName :: Platform -> CPrimDesc -> [Type] -> [Type] -> String completeFunProcName pfm desc its ots | funPf desc == noneFP = cName desc | otherwise = cName desc ++ ifFun ++ apsToName where ifFun = case desc of Fun _ _ -> "_fun" Proc _ _ -> "" apsToName = concatMap (("_"++) . (toFunName pfm)) apsToNameList apsToNameList = (take (useInputs $ funPf desc) its) ++ (take (useOutputs $ funPf desc) ots) toFunName :: Platform -> Type -> String toFunName pfm (ArrayType _ t@(ArrayType _ _)) = toFunName pfm t toFunName pfm (ArrayType _ t) = "arrayOf_" ++ toFunName pfm t toFunName pfm t = case (find (\(t',_,_) -> t == t') $ types pfm) of Just (_,_,s) -> map (\c -> if c == ' ' then '_' else c) $ s Nothing -> handlePrimitivesError $ "Unhandled type in platform " ++ name pfm -- arraySize :: Type -> Int -> Expression () -- arraySize a@(ArrayType _ t) defaultArraySize = toExp $ arraySize' a -- where -- arraySize' :: Type -> (Int,Int) -- arraySize' (ArrayType (LiteralLen n) t) = (n * fst at, snd at) where -- at = arraySize' t -- arraySize' (ArrayType UndefinedLen t) = (fst at, 1 + snd at) where -- at = arraySize' t -- arraySize' _ = (1,0) -- toExp :: (Int,Int) -> Expression () -- toExp (c, 0) = intToCe $ toInteger c -- toExp (c, i) = prod_const (toExp (c, i-1)) (vToE $ Variable defaultArraySizeConstantName (NumType Unsigned S32) Value ()) prod_const a b = FunctionCall "*" (NumType Unsigned S32) InfixOp [a,b] () () isInparam (In _ _) = True isInparam (Out _ _) = False aToE (In x ()) = x aToE (Out x ()) = x eToIn x = In x () eToOut x = Out x () -- ceToInt (Expression (ConstantExpression (Constant (IntConstant (IntConstantType x _)) _)) _) = x intToCe x = ConstExpr (IntConst x () ()) () vToE v = VarExpr v ()