{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}

module LinearScan.Hoopl.DSL where

import           Compiler.Hoopl as Hoopl hiding ((<*>))
import           Control.Applicative
import           Control.Arrow (first)
import           Control.Monad
import           Control.Monad.Fix
import           Control.Monad.Free
import           Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Free as TF
import           Control.Monad.Trans.Free hiding (FreeF(..), Free)
import           Control.Monad.Trans.State
import qualified Data.Map as M
import           Data.Maybe (fromMaybe)
import           Data.Monoid
import           LinearScan
import           Unsafe.Coerce
import           Test.QuickCheck

data SpillStack = SpillStack
    { stackPtr      :: Int
    , stackSlotSize :: Int
    , stackSlots    :: M.Map (Maybe Int) Int
    }
    deriving (Eq, Show)

type Env = State ([Int], SpillStack)

newSpillStack :: Int -> Int -> SpillStack
newSpillStack offset slotSize = SpillStack
    { stackPtr      = offset
    , stackSlotSize = slotSize
    , stackSlots    = mempty
    }

getStackSlot :: Maybe VarId -> Env Int
getStackSlot vid = do
    (supply, stack) <- get
    case M.lookup vid (stackSlots stack) of
        Just off -> return off
        Nothing -> do
            let off = stackPtr stack
            put (supply, stack
                { stackPtr   = off + stackSlotSize stack
                , stackSlots = M.insert vid off (stackSlots stack)
                })
            return off

-- | The 'Asm' monad lets us create labels by name and refer to them later.
type Labels = M.Map String Label
type Asm = StateT Labels SimpleUniqueMonad

getLabel :: String -> Asm Label
getLabel str = do
    l <- gets (M.lookup str)
    case l of
        Just lbl -> return lbl
        Nothing -> do
            lbl <- lift freshLabel
            modify (M.insert str lbl)
            return lbl

-- | A series of 'Nodes' is a set of assembly instructions that ends with some
--   kind of closing operation, such as a jump, branch or return.
type Nodes n a = Free ((,) (n O O)) a

-- | The 'Nodes' free monad is really just a convenient way to describe a list
--   that must result in a closing operation at the end.
nodesToList :: Nodes n a -> (a, [n O O])
nodesToList (Pure a) = (a, [])
nodesToList (Free (n, xs)) = (n :) <$> nodesToList xs

type BodyNode n = Nodes n ()

bodyNode :: n O O -> BodyNode n
bodyNode n = Free (n, Pure ())

type EndNode n = Nodes n (Asm (n O C))

endNode :: Asm (n O C) -> EndNode n
endNode = return

-- | A program is a series of 'Nodes', each associated with a label.
data ProgramF n = FreeBlock
    { labelEntry :: Label
    , labelBody  :: EndNode n
    }
type Program n = FreeT ((,) (ProgramF n)) Asm ()

label :: String -> EndNode n -> Program n
label str body = do
    lbl <- lift $ getLabel str
    liftF (FreeBlock lbl body, ())

jump :: HooplNode n => String -> EndNode n
jump dest = endNode $ mkBranchNode <$> getLabel dest

-- | When we compile a program, the result is a closed Hoopl Graph and the
--   label corresponding to the requested entry label name.
compile :: (NonLocal n, HooplNode n)
        => String -> Program n -> SimpleUniqueMonad (Graph n C C, Label)
compile name prog
    = flip evalStateT (mempty :: Labels)
    $ do body  <- go prog
         entry <- gets (M.lookup name)
         case entry of
             Nothing  -> error $ "Missing label: " ++ name
             Just lbl -> return (bodyGraph body, lbl)
  where
    go m = do
        p <- runFreeT m
        case p of
            TF.Pure () -> return emptyBody
            TF.Free (blk, xs) -> addBlock <$> comp blk <*> go xs

    comp (FreeBlock lbl body) = do
        let (close, blocks) = nodesToList body
        BlockCC (mkLabelNode lbl) (blockFromList blocks) <$> close