{-
 - Copyright (c) 2009, ERICSSON AB All rights reserved.
 - 
 - Redistribution and use in source and binary forms, with or without
 - modification, are permitted provided that the following conditions
 - are met:
 - 
 -     * Redistributions of source code must retain the above copyright
 -     notice,
 -       this list of conditions and the following disclaimer.
 -     * Redistributions in binary form must reproduce the above copyright
 -       notice, this list of conditions and the following disclaimer
 -       in the documentation and/or other materials provided with the
 -       distribution.
 -     * Neither the name of the ERICSSON AB nor the names of its
 -     contributors
 -       may be used to endorse or promote products derived from this
 -       software without specific prior written permission.
 - 
 - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 -}

module Feldspar.Compiler.Optimization.PrimitiveInstructions where

import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Options
import Data.Map hiding (filter,map)

-- Implementation of the mapping from high-level DSL primitives
-- to low level primitive instructions.

class HandlePrimitives t where
    handlePrimitives :: Options -> t -> t

instance (HandlePrimitives a) => HandlePrimitives [a] where
    handlePrimitives opts = map (handlePrimitives opts)

instance HandlePrimitives ImpFunction where
    handlePrimitives opts (Fun n i o prg) = Fun n i o $ handlePrimitives opts prg

instance HandlePrimitives CompleteProgram where
    handlePrimitives opts (CompPrg d b) = CompPrg d $ handlePrimitives opts b

instance HandlePrimitives Program where
    handlePrimitives opts Empty = Empty
    handlePrimitives opts (Primitive instr s) = transformPrimitive opts instr s
    handlePrimitives opts (Seq prgs s) = Seq (map (handlePrimitives opts) prgs) s
    handlePrimitives opts (IfThenElse v b1 b2 s) = IfThenElse v (handlePrimitives opts b1) (handlePrimitives opts b2) s
    handlePrimitives opts (SeqLoop cnd condCalc bod s)
        = SeqLoop cnd (handlePrimitives opts condCalc) (handlePrimitives opts bod) s
    handlePrimitives opts (ParLoop cnt max st bod s) = ParLoop cnt max st (handlePrimitives opts bod) s

transformPrimitive :: Options -> Instruction -> SemInfPrim -> Program
{- -- Do we still have a 'tuple' function?
transformPrimitive opt i@(CFun "tuple" ps) s
    | length ins == length outs
        = Seq (map (\pair -> mkCopy opt pair s) $ zip ins outs) []
    | otherwise
        = error ("Error: Number of parameters is odd in a 'tuple' call.\n\t" ++ toC 0 i)
        where
            ins = inParams ps
            outs = outParams ps
-}
transformPrimitive opts (CFun "(==)" [In in1, In in2, Out out]) s = op2 in1 in2 out "==" "equal" s
transformPrimitive opts (CFun "(/=)" [In in1, In in2, Out out]) s = op2 in1 in2 out "!=" "not_equal" s
transformPrimitive opts (CFun "(<)" [In in1, In in2, Out out]) s = op2 in1 in2 out "<" "less" s
transformPrimitive opts (CFun "(>)" [In in1, In in2, Out out]) s = op2 in1 in2 out ">" "greater" s
transformPrimitive opts (CFun "(<=)" [In in1, In in2, Out out]) s = op2 in1 in2 out "<=" "less_equal" s
transformPrimitive opts (CFun "(>=)" [In in1, In in2, Out out]) s = op2 in1 in2 out ">=" "greater_equal" s
transformPrimitive opts (CFun "not" [In in1, Out out]) s         = op1 in1 out "!" "not" s
transformPrimitive opts (CFun "(&&)" [In in1, In in2, Out out]) s = op2 in1 in2 out "&&" "and" s
transformPrimitive opts (CFun "(||)" [In in1, In in2, Out out]) s = op2 in1 in2 out "||" "or" s
transformPrimitive opts (CFun "div" [In in1, In in2, Out out]) s = op2 in1 in2 out "/" "divide" s
transformPrimitive opts (CFun "(^)" [In in1, In in2, Out out]) s = fun2 in1 in2 out "pow" s

transformPrimitive opts (CFun "abs" [In in1, Out out]) s = fun1 in1 out "abs" s
transformPrimitive opts (CFun "signum" [In in1, Out out]) s = fun1 in1 out "signum" s
transformPrimitive opts (CFun "(+)" [In in1, In in2, Out out]) s = op2 in1 in2 out "+" "add" s
transformPrimitive opts (CFun "(-)" [In in1, In in2, Out out]) s = op2 in1 in2 out "-" "sub" s
transformPrimitive opts (CFun "(*)" [In in1, In in2, Out out]) s = op2 in1 in2 out "*" "mult" s
transformPrimitive opts (CFun "(/)" [In in1, In in2, Out out]) s = op2 in1 in2 out "/" "divide" s

transformPrimitive opts (CFun "(!)" [In arr, In idx, Out (kind,out)]) s
    = Primitive (Assign left right) semInf
        where
            left = toLeftValue out
            right = Expr (LeftExpr $ ArrayElem (toLeftValue arr) idx) $ exprType out
            semInf = s{ varMap = addVarMap (leftVarMap left $ Just right) (rightVarMap right) }
transformPrimitive opts (CFun "setIx" [In original, In index, In value, Out (kind,result)]) s
    = Seq
        [ mkCopy opts (original,(kind, result)) s
        , mkCopy opts (value, (Normal,Expr (LeftExpr $ ArrayElem (toLeftValue result) index) $ exprType value)) s
        ] []
transformPrimitive opts (CFun "copy" [In in1, Out (kind,out)]) s
    | simpleType (exprType in1) && kind == Normal
        = Primitive (Assign (toLeftValue out) $ in1) semInf
    | simpleType (exprType in1) && kind == OutKind
        = Primitive (Assign (PointedVal $ toLeftValue out) $ in1) semInf
    | otherwise = Primitive (CFun ("copy" ++ "_" ++ toFunName (exprType in1)) ([In in1] ++ arrayDim (exprType in1) ++ [Out (kind,out)])) semInf
        where
            semInf = s{ varMap = vm }
            vm = addVarMap (leftVarMap (toLeftValue out) $ Just in1) (rightVarMap in1)
            arrayDim (ImpArrayType (Just n) t) = In (Expr (ConstExpr (IntConst n)) (Numeric ImpSigned S32)) : arrayDim t
            arrayDim (ImpArrayType Nothing t) = In (Expr (ConstExpr (IntConst 16)) (Numeric ImpSigned S32)) : arrayDim t
            arrayDim _ = []

transformPrimitive opts c@(CFun "copy" pars) s
    | length ins /= length outs = error $ "Error: invalid arguments to 'copy':\n" ++ toC 0 c
    | otherwise = Seq (map genTwoParamCopy $ zip ins outs) []
        where
            ins = filter inparam pars
            outs = filter (not . inparam) pars
            genTwoParamCopy (i,o) = transformPrimitive opts (CFun "copy" [i,o]) s

transformPrimitive _ i@(CFun _ pars) s = Primitive i semInf where
    semInf = s{ varMap = foldr addVarMap Data.Map.empty mapList }
    mapList
        = map rightVarMap (inParams pars)
        ++ map (\out -> leftVarMap (toLeftValue out) Nothing) (map snd $ outParams pars)

fun1 in1 (kind,out) cFunName s
    | simpleType (exprType out) && kind == Normal
        = Primitive (Assign (toLeftValue out) right) semInf
    | simpleType (exprType out) && kind == OutKind
        = Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
    | otherwise
        = Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, Out (kind,out)]) semInf
        where
            right = Expr (FunCall SimpleFun (cFunName ++ "_fun_" ++ toFunName (exprType in1)) [in1]) (exprType out)
            semInf = s{ varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }

fun2 in1 in2 (kind,out) cFunName s
    | simpleType (exprType out) && kind == Normal
        = Primitive (Assign (toLeftValue out) right) semInf
    | simpleType (exprType out) && kind == OutKind
        = Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
    | otherwise
        = Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, In in2, Out (kind,out)]) semInf
        where
            right = Expr (FunCall SimpleFun (cFunName ++ "_fun_" ++ toFunName (exprType in1)) [in1, in2]) (exprType out)
            semInf = s{varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }

op1 in1 (kind,out) cOpName cFunName s
    | simpleType (exprType out) && kind == Normal
        = Primitive (Assign (toLeftValue out) right) semInf
    | simpleType (exprType out) && kind == OutKind
        = Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
    | otherwise
        = Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, Out (kind,out)]) semInf
        where
            right = Expr (FunCall PrefixOp cOpName [in1]) $ exprType out
            semInf = s{ varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }

op2 in1 in2 (kind,out) cOpName cFunName s
    | simpleType (exprType out) && kind == Normal
        = Primitive (Assign (toLeftValue out) right) semInf
    | simpleType (exprType out) && kind == OutKind
        = Primitive (Assign (PointedVal $ toLeftValue out) right) semInf
    | otherwise
        = Primitive (CFun (cFunName ++ "_" ++ toFunName (exprType in1)) [In in1, In in2, Out (kind,out)]) semInf
        where
            right = Expr (FunCall InfixOp cOpName [in1,in2]) $ exprType out
            semInf = s{ varMap = addVarMap (leftVarMap (toLeftValue out) $ Just right) (rightVarMap right) }

inParams ps = map (\(In x) -> x) $ filter inparam ps
outParams ps = map (\(Out x) -> x) $ filter (not . inparam) ps

inparam p = case p of
    In _ -> True
    Out _ -> False

mkCopy opt (in1,out) s = transformPrimitive opt (CFun "copy" [In in1, Out out]) s


toFunName :: Type -> String
toFunName BoolType = "bool"
toFunName FloatType = "float"
toFunName (Numeric sig siz) = toC 0 sig ++ "_" ++ toC 0 siz
toFunName (ImpArrayType _ t) = "arrayOf_" ++ toFunName t 
toFunName (Feldspar.Compiler.Imperative.Representation.Pointer t) = "pointerTo_" ++ toFunName t