{-# OPTIONS -fplugin=Rattus.Plugin #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP #-}
-- | Programming with streams.

module Rattus.Stream
  ( map
  , hd
  , tl
  , const
  , constBox
  , shift
  , shiftMany
  , scan
  , scanMap
  , scanMap2
  , Str(..)
  , zipWith
  , zip
  , unfold
  , filter
  , integral
  )

where

import Rattus
import Prelude hiding (map, const, zipWith, zip, filter)

import Data.VectorSpace

-- | @Str a@ is a stream of values of type @a@.
data Str a = !a ::: !(O (Str a))

-- all functions in this module are in Rattus 
{-# ANN module Rattus #-}

-- | Get the first element (= head) of a stream.
hd :: Str a -> a
hd :: Str a -> a
hd (a
x ::: O (Str a)
_) = a
x


-- | Get the tail of a stream, i.e. the remainder after removing the
-- first element.
tl :: Str a -> O (Str a)
tl :: Str a -> O (Str a)
tl (a
_ ::: O (Str a)
xs) = O (Str a)
xs

-- | Apply a function to each element of a stream.
map :: Box (a -> b) -> Str a -> Str b
map :: Box (a -> b) -> Str a -> Str b
map Box (a -> b)
f (a
x ::: O (Str a)
xs) = Box (a -> b) -> a -> b
forall a. Box a -> a
unbox Box (a -> b)
f a
x b -> O (Str b) -> Str b
forall a. a -> O (Str a) -> Str a
::: Str b -> O (Str b)
forall a. a -> O a
delay (Box (a -> b) -> Str a -> Str b
forall a b. Box (a -> b) -> Str a -> Str b
map Box (a -> b)
f (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
xs))


-- | Construct a stream that has the same given value at each step.
const :: Stable a => a -> Str a
const :: a -> Str a
const a
a = a
a a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (a -> Str a
forall a. Stable a => a -> Str a
const a
a)

-- | Variant of 'const' that allows any type @a@ as argument as long
-- as it is boxed.
constBox :: Box a -> Str a
constBox :: Box a -> Str a
constBox Box a
a = Box a -> a
forall a. Box a -> a
unbox Box a
a a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (Box a -> Str a
forall a. Box a -> Str a
constBox Box a
a)

-- | Construct a stream by repeatedly applying a function to a given
-- start element. That is, @unfold (box f) x@ will produce the stream
-- @x ::: f x ::: f (f x) ::: ...@
unfold :: Stable a => Box (a -> a) -> a -> Str a
unfold :: Box (a -> a) -> a -> Str a
unfold Box (a -> a)
f a
x = a
x a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (Box (a -> a) -> a -> Str a
forall a. Stable a => Box (a -> a) -> a -> Str a
unfold Box (a -> a)
f (Box (a -> a) -> a -> a
forall a. Box a -> a
unbox Box (a -> a)
f a
x))

-- | Similar to Haskell's 'scanl'.
--
-- > scan (box f) x (v1 ::: v2 ::: v3 ::: ... ) == (x `f` v1) ::: ((x `f` v1) `f` v2) ::: ...
--
-- Note: Unlike 'scanl', 'scan' starts with @x `f` v1@, not @x@.
scan :: (Stable b) => Box(b -> a -> b) -> b -> Str a -> Str b
scan :: Box (b -> a -> b) -> b -> Str a -> Str b
scan Box (b -> a -> b)
f b
acc (a
a ::: O (Str a)
as) =  b
acc' b -> O (Str b) -> Str b
forall a. a -> O (Str a) -> Str a
::: Str b -> O (Str b)
forall a. a -> O a
delay (Box (b -> a -> b) -> b -> Str a -> Str b
forall b a. Stable b => Box (b -> a -> b) -> b -> Str a -> Str b
scan Box (b -> a -> b)
f b
acc' (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
as))
  where acc' :: b
acc' = Box (b -> a -> b) -> b -> a -> b
forall a. Box a -> a
unbox Box (b -> a -> b)
f b
acc a
a

-- | 'scanMap' is a composition of 'map' and 'scan':
--
-- > scanMap f g x === map g . scan f x
scanMap :: (Stable b) => Box(b -> a -> b) -> Box (b -> c) -> b -> Str a -> Str c
scanMap :: Box (b -> a -> b) -> Box (b -> c) -> b -> Str a -> Str c
scanMap Box (b -> a -> b)
f Box (b -> c)
p b
acc (a
a ::: O (Str a)
as) =  Box (b -> c) -> b -> c
forall a. Box a -> a
unbox Box (b -> c)
p b
acc' c -> O (Str c) -> Str c
forall a. a -> O (Str a) -> Str a
::: Str c -> O (Str c)
forall a. a -> O a
delay (Box (b -> a -> b) -> Box (b -> c) -> b -> Str a -> Str c
forall b a c.
Stable b =>
Box (b -> a -> b) -> Box (b -> c) -> b -> Str a -> Str c
scanMap Box (b -> a -> b)
f Box (b -> c)
p b
acc' (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
as))
  where acc' :: b
acc' = Box (b -> a -> b) -> b -> a -> b
forall a. Box a -> a
unbox Box (b -> a -> b)
f b
acc a
a


-- | 'scanMap2' is similar to 'scanMap' but takes two input streams.
scanMap2 :: (Stable b) => Box(b -> a1 -> a2 -> b) -> Box (b -> c) -> b -> Str a1 -> Str a2 -> Str c
scanMap2 :: Box (b -> a1 -> a2 -> b)
-> Box (b -> c) -> b -> Str a1 -> Str a2 -> Str c
scanMap2 Box (b -> a1 -> a2 -> b)
f Box (b -> c)
p b
acc (a1
a1 ::: O (Str a1)
as1) (a2
a2 ::: O (Str a2)
as2) =
    Box (b -> c) -> b -> c
forall a. Box a -> a
unbox Box (b -> c)
p b
acc' c -> O (Str c) -> Str c
forall a. a -> O (Str a) -> Str a
::: Str c -> O (Str c)
forall a. a -> O a
delay (Box (b -> a1 -> a2 -> b)
-> Box (b -> c) -> b -> Str a1 -> Str a2 -> Str c
forall b a1 a2 c.
Stable b =>
Box (b -> a1 -> a2 -> b)
-> Box (b -> c) -> b -> Str a1 -> Str a2 -> Str c
scanMap2 Box (b -> a1 -> a2 -> b)
f Box (b -> c)
p b
acc' (O (Str a1) -> Str a1
forall a. O a -> a
adv O (Str a1)
as1) (O (Str a2) -> Str a2
forall a. O a -> a
adv O (Str a2)
as2))
  where acc' :: b
acc' = Box (b -> a1 -> a2 -> b) -> b -> a1 -> a2 -> b
forall a. Box a -> a
unbox Box (b -> a1 -> a2 -> b)
f b
acc a1
a1 a2
a2

-- | Similar to 'Prelude.zipWith' on Haskell lists.
zipWith :: Box(a -> b -> c) -> Str a -> Str b -> Str c
zipWith :: Box (a -> b -> c) -> Str a -> Str b -> Str c
zipWith Box (a -> b -> c)
f (a
a ::: O (Str a)
as) (b
b ::: O (Str b)
bs) = Box (a -> b -> c) -> a -> b -> c
forall a. Box a -> a
unbox Box (a -> b -> c)
f a
a b
b c -> O (Str c) -> Str c
forall a. a -> O (Str a) -> Str a
::: Str c -> O (Str c)
forall a. a -> O a
delay (Box (a -> b -> c) -> Str a -> Str b -> Str c
forall a b c. Box (a -> b -> c) -> Str a -> Str b -> Str c
zipWith Box (a -> b -> c)
f (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
as) (O (Str b) -> Str b
forall a. O a -> a
adv O (Str b)
bs))

-- | Similar to 'Prelude.zip' on Haskell lists.
zip :: Str a -> Str b -> Str (a:*b)
zip :: Str a -> Str b -> Str (a :* b)
zip (a
a ::: O (Str a)
as) (b
b ::: O (Str b)
bs) =  (a
a a -> b -> a :* b
forall a b. a -> b -> a :* b
:* b
b) (a :* b) -> O (Str (a :* b)) -> Str (a :* b)
forall a. a -> O (Str a) -> Str a
::: Str (a :* b) -> O (Str (a :* b))
forall a. a -> O a
delay (Str a -> Str b -> Str (a :* b)
forall a b. Str a -> Str b -> Str (a :* b)
zip (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
as) (O (Str b) -> Str b
forall a. O a -> a
adv O (Str b)
bs))


-- | Filter out elements from a stream according to a predicate.
filter :: Box(a -> Bool) -> Str a -> Str(Maybe' a)
filter :: Box (a -> Bool) -> Str a -> Str (Maybe' a)
filter Box (a -> Bool)
p = Box (a -> Maybe' a) -> Str a -> Str (Maybe' a)
forall a b. Box (a -> b) -> Str a -> Str b
map ((a -> Maybe' a) -> Box (a -> Maybe' a)
forall a. a -> Box a
box (\a
a -> if Box (a -> Bool) -> a -> Bool
forall a. Box a -> a
unbox Box (a -> Bool)
p a
a then a -> Maybe' a
forall a. a -> Maybe' a
Just' a
a else Maybe' a
forall a. Maybe' a
Nothing'))

{-| Given a value a and a stream as, this function produces a stream
  that behaves like -}
shift :: Stable a => a -> Str a -> Str a
shift :: a -> Str a -> Str a
shift a
a (a
x ::: O (Str a)
xs) = a
a a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (a -> Str a -> Str a
forall a. Stable a => a -> Str a -> Str a
shift a
x (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
xs))


{-| Given a list @[a1, ..., an]@ of elements and a stream @xs@ this
  function constructs a stream that starts with the elements @a1, ...,
  an@, and then proceeds as @xs@. In particular, this means that the
  ith element of the original stream @xs@ is the (i+n)th element of
  the new stream. In other words @shiftMany@ behaves like repeatedly
  applying @shift@ for each element in the list. -}
shiftMany :: Stable a => List a -> Str a -> Str a
shiftMany :: List a -> Str a -> Str a
shiftMany List a
l Str a
xs = List a -> List a -> Str a -> Str a
forall a. Stable a => List a -> List a -> Str a -> Str a
run List a
l List a
forall a. List a
Nil Str a
xs where
  run :: Stable a => List a -> List a -> Str a -> Str a
  run :: List a -> List a -> Str a -> Str a
run (a
b :! List a
bs) List a
buf (a
x ::: O (Str a)
xs) = a
b a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (List a -> List a -> Str a -> Str a
forall a. Stable a => List a -> List a -> Str a -> Str a
run List a
bs (a
x a -> List a -> List a
forall a. a -> List a -> List a
:! List a
buf) (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
xs))
  run List a
Nil List a
buf (a
x ::: O (Str a)
xs) =
    case List a -> List a
forall a. List a -> List a
reverse' List a
buf of
      a
b :! List a
bs -> a
b a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (List a -> List a -> Str a -> Str a
forall a. Stable a => List a -> List a -> Str a -> Str a
run List a
bs (a
x a -> List a -> List a
forall a. a -> List a -> List a
:! List a
forall a. List a
Nil) (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
xs))
      List a
Nil -> a
x a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: O (Str a)
xs
    
-- | Calculates an approximation of an integral of the stream of type
-- @Str a@ (the y-axis), where the stream of type @Str s@ provides the
-- distance between measurements (i.e. the distance along the y axis).
integral :: (Stable a, VectorSpace a s) => a -> Str s -> Str a -> Str a
integral :: a -> Str s -> Str a -> Str a
integral a
acc (s
t ::: O (Str s)
ts) (a
a ::: O (Str a)
as) = a
acc' a -> O (Str a) -> Str a
forall a. a -> O (Str a) -> Str a
::: Str a -> O (Str a)
forall a. a -> O a
delay (a -> Str s -> Str a -> Str a
forall a s.
(Stable a, VectorSpace a s) =>
a -> Str s -> Str a -> Str a
integral a
acc' (O (Str s) -> Str s
forall a. O a -> a
adv O (Str s)
ts) (O (Str a) -> Str a
forall a. O a -> a
adv O (Str a)
as))
  where acc' :: a
acc' = a
acc a -> a -> a
forall v a. VectorSpace v a => v -> v -> v
^+^ (s
t s -> a -> a
forall v a. VectorSpace v a => a -> v -> v
*^ a
a)


-- Prevent functions from being inlined too early for the rewrite
-- rules to fire.

{-# NOINLINE [1] map #-}
{-# NOINLINE [1] const #-}
{-# NOINLINE [1] constBox #-}
{-# NOINLINE [1] scan #-}
{-# NOINLINE [1] scanMap #-}
{-# NOINLINE [1] zip #-}


{-# RULES

  "const/map" forall (f :: Stable b => Box (a -> b))  x.
    map f (const x) = let x' = unbox f x in const x' ;

  "map/map" forall f g xs.
    map f (map g xs) = map (box (unbox f . unbox g)) xs ;

  "map/scan" forall f p acc as.
    map p (scan f acc as) = scanMap f p acc as ;

  "zip/map" forall xs ys f.
    map f (zip xs ys) = let f' = unbox f in zipWith (box (\ x y -> f' (x :* y))) xs ys
#-}


#if __GLASGOW_HASKELL__ >= 808
{-# RULES
  "scan/scan" forall f g b c as.
    scan g c (scan f b as) =
      let f' = unbox f; g' = unbox g in
      scanMap (box (\ (b:*c) a -> let b' = f' b a in (b':* g' c b'))) (box snd') (b:*c) as ;

  "scan/scanMap" forall f g p b c as.
    scan g c (scanMap f p b as) =
      let f' = unbox f; g' = unbox g; p' = unbox p in
      scanMap (box (\ (b:*c) a -> let b' = f' (p' b) a in (b':* g' c b'))) (box snd') (b:*c) as ;

#-}
#endif