{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE UndecidableInstances   #-}

module Control.Monad.Weighted.Class
  ( MonadWeighted(..)
  , collect
  , toCovector
  ) where

import qualified Control.Monad.Trans.Identity as Identity
import qualified Control.Monad.Trans.Except as Except
import qualified Control.Monad.Trans.State.Strict as StateStrict
import qualified Control.Monad.Trans.State.Lazy as StateLazy
import qualified Control.Monad.Trans.Maybe as Maybe
import qualified Control.Monad.Trans.Reader as Reader

import Control.Monad.Trans (lift)

import Data.Semiring

import Data.Coerce

-- | A class for computations which carry a weight with them. It is analogous
-- to 'Control.Monad.Writer.Writer' over the 'Data.Monoid.Product' 'Monoid'.
class (Semiring w, Monad m) => MonadWeighted w m | m -> w where
    {-# MINIMAL (weighted | weight), weigh, scale #-}
    -- | @'weighted' (a,w)@ embeds a simple weighted action.
    weighted :: (a,w) -> m a
    weighted ~(a, w) = do
      weight w
      return a

    -- | @'weight' w@ is an action that produces the output @w@.
    weight   :: w -> m ()
    weight w = weighted ((),w)

    -- | @'weigh' m@ is an action that executes the action @m@ and adds
    -- its output to the value of the computation.
    weigh :: m a -> m (a, w)
    -- | @'scale' m@ is an action that executes the action @m@, which
    -- returns a value and a function, and returns the value, applying
    -- the function to the output.
    scale   :: m (a, w -> w) -> m a

instance MonadWeighted w m => MonadWeighted w (Except.ExceptT e m) where
    weighted = lift . weighted
    weight   = lift . weight
    weigh    = Except.liftListen weigh
    scale    = Except.liftPass scale

instance MonadWeighted w m => MonadWeighted w (Identity.IdentityT m) where
    weighted = lift . weighted
    weight   = lift . weight
    weigh    = Identity.mapIdentityT weigh
    scale    = Identity.mapIdentityT scale

instance MonadWeighted w m => MonadWeighted w (StateStrict.StateT s m) where
    weighted = lift . weighted
    weight   = lift . weight
    weigh    = StateStrict.liftListen weigh
    scale    = StateStrict.liftPass scale

instance MonadWeighted w m => MonadWeighted w (StateLazy.StateT s m) where
    weighted = lift . weighted
    weight   = lift . weight
    weigh    = StateLazy.liftListen weigh
    scale    = StateLazy.liftPass scale

instance MonadWeighted w m => MonadWeighted w (Maybe.MaybeT m) where
    weighted = lift . weighted
    weight   = lift . weight
    weigh    = Maybe.liftListen weigh
    scale    = Maybe.liftPass scale

instance MonadWeighted w m => MonadWeighted w (Reader.ReaderT r m) where
    weighted = lift . weighted
    weight   = lift . weight
    weigh    = Reader.mapReaderT weigh
    scale    = Reader.mapReaderT scale

-- | Collect the total weight of a computation.
collect :: (Foldable m, MonadWeighted w m) => m a -> w
collect = getAdd  #. foldMap (Add #. snd) . weigh

infixr 9 #.
(#.) :: Coercible b c => (b -> c) -> (a -> b) -> a -> c
(#.) _ = coerce

-- | Transform a weighted computation to a covector.
toCovector :: (Foldable m, MonadWeighted w m) => m a -> (a -> w) -> w
toCovector xs f = collect (weight . f =<< xs)