{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
module Control.Scanl (
Scan(..)
, ScanM(..)
, scan
, scanM
, prescan
, postscan
, purely
, purely_
, impurely
, impurely_
, generalize
, simplify
, hoists
, arrM
, premap
, premapM
) where
import Control.Applicative
import Control.Arrow
import Control.Category
import Control.Foldl (Fold(..))
import Control.Foldl.Internal (Pair(..))
import Control.Monad ((<=<))
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict
import Data.Functor.Identity
import Data.Monoid hiding ((<>))
import Data.Profunctor
import Data.Semigroup (Semigroup(..))
import Data.Traversable
import Data.Tuple (swap)
import Prelude hiding ((.), id)
data Scan a b
= forall x. Scan (a -> State x b) x
instance Functor (Scan a) where
fmap f (Scan step begin) = Scan (fmap f . step) begin
{-# INLINE fmap #-}
instance Applicative (Scan a) where
pure b = Scan (\_ -> pure b) ()
{-# INLINE pure #-}
(Scan stepL beginL) <*> (Scan stepR beginR) =
let step a (Pair xL xR) = (bL bR, (Pair xL' xR'))
where (bL, xL') = runState (stepL a) xL
(bR, xR') = runState (stepR a) xR
begin = Pair beginL beginR
in Scan (state . step) begin
{-# INLINE (<*>) #-}
instance Profunctor Scan where
lmap = premap
rmap = fmap
instance Category Scan where
id = Scan pure ()
{-# INLINE id #-}
(Scan s2 b2) . (Scan s1 b1) = Scan (state . step) (Pair b1 b2)
where step a (Pair xL xR) = (c, Pair xL' xR')
where (b, xL') = runState (s1 a) xL
(c, xR') = runState (s2 b) xR
{-# INLINE (.) #-}
instance Arrow Scan where
arr f = Scan (pure . f) ()
{-# INLINE arr #-}
first (Scan step begin) = Scan
(\(a,b) -> state $ \x -> first (,b) $ runState (step a) x)
begin
{-# INLINE first #-}
second (Scan step begin) = Scan
(\(b,a) -> state $ \x -> first (b,) $ runState (step a) x)
begin
{-# INLINE second #-}
instance Semigroup b => Semigroup (Scan a b) where
(<>) = liftA2 (<>)
{-# INLINE (<>) #-}
instance Monoid b => Monoid (Scan a b) where
mempty = pure mempty
{-# INLINE mempty #-}
mappend = liftA2 mappend
{-# INLINE mappend #-}
instance Num b => Num (Scan a b) where
fromInteger = pure . fromInteger
{-# INLINE fromInteger #-}
negate = fmap negate
{-# INLINE negate #-}
abs = fmap abs
{-# INLINE abs #-}
signum = fmap signum
{-# INLINE signum #-}
(+) = liftA2 (+)
{-# INLINE (+) #-}
(*) = liftA2 (*)
{-# INLINE (*) #-}
(-) = liftA2 (-)
{-# INLINE (-) #-}
instance Fractional b => Fractional (Scan a b) where
fromRational = pure . fromRational
{-# INLINE fromRational #-}
recip = fmap recip
{-# INLINE recip #-}
(/) = liftA2 (/)
{-# INLINE (/) #-}
instance Floating b => Floating (Scan a b) where
pi = pure pi
{-# INLINE pi #-}
exp = fmap exp
{-# INLINE exp #-}
sqrt = fmap sqrt
{-# INLINE sqrt #-}
log = fmap log
{-# INLINE log #-}
sin = fmap sin
{-# INLINE sin #-}
tan = fmap tan
{-# INLINE tan #-}
cos = fmap cos
{-# INLINE cos #-}
asin = fmap asin
{-# INLINE asin #-}
atan = fmap atan
{-# INLINE atan #-}
acos = fmap acos
{-# INLINE acos #-}
sinh = fmap sinh
{-# INLINE sinh #-}
tanh = fmap tanh
{-# INLINE tanh #-}
cosh = fmap cosh
{-# INLINE cosh #-}
asinh = fmap asinh
{-# INLINE asinh #-}
atanh = fmap atanh
{-# INLINE atanh #-}
acosh = fmap acosh
{-# INLINE acosh #-}
(**) = liftA2 (**)
{-# INLINE (**) #-}
logBase = liftA2 logBase
{-# INLINE logBase #-}
data ScanM m a b =
forall x . ScanM (a -> StateT x m b) (m x)
instance Functor m => Functor (ScanM m a) where
fmap f (ScanM step begin) = ScanM (fmap f . step) begin
{-# INLINE fmap #-}
instance Applicative m => Applicative (ScanM m a) where
pure b = ScanM (\_ -> StateT $ \() -> pure (b, ())) (pure ())
{-# INLINE pure #-}
(ScanM stepL beginL) <*> (ScanM stepR beginR) =
let step a (Pair xL xR) =
(\(bL, xL') (bR, xR') -> (bL bR, (Pair xL' xR')))
<$> runStateT (stepL a) xL
<*> runStateT (stepR a) xR
begin = Pair <$> beginL <*> beginR
in ScanM (StateT . step) begin
{-# INLINE (<*>) #-}
instance Functor m => Profunctor (ScanM m) where
rmap = fmap
lmap f (ScanM step begin) = ScanM (step . f) begin
instance Monad m => Category (ScanM m) where
id = ScanM pure (pure ())
{-# INLINE id #-}
(ScanM s2 b2) . (ScanM s1 b1) = ScanM (StateT . step) (Pair <$> b1 <*> b2)
where step a (Pair xL xR) = do
(b, xL') <- runStateT (s1 a) xL
(c, xR') <- runStateT (s2 b) xR
pure (c, Pair xL' xR')
{-# INLINE (.) #-}
instance Monad m => Arrow (ScanM m) where
arr f = ScanM (lift . pure . f) (pure ())
{-# INLINE arr #-}
first (ScanM step begin) = ScanM
(\(a,b) -> StateT $ \x -> first (,b) <$> runStateT (step a) x)
begin
{-# INLINE first #-}
second (ScanM step begin) = ScanM
(\(b,a) -> StateT $ \x -> first (b,) <$> runStateT (step a) x)
begin
{-# INLINE second #-}
instance (Monad m, Semigroup b) => Semigroup (ScanM m a b) where
(<>) = liftA2 (<>)
{-# INLINE (<>) #-}
instance (Monad m, Monoid b) => Monoid (ScanM m a b) where
mempty = pure mempty
{-# INLINE mempty #-}
mappend = liftA2 mappend
{-# INLINE mappend #-}
instance (Monad m, Num b) => Num (ScanM m a b) where
fromInteger = pure . fromInteger
{-# INLINE fromInteger #-}
negate = fmap negate
{-# INLINE negate #-}
abs = fmap abs
{-# INLINE abs #-}
signum = fmap signum
{-# INLINE signum #-}
(+) = liftA2 (+)
{-# INLINE (+) #-}
(*) = liftA2 (*)
{-# INLINE (*) #-}
(-) = liftA2 (-)
{-# INLINE (-) #-}
instance (Monad m, Fractional b) => Fractional (ScanM m a b) where
fromRational = pure . fromRational
{-# INLINE fromRational #-}
recip = fmap recip
{-# INLINE recip #-}
(/) = liftA2 (/)
{-# INLINE (/) #-}
instance (Monad m, Floating b) => Floating (ScanM m a b) where
pi = pure pi
{-# INLINE pi #-}
exp = fmap exp
{-# INLINE exp #-}
sqrt = fmap sqrt
{-# INLINE sqrt #-}
log = fmap log
{-# INLINE log #-}
sin = fmap sin
{-# INLINE sin #-}
tan = fmap tan
{-# INLINE tan #-}
cos = fmap cos
{-# INLINE cos #-}
asin = fmap asin
{-# INLINE asin #-}
atan = fmap atan
{-# INLINE atan #-}
acos = fmap acos
{-# INLINE acos #-}
sinh = fmap sinh
{-# INLINE sinh #-}
tanh = fmap tanh
{-# INLINE tanh #-}
cosh = fmap cosh
{-# INLINE cosh #-}
asinh = fmap asinh
{-# INLINE asinh #-}
atanh = fmap atanh
{-# INLINE atanh #-}
acosh = fmap acosh
{-# INLINE acosh #-}
(**) = liftA2 (**)
{-# INLINE (**) #-}
logBase = liftA2 logBase
{-# INLINE logBase #-}
scan :: Traversable t => Scan a b -> t a -> t b
scan (Scan step begin) as = fst $ runState (traverse step as) begin
{-# INLINE scan #-}
scanM :: (Traversable t, Monad m) => ScanM m a b -> t a -> m (t b)
scanM (ScanM step begin) as = fmap fst $ runStateT (traverse step as) =<< begin
{-# INLINE scanM #-}
prescan :: Fold a b -> Scan a b
prescan (Fold step begin done) = Scan (state . step') begin
where
step' a x = (b, x')
where
x' = step x a
b = done x
{-# INLINE prescan #-}
postscan :: Fold a b -> Scan a b
postscan (Fold step begin done) = Scan (state . step') begin
where
step' a x = (b, x')
where
x' = step x a
b = done x'
{-# INLINE postscan #-}
arrM :: Monad m => (b -> m c) -> ScanM m b c
arrM f = ScanM (lift . f) (pure ())
{-# INLINE arrM #-}
purely :: (forall x . (a -> State x b) -> x -> r) -> Scan a b -> r
purely f (Scan step begin) = f step begin
{-# INLINABLE purely #-}
purely_ :: (forall x . (x -> a -> (x, b)) -> x -> r) -> Scan a b -> r
purely_ f (Scan step begin) = f (\s a -> swap $ runState (step a) s) begin
{-# INLINABLE purely_ #-}
impurely
:: (forall x . (a -> StateT x m b) -> m x -> r)
-> ScanM m a b
-> r
impurely f (ScanM step begin) = f step begin
{-# INLINABLE impurely #-}
impurely_
:: Monad m
=> (forall x . (x -> a -> m (x, b)) -> m x -> r)
-> ScanM m a b
-> r
impurely_ f (ScanM step begin) = f (\s a -> swap <$> runStateT (step a) s) begin
generalize :: Monad m => Scan a b -> ScanM m a b
generalize (Scan step begin) = hoists
(\(Identity c) -> pure c)
(ScanM step (Identity begin))
{-# INLINABLE generalize #-}
simplify :: ScanM Identity a b -> Scan a b
simplify (ScanM step (Identity begin)) = Scan step begin
{-# INLINABLE simplify #-}
hoists :: (forall x . m x -> n x) -> ScanM m a b -> ScanM n a b
hoists phi (ScanM step begin ) = ScanM
(\a -> StateT $ phi . runStateT (step a))
(phi begin)
{-# INLINABLE hoists #-}
premap :: (a -> b) -> Scan b r -> Scan a r
premap f (Scan step begin) = Scan (step . f) begin
{-# INLINABLE premap #-}
premapM :: Monad m => (a -> m b) -> ScanM m b r -> ScanM m a r
premapM f (ScanM step begin) = ScanM (step <=< lift . f) begin
{-# INLINABLE premapM #-}