-- Copyright 2013 Kevin Backhouse.

The 'OrdCons' instrument uses two passes to implement hash-consing.
The values are added to the table during the first pass and a unique
index for each value is returned during the second pass.

'OrdCons' is implemented using 'Data.Map', so it can be used on any
datatype which is an instance of 'Ord'.

module Control.Monad.MultiPass.Instrument.OrdCons
  ( OrdCons
  , initOrdCons, ordCons, getOrdConsTable
  , OrdConsTable
  , lookupOrdConsTable, insertOrdConsTable, growOrdConsTable

import Control.Exception ( assert )
import Control.Monad.ST2
import Control.Monad.Writer.Strict
import Control.Monad.MultiPass
import Control.Monad.MultiPass.ThreadContext.MonoidTC
import qualified Data.Map as FM
import Data.Maybe ( isJust, fromJust )

-- | Abstract datatype for the instrument.
data OrdCons a r w p1 p2 tc
  = OrdCons
      { initInternal
          :: !(p1 (OrdConsTable a) -> MultiPassPrologue r w tc ())

      , ordConsInternal
          :: !(p1 a -> MultiPass r w tc (p2 Int))

      , getOrdConsTableInternal
          :: !(MultiPassEpilogue r w tc (p2 (OrdConsTable a)))

-- | Initialise the 'OrdCons' instrument with an 'OrdConsTable'.  This
-- method is optional. Ff this method is not used then the instrument
-- will be initialised with an empty 'OrdConsTable'.
  :: (Ord a, Monad p1, Monad p2)
  => OrdCons a r w p1 p2 tc       -- ^ Instrument
  -> p1 (OrdConsTable a)          -- ^ Initial table
  -> MultiPassPrologue r w tc ()
initOrdCons =

-- | Get a unique index for the value.
  :: (Ord a, Monad p1, Monad p2)
  => OrdCons a r w p1 p2 tc       -- ^ Instrument
  -> p1 a                         -- ^ Value
  -> MultiPass r w tc (p2 Int)    -- ^ Unique index
ordCons =

-- | Get the final 'OrdConsTable'.
  :: OrdCons a r w p1 p2 tc
  -> MultiPassEpilogue r w tc (p2 (OrdConsTable a))
getOrdConsTable =

-- | This datatype is a newtype around @'FM.Map' a 'Int'@. It maps its
-- keys (of type @a@) to a permutation of the integers @0..n-1@, where
-- @n@ is the number of keys.
newtype OrdConsTable a
  = OrdConsTable (FM.Map a Int)

-- | Empty 'OrdConsTable'.
emptyOrdConsTable :: OrdConsTable a
emptyOrdConsTable =
  OrdConsTable FM.empty

-- | Lookup an element.
lookupOrdConsTable :: Ord a => OrdConsTable a -> a -> Maybe Int
lookupOrdConsTable (OrdConsTable table) x =
  FM.lookup x table

-- | Insert an element. If the element is not in the map yet, then it
-- is assigned index @n@, where @n@ is the original size of the table.
insertOrdConsTable :: Ord a => OrdConsTable a -> a -> OrdConsTable a
insertOrdConsTable (OrdConsTable table) x =
  if FM.member x table
     then OrdConsTable table
     else OrdConsTable $ FM.insert x (FM.size table) table

-- | Add multiple elements. The new elements are assigned indices
-- @n..n+k-1@, where @n@ is the original size of the table and @k@ is
-- the number of new elements to be added. This function will assert
-- if any of the new elements are already in the table.
  :: Ord a => OrdConsTable a -> FM.Map a () -> OrdConsTable a
growOrdConsTable (OrdConsTable table) xs =
  assert (FM.null (FM.intersection table xs)) $
  let n = FM.size table in
  let xs' = snd $ FM.mapAccum (\i () -> (i+1, i)) n xs in
  OrdConsTable $ FM.union table xs'

newtype GC1 r w a
  = GC1 (ST2Ref r w (OrdConsTable a))

newtype OrdConsTC a
  = OrdConsTC (FM.Map a ())

instance Ord a => Monoid (OrdConsTC a) where
  mempty =
    OrdConsTC FM.empty

  mappend (OrdConsTC xs) (OrdConsTC ys) =
    OrdConsTC (FM.union xs ys)

instance Instrument tc () ()
                    (OrdCons a r w Off Off tc) where
  createInstrument _ _ () =
    wrapInstrument $ OrdCons
      { initInternal = \Off -> return ()
      , ordConsInternal = \Off -> return Off
      , getOrdConsTableInternal = return Off

instance Ord a =>
         Instrument tc (MonoidTC (OrdConsTC a)) (GC1 r w a)
                    (OrdCons a r w On Off tc) where
  createInstrument st2ToMP updateCtx (GC1 initTableRef) =
    wrapInstrument $ OrdCons
      { initInternal = \(On initTable) ->
          mkMultiPassPrologue $
          do -- Check that the initTableRef has not been initialised
             -- already.
             OrdConsTable xs <- st2ToMP $ readST2Ref initTableRef
             assert (FM.null xs) $ return ()
             st2ToMP $ writeST2Ref initTableRef initTable

      , ordConsInternal = \(On x) ->
          let updateTable initTable (MonoidTC (OrdConsTC table)) =
                MonoidTC $ OrdConsTC $
                if isJust (lookupOrdConsTable initTable x)
                   then table
                   else FM.insert x () table
          mkMultiPass $
          do initTable <- st2ToMP $ readST2Ref initTableRef
             _ <- updateCtx (updateTable initTable)
             return Off

      , getOrdConsTableInternal =
          return Off

-- The gc2_newTable field is a superset of gc2_initTable. (The
-- initTable is only used if back-tracking occurs.)
data GC2 a
  = GC2
      { gc2_initTable :: !(OrdConsTable a)
      , gc2_newTable  :: !(OrdConsTable a)

instance Ord a => Instrument tc () (GC2 a)
                             (OrdCons a r w On On tc) where
  createInstrument _ _ gc =
    let newTable = gc2_newTable gc in
    wrapInstrument $ OrdCons
      { initInternal = \(On _) -> return ()

      , ordConsInternal = \(On x) ->
          let m = lookupOrdConsTable newTable x in
          assert (isJust m) $
          return $ On $ fromJust m

      , getOrdConsTableInternal =
          return (On newTable)

-- This instrument never needs to back-track.
instance BackTrack r w tc (GC1 r w a)
instance BackTrack r w () (GC2 a)

instance NextGlobalContext r w () () (GC1 r w a) where
  nextGlobalContext _ _ () () =
    do initTableRef <- newST2Ref emptyOrdConsTable
       return (GC1 initTableRef)

instance NextGlobalContext r w tc (GC1 r w a) (GC1 r w a) where
  nextGlobalContext _ _ _ gc =
    return gc

instance Ord a =>
         NextGlobalContext r w (MonoidTC (OrdConsTC a))
                           (GC1 r w a) (GC2 a) where
  nextGlobalContext _ _ tc gc =
    let GC1 initTableRef = gc in
    let MonoidTC (OrdConsTC table) = tc in
    do initTable <- readST2Ref initTableRef
       return $ GC2
         { gc2_initTable = initTable
         , gc2_newTable  = growOrdConsTable initTable table

instance NextGlobalContext r w tc (GC2 a) (GC2 a) where
  nextGlobalContext _ _ _ gc =
    return gc

instance NextGlobalContext r w tc (GC2 a) (GC1 r w a) where
  nextGlobalContext _ _ _ gc =
    do initTableRef <- newST2Ref (gc2_initTable gc)
       return (GC1 initTableRef)