{-|
  This module defines the syntax of the simple, continuation-passing-style
  functional language used here, as well as some examples.
-}
{-# LANGUAGE GeneralizedNewtypeDeriving, OverloadedStrings, FlexibleInstances #-}
module CPSScheme where

import Data.String( IsString(..) )
import qualified Data.Map as M
import Control.Monad.State
import Control.Applicative ((<$>))

import Common

-- * The CPS styntax

-- | A program is defined as a lambda abstraction. The calling convention is
-- that the program has one paramater, the final continuation.
type Prog = Lambda

-- | Labels are used throughout the code to refer to various positions in the code.
--
-- Integers are used here, but they are wrapped in a newtype to hide them from
-- the implementation.
newtype Label = Label Integer
    deriving (Show, Num, Eq, Ord, Enum)

-- | Variable names are just strings. Again, they are wrapped so they can be
-- treated abstractly. They also carry the label of their binding position.
data Var = Var Label String
    deriving (Show, Eq, Ord)

-- | The label of the 'Lambda' or 'Let' that bound this variable.
binder :: Var -> Label
binder (Var l _) = l

-- | A lambda expression has a label, a list of abstract argument names and a body.
data Lambda = Lambda Label [Var] Call
    deriving (Show, Eq, Ord)

-- | The body of a lambda expression is either 
data Call = App Label Val [Val]
          -- ^ an application of a value to a list of arguments, or
          | Let Label [(Var, Lambda)] Call
          -- ^ it is the definition of a list of (potentially mutable
          -- recursive) lambda expression, defined for a nother call
          -- expression.
    deriving (Show, Eq, Ord)

-- | A value can either be
data Val = L Lambda          -- ^ a lambda abstraction,
         | R Label Var       -- ^ a reference to a variable (which contains the
                             -- label of the binding position of the variable
                             -- for convenience),
         | C Label Const     -- ^ a constant value or
         | P Prim            -- ^ a primitive operation.
    deriving (Show, Eq, Ord)


-- | As constants we only have integers.
type Const = Integer

-- | Primitive operations. The primitive operations are annotated by labels. These mark the (invisible) call sites that call the continuations, and are one per continuation.
data Prim = Plus Label -- ^ Integer addition. Expected parameters: two integers, one continuation.
          | If Label Label -- ^ Conditional branching. Expected paramters: one integer, one continuation to be called if the argument is nonzero, one continuation to be called if the argument is zero ("false")
    deriving (Show, Eq, Ord)

-- * Smart constructors

instance IsString Var where
    fromString s = Var noProg s
instance IsString Val where
    fromString s = R noProg (Var noProg s)
instance IsString a => IsString (Inv a) where
    fromString s = Inv $ fromString s

instance Num (Inv Val) where
    fromInteger i = Inv $ C noProg i
    (+) = error "Do not use the Num Val instance"
    (*) = error "Do not use the Num Val instance"
    abs = error "Do not use the Num Val instance"
    signum = error "Do not use the Num Val instance"
    negate (Inv (C _ i)) = Inv $ (C noProg (-i))
    negate _ = error "Do not use the Num Val instance"


-- | This wrapper marks values created using the smart constructors that are
-- not yet finished by passing them to 'prog' and therefore invalid.
newtype Inv a = Inv { unsafeFinish :: a }
    deriving (Show, Eq)

-- | This converts code generated by the smart constructors below to a fully
-- annotated 'CPSScheme' syntax tree, by assigning labels and resolving references
prog :: Inv Lambda -> Prog
prog (Inv p) = evalState (pLambda M.empty p) [1..] 
  where next = do {l <- gets head ; modify tail; return l}
        pLambda env (Lambda _ vs c) = do
            l <- next
            let env' = env `upd` map (\(Var _ n) -> n  l) vs
            vs' <- mapM (pVar env') vs
            c' <- pCall env' c
            return $ Lambda l vs' c'
        pCall env (App _ v vs) = do
            l <- next
            v' <- pVal env v
            vs' <- mapM (pVal env) vs
            return $ App l v' vs'
        pCall env (Let _ binds c) = do
            l <- next
            let env' = env `upd` map (\(Var _ n,_) -> (n  l)) binds
            binds' <- forM binds $ \(v,l) -> do
                v' <- pVar env' v
                l' <- pLambda env' l
                return (v', l')
            c' <- pCall env' c
            return (Let l binds' c')
        pVal env (L lambda) = L <$> pLambda env lambda
        pVal env (R _ var) = do
            l <- next
            var' <- pVar env var
            return $ R l var'
        pVal env (C _ i) = do
            l <- next
            return $ C l i
        pVal env (P (Plus _)) = do 
            l <- next
            return $ P (Plus l)
        pVal env (P (If _ _)) = do 
            l1 <- next
            l2 <- next
            return $ P (If l1 l2)
        pVar env (Var _ n) = do
            let r = env M.! n
            return $ Var r n


lambda :: [Inv Var] -> Inv Call -> Inv Lambda
lambda vs (Inv c) = Inv $ Lambda noProg (map unsafeFinish vs) c

app :: Inv Val -> [Inv Val] -> Inv Call
app (Inv v) vs = Inv $ App noProg v (map unsafeFinish vs)

let_ :: [(Inv Var, Inv Lambda)] -> Inv Call -> Inv Call
let_ binds (Inv c) = Inv $
        Let noProg (map (\(Inv v, Inv l) -> (v,l)) binds) c

l :: Inv Lambda -> Inv Val
l = Inv . L . unsafeFinish

c :: Const -> Inv Val
c = Inv . C noProg

plus :: Inv Val
plus = Inv $ P (Plus noProg)

if_ :: Inv Val
if_ = Inv $ P (If noProg noProg)

-- | Internal error value
noProg :: a
noProg = error "Smart constructors used without calling prog"

-- * Some example Programs

-- | Returns 0
ex1 :: Prog
ex1 = prog $ lambda ["cont"] $
        app "cont" [0]

-- | Returns 1 + 1
ex2 :: Prog
ex2 = prog $ lambda ["cont"] $ 
        app plus [1, 1, "cont"]

-- | Returns the sum of the first 10 integers            
ex3 :: Prog
ex3 = prog $ lambda ["cont"] $
        let_ [("rec", lambda ["p", "i", "c'"] $
                        app if_
                            [ "i"
                            , l $ lambda [] $
                                app plus ["p", "i",
                                    l $ lambda ["p'"] $
                                        app plus ["i", -1,
                                            l $ lambda ["i'"] $
                                                app "rec" [ "p'", "i'",  "c'" ]
                                            ]
                                    ]
                            , l $ lambda [] $
                                app "c'" ["p"]
                            ]
        )] $ app "rec" [0, 10, "cont"]

-- | Does not Terminate
ex4 :: Prog
ex4 = prog $ lambda ["cont"] $
        let_ [("rec", lambda ["c"] $ app "rec" ["c"])] $
           app "rec" ["cont"]

-- | The puzzle from Shiver's dissertation
puzzle :: Prog
puzzle = prog $ lambda ["k"] $
        app (l $ lambda ["f"] $ app "f" [0, 42, l $ lambda ["v"] $ app "f" [1,"v","k"]])
            [l $ lambda ["x","h","k1"] $
                app if_ [ "x"
                        , l $ lambda [] $ app "h" ["k1"]
                        , l $ lambda [] $ app "k1" [l $ lambda ["k2"] $ app "k2" ["x"]]
                        ]
            ]