{-# LANGUAGE FlexibleInstances #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Arrow.Transformer.Writer
-- Copyright   :  (c) Ross Paterson 2003
-- License     :  BSD-style (see the LICENSE file in the distribution)
--
-- Maintainer  :  ross@soi.city.ac.uk
-- Stability   :  experimental
-- Portability :  non-portable (multi-parameter type classes)
--
-- Arrow transformer that adds accumulation of output.

module Control.Arrow.Transformer.Writer(
		WriterArrow(WriterArrow),
		runWriter,
		ArrowAddWriter(..),
	) where

import Control.Arrow.Internals
import Control.Arrow.Operations
import Control.Arrow.Transformer

import Control.Applicative
import Control.Arrow
import Control.Category
import Data.Monoid

import Prelude hiding (id,(.))

-- | An arrow type that augments an existing arrow with accumulating
-- output.  The 'ArrowWriter' class contains the relevant operations.

newtype WriterArrow w a b c = WriterArrow (a b (c, w))

-- | Encapsulation of a writer computation, providing the accumulated output.
--
-- Typical usage in arrow notation:
--
-- >	proc p -> do
-- >		...
-- >		(result, output) <- (|runWriter cmd|)

runWriter :: (Arrow a, Monoid w) => WriterArrow w a e b -> a e (b,w)
runWriter (WriterArrow f) = f

rstrength :: ((a, w), b) -> ((a, b), w)
rstrength ((a, w), b) = ((a, b), w)

unit :: Monoid w => a -> (a, w)
unit a = (a, mempty)

join :: Monoid w => ((a, w), w) -> (a, w)
join ((a, w2), w1) = (a, w1 `mappend` w2)

-- arrow transformer

instance (Arrow a, Monoid w) => ArrowTransformer (WriterArrow w) a where
	lift f = WriterArrow (f >>> arr unit)

-- liftings of standard classes

instance (Arrow a, Monoid w) => Category (WriterArrow w a) where
	id = WriterArrow (arr unit)
	WriterArrow f . WriterArrow g =
		WriterArrow (arr join . first f . g)

instance (Arrow a, Monoid w) => Arrow (WriterArrow w a) where
	arr f = WriterArrow (arr (unit . f))
	first (WriterArrow f) = WriterArrow (first f >>> arr rstrength)

instance (ArrowChoice a, Monoid w) => ArrowChoice (WriterArrow w a) where
	left (WriterArrow f) = WriterArrow (left f >>> arr lift_monoid)
		where	lift_monoid (Left (x, w)) = (Left x, w)
			lift_monoid (Right y) = unit (Right y)

instance (ArrowApply a, Monoid w) => ArrowApply (WriterArrow w a) where
	app = WriterArrow (arr (\(WriterArrow f, x) -> (f, x)) >>> app)

instance (ArrowZero a, Monoid w) => ArrowZero (WriterArrow w a) where
	zeroArrow = WriterArrow zeroArrow

instance (ArrowPlus a, Monoid w) => ArrowPlus (WriterArrow w a) where
	WriterArrow f <+> WriterArrow g = WriterArrow (f <+> g)

instance (ArrowLoop a, Monoid w) => ArrowLoop (WriterArrow w a) where
	loop (WriterArrow f) = WriterArrow (loop (f >>> arr swapenv))
		where	swapenv ~(~(x, y), w) = ((x, w), y)

-- Other instances

instance (Arrow a, Monoid w) => Functor (WriterArrow w a b) where
	fmap f g = g >>> arr f

instance (Arrow a, Monoid w) => Applicative (WriterArrow w a b) where
	pure x = arr (const x)
	f <*> g = f &&& g >>> arr (uncurry id)

instance (ArrowPlus a, Monoid w) => Alternative (WriterArrow w a b) where
	empty = zeroArrow
	f <|> g = f <+> g

instance (ArrowPlus a, Monoid w) => Monoid (WriterArrow w a b c) where
	mempty = zeroArrow
	mappend f g = f <+> g

-- new instances

instance (Arrow a, Monoid w) => ArrowWriter w (WriterArrow w a) where
	write = WriterArrow (arr (\x -> ((), x)))
	newWriter (WriterArrow f) =
		WriterArrow (f >>> arr (\(x, w) -> ((x, w), w)))

instance (Arrow a, Monoid w) => ArrowAddWriter w (WriterArrow w a) a where
	liftWriter = lift
	elimWriter = runWriter

-- liftings of other classes

instance (ArrowCircuit a, Monoid w) => ArrowCircuit (WriterArrow w a) where
	delay x = lift (delay x)

instance (ArrowError ex a, Monoid w) => ArrowError ex (WriterArrow w a) where
	raise = lift raise
	handle (WriterArrow f) (WriterArrow h) = WriterArrow (handle f h)
	tryInUnless (WriterArrow f) (WriterArrow s) (WriterArrow h) =
		WriterArrow (tryInUnless f s' h)
		where	s' = arr lstrength >>> first s >>> arr join
			lstrength (x, (y, w)) = ((x, y), w)
	newError (WriterArrow f) = WriterArrow (newError f >>> arr h)
		where	h (Left ex) = unit (Left ex)
			h (Right (c, w)) = (Right c, w)

instance (ArrowReader r a, Monoid w) => ArrowReader r (WriterArrow w a) where
	readState = lift readState
	newReader (WriterArrow f) = WriterArrow (newReader f)

instance (ArrowState s a, Monoid w) => ArrowState s (WriterArrow w a) where
	fetch = lift fetch
	store = lift store

-- promotions of encapsulation operators

instance (ArrowAddError ex a a', Monoid w) =>
		ArrowAddError ex (WriterArrow w a) (WriterArrow w a') where
	liftError (WriterArrow f) = WriterArrow (liftError f)
	elimError (WriterArrow f) (WriterArrow h) = WriterArrow (elimError f h)

instance (ArrowAddReader r a a', Monoid w) =>
		ArrowAddReader r (WriterArrow w a) (WriterArrow w a') where
	liftReader (WriterArrow f) = WriterArrow (liftReader f)
	elimReader (WriterArrow f) = WriterArrow (elimReader f)

instance (ArrowAddState s a a', Monoid w) =>
		ArrowAddState s (WriterArrow w a) (WriterArrow w a') where
	liftState (WriterArrow f) = WriterArrow (liftState f)
	elimState (WriterArrow f) = WriterArrow (elimState f >>> arr rstrength)