{-# LANGUAGE CPP #-}
-- |
-- Module      : Data.Vector.Fusion.Util
-- Copyright   : (c) Roman Leshchinskiy 2009
-- License     : BSD-style
--
-- Maintainer  : Roman Leshchinskiy <rl@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : portable
--
-- Fusion-related utility types
--

module Data.Vector.Fusion.Util (
  Id(..), Box(..),

  delay_inline, delayed_min
) where

#if !MIN_VERSION_base(4,8,0)
import Control.Applicative (Applicative(..))
#endif

-- | Identity monad
newtype Id a = Id { Id a -> a
unId :: a }

instance Functor Id where
  fmap :: (a -> b) -> Id a -> Id b
fmap a -> b
f (Id a
x) = b -> Id b
forall a. a -> Id a
Id (a -> b
f a
x)

instance Applicative Id where
  pure :: a -> Id a
pure = a -> Id a
forall a. a -> Id a
Id
  Id a -> b
f <*> :: Id (a -> b) -> Id a -> Id b
<*> Id a
x = b -> Id b
forall a. a -> Id a
Id (a -> b
f a
x)

instance Monad Id where
  return :: a -> Id a
return = a -> Id a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  Id a
x >>= :: Id a -> (a -> Id b) -> Id b
>>= a -> Id b
f = a -> Id b
f a
x

-- | Box monad
data Box a = Box { Box a -> a
unBox :: a }

instance Functor Box where
  fmap :: (a -> b) -> Box a -> Box b
fmap a -> b
f (Box a
x) = b -> Box b
forall a. a -> Box a
Box (a -> b
f a
x)

instance Applicative Box where
  pure :: a -> Box a
pure = a -> Box a
forall a. a -> Box a
Box
  Box a -> b
f <*> :: Box (a -> b) -> Box a -> Box b
<*> Box a
x = b -> Box b
forall a. a -> Box a
Box (a -> b
f a
x)

instance Monad Box where
  return :: a -> Box a
return = a -> Box a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  Box a
x >>= :: Box a -> (a -> Box b) -> Box b
>>= a -> Box b
f = a -> Box b
f a
x

-- | Delay inlining a function until late in the game (simplifier phase 0).
delay_inline :: (a -> b) -> a -> b
{-# INLINE [0] delay_inline #-}
delay_inline :: (a -> b) -> a -> b
delay_inline a -> b
f = a -> b
f

-- | `min` inlined in phase 0
delayed_min :: Int -> Int -> Int
{-# INLINE [0] delayed_min #-}
delayed_min :: Int -> Int -> Int
delayed_min Int
m Int
n = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
m Int
n