{-# LANGUAGE CPP, ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Parallel.Eden.Topology
-- Copyright   :  (c) Philipps Universitaet Marburg 2009-2014
-- License     :  BSD-style (see the file LICENSE)
-- 
-- Maintainer  :  eden@mathematik.uni-marburg.de
-- Stability   :  beta
-- Portability :  not portable
--
-- This Haskell module defines topology skeletons for the parallel functional
-- language Eden. Topology skeletons are skeletons that implement a network of
-- processes interconnected by a characteristic communication topology.
--
-- Depends on GHC. Using standard GHC, you will get a threaded simulation of Eden. 
-- Use the forked GHC-Eden compiler from http:\/\/www.mathematik.uni-marburg.de/~eden 
-- for a parallel build.
--
-- Eden Group ( http:\/\/www.mathematik.uni-marburg.de/~eden )


module Control.Parallel.Eden.Topology (
  -- * Skeletons that are primarily characterized by their topology.
    
  -- ** Pipeline skeletons
  -- |
  pipe, pipeRD
  -- ** Ring skeletons
  -- |
  ,ringSimple, ring, ringFl, ringAt, ringFlAt
  -- ** Torus skeleton
  -- |  
  ,torus 
  -- ** The Hypercube skeleton
  -- |

    -- ** The All-To-All skeleton 
  -- |The allToAll skeleton allows distributed data exchange and
  -- transformation including data of all processes. Input and output
  -- are provided as remote data. A typical application is the
  -- distributed transposition of a distributed Martrix.
  ,allToAllRDAt, allToAllRD, parTransposeRDAt, parTransposeRD, allGatherRDAt, allGatherRD     
  -- ** The All-Reduce skeleton   
  -- |The skeleton uses a butterfly topology to reduce the data of
  -- participating processes P in log(|P|) communication stages. Input
  -- and output are provided as remote data.
  ,allReduceRDAt, allReduceRD, allGatherBuFlyRDAt, allGatherBuFlyRD  

  ) where
#if defined( __PARALLEL_HASKELL__ ) || defined (NOT_PARALLEL)
import Control.Parallel.Eden
#else
import Control.Parallel.Eden.EdenConcHs
#endif
import Control.Parallel.Eden.Auxiliary
import Control.Parallel.Eden.Map
import Data.List

   
   
-- |Simple pipe where the parent process creates all pipe processes. The processes communicate their results via the caller process. 
pipe :: forall a . Trans a => 
        [a -> a]    -- ^functions of the pipe
        -> a        -- ^input
        -> a        -- ^output
pipe fs = unLiftRD (pipeRD fs)
  
-- |Process pipe where the processes communicate their Remote Data handles via the caller process but fetch the actual data from their predecessor processes
pipeRD :: forall a . Trans a => 
          [a -> a]    -- ^functions of the pipe
          -> RD a     -- ^remote input
          -> RD a     -- ^remote output
pipeRD fs xs = (last outs) where 
  outs = spawn ps $ lazy $ xs : outs
  ps :: [Process (RD a) (RD a)]
  ps = map (process . liftRD) fs


-- | Simple ring skeleton (tutorial version) 
-- using remote data for providing direct inter-ring communication  
-- without input distribution and output combination  
ringSimple      :: (Trans i, Trans o, Trans r) =>
               (i -> r -> (o,r))  -- ^ ring process function
               -> [i] -> [o]      -- ^ input output mapping
ringSimple f is =  os
  where
    (os,ringOuts)  = unzip (parMap (toRD $ uncurry f)
                                   (zip is $ lazy ringIns))
    ringIns        = rightRotate ringOuts

toRD :: (Trans i, Trans o, Trans r) =>
        ((i,r) -> (o,r))          -- ^ ring process function
        -> ((i, RD r) -> (o, RD r)) -- ^ -- with remote data
toRD  f (i, ringIn)  = (o, release ringOut)
  where (o, ringOut) = f (i, fetch ringIn)

rightRotate    :: [a] -> [a]
rightRotate [] =  []
rightRotate xs =  last xs : init xs

-- | The ringFlAt establishes a ring topology, the ring process functions
-- transform the initial input of a ring process and the input stream from the ring into the 
-- ring output stream and the ring processes' final result. Every ring process  
-- applies its individual function which e.g. allows to route individual offline input into the 
-- ring processes. This version uses explicit placement.
ringFlAt :: (Trans a,Trans b,Trans r) =>
        Places                     -- ^where to put workers
        -> (i -> [a])              -- ^distribute input
        -> ([b] -> o)              -- ^combine output
        -> [(a -> r -> (b,r))]     -- ^ring process fcts
        -> i                       -- ^ring input
        -> o                       -- ^ring output
ringFlAt places distrib combine fs i = combine os where
  (os, ringOuts) = unzip $ spawnFAt places (map (toRD . uncurry) (cycle fs)) 
                                           (zip (distrib i) $ lazy ringIns)
  ringIns        = rightRotate ringOuts

-- | The ringFl establishes a ring topology, the ring process functions
-- transform the initial input of a ring process and the input stream from the ring into the 
-- ring output stream and the ring processes' final result. Every ring process 
-- applies an individual function which e.g. allows to route individual offline input into the 
-- ring processes. Use ringFlAt if explicit placement is desired.
ringFl  :: (Trans a,Trans b,Trans r) =>
        (i -> [a])                 -- ^distribute input
        -> ([b] -> o)              -- ^combine output
        -> [(a -> r -> (b,r))]     -- ^ring process fcts
        -> i                       -- ^ring input
        -> o                       -- ^ring output
ringFl = ringFlAt [0]

-- | Skeleton @ringAt@ establishes a ring topology, the ring process function
-- transforms the initial input of a ring process and the input stream from the ring into the 
-- ring output stream and the ring processes' final result. The 
-- same function is used by every ring process. Use ringFlAt
-- if you need different functions in the processes. This version uses explicit placement.
ringAt :: (Trans a,Trans b,Trans r) =>
        Places                    -- ^where to put workers
        -> (i -> [a])             -- ^distribute input
        -> ([b] -> o)             -- ^combine output
        -> (a -> r -> (b,r))      -- ^ring process fct
        -> i                      -- ^ring input
        -> o                      -- ^ring output 
ringAt places distrib combine f i =
  ringFlAt places distrib combine [f] i 

-- | The ring establishes a ring topology, the ring process function
-- transforms the initial input of a ring process and the input stream from the ring into the 
-- ring output stream and the ring processes final result. The 
-- same function is used by every ring process. Use ringFl
-- if you need different functions in the processes. Use ringAt if 
-- explicit placement is desired.
ring :: (Trans a,Trans b,Trans r) =>
        (i -> [a])                -- ^distribute input
        -> ([b] -> o)             -- ^combine output
        -> (a -> r -> (b,r))      -- ^ring process fct
        -> i                      -- ^ring input
        -> o                      -- ^ring output
ring = ringAt [0]



-- | Parallel torus skeleton (tutorial version) with stream rotation in 2 directions: initial inputs for each torus element are given. The node function is used on each torus element to transform the initial input and a stream of inputs from each direction to a stream of outputs to each direction. Each torus input should have the same size in both dimensions, otherwise the smaller input will determine the size of the torus.
torus :: (Trans a, Trans b, Trans c, Trans d) =>
         (c -> [a] -> [b] -> (d,[a],[b])) -- ^ node function
         -> [[c]] -> [[d]]                -- ^ input-output mapping
torus f inss = outss
  where
    t_outss = spawnPss (repeat (repeat (ptorus f))) t_inss    -- optimised
    (outss,outssA,outssB) = unzip3 (map unzip3 t_outss)
    inssA   = map rightRotate outssA
    inssB   = rightRotate outssB
    t_inss  = zipWith3 lazyzip3 inss (lazy inssA) (lazy inssB)
    lazyzip3 as bs cs = zip3 as (lazy bs) (lazy cs)

-- each individual process of the torus (tutorial version)
ptorus :: (Trans a, Trans b, Trans c, Trans d) =>
          (c -> [a] -> [b] -> (d,[a],[b])) ->
          Process (c,RD [a],RD [b])
                  (d,RD [a],RD [b])
ptorus f 
 = process (\ (fromParent,          inA,           inB) ->
               let (toParent, outA, outB) = f fromParent inA' inB'
                   (inA',inB')            = fetch2 inA inB
               in  (toParent,   release outA,  release outB))

-- | The skeleton creates as many processes as elements in the input list (@np@). 
-- The processes get all-to-all connected, each process input is transformed to 
-- @np@ intermediate values by the first parameter function, where the @i@-th value
-- will be send to process @i@. The second transformation function combines the initial
-- input and the @np@ received intermediate values to the final output.
allToAllRD :: forall a b i. (Trans a, Trans b, Trans i) 
                => (Int -> a -> [i]) -- ^transform before bcast (num procs, input, sync-data out)
                -> (a -> [i] ->b)    -- ^transform after bcast (input, sync-data in, output)
                -> [RD a]            -- ^remote input for each process
                -> [RD b]            -- ^remote output for each process
allToAllRD = allToAllRDAt [0]

-- | The skeleton creates as many processes as elements in the input list (@np@). 
-- The processes get all-to-all connected, each process input is transformed to 
-- @np@ intermediate values by the first parameter function, where the @i@-th value
-- will be send to process @i@. The second transformation function combines the initial
-- input and the @np@ received intermediate values to the final output.
allToAllRDAt :: forall a b i. (Trans a, Trans b, Trans i) 
                => Places            -- ^where to instantiate
                -> (Int -> a -> [i]) -- ^transform before bcast (num procs, input, sync-data out)
                -> (a -> [i] ->b)    -- ^transform after bcast (input, sync-data in, output)
                -> [RD a]            -- ^remote input for each process
                -> [RD b]            -- ^remote output for each process
allToAllRDAt places t1 t2 xs = res where
  n = length xs           --same amount of procs as #xs
  (res,iss) = n `pseq` unzip $ parMapAt places (uncurry p) inp
  inp       = zip xs $ lazy $ transpose iss

  p :: RD a-> [RD i]-> (RD b,[RD i])
  p xRD theirIs = (resF theirIs, myIsF x) where
    x      = fetch xRD
    myIsF  = releaseAll . t1 n
    resF   = release . t2 x . fetchAll

-- works similar for splitIntoN and unsplit (concat)??? 
-- |Parallel transposition for matrizes which are row-wise round robin distributed among the machines, the transposed result matrix is also row-wise round robin distributed.
parTransposeRD :: Trans b 
                  => [RD [[b]]] -- ^input list of remote partial matrizes
                  -> [RD [[b]]] -- ^output list of remote partial matrizes
parTransposeRD = parTransposeRDAt [0]


-- works similar for splitIntoN and unsplit (concat)??? 
-- |Parallel transposition for matrizes which are row-wise round robin distributed among the machines, the transposed result matrix is also row-wise round robin distributed.
parTransposeRDAt :: Trans b 
                    => Places
                    -> [RD [[b]]] -- ^input list of remote partial matrizes
                    -> [RD [[b]]] -- ^output list of remote partial matrizes
parTransposeRDAt places = allToAllRDAt places (\ n -> unshuffle n . transpose)
                                              (\ _ -> map shuffle . transpose)

-- | Performs an all-gather using all to all comunication (based on allToAllRDAt). 
-- The initial transformation is applied in  the processes to obtain the values that will be reduced.
-- The final combine function is used to create a processes outputs from the initial input and the 
-- gathered values.
allGatherRD :: forall a b c. (Trans a, Trans b, Trans c)
               => (a -> b)         -- ^initial transform function
               -> (a -> [b] -> c)  -- ^final combine function
               -> [RD a] -> [RD c]
allGatherRD = allGatherRDAt [0]

-- | Performs an all-gather using all to all comunication (based on allToAllRDAt).
-- The initial transformation is applied in  the processes to obtain the values that will be reduced.
-- The final combine function is used to create a processes outputs from the initial input and the 
-- gathered values.
allGatherRDAt :: forall a b c. (Trans a, Trans b, Trans c)
                      => Places           -- ^where to instantiate
                      -> (a -> b)         -- ^initial transform function
                      -> (a -> [b] -> c)  -- ^final combine function
                      -> [RD a] -> [RD c]
allGatherRDAt places t1 t2 = allToAllRDAt places t1' t2 where
  t1' :: Int -> a -> [b]
  t1' n x = replicate n (t1 x)


-- | Performs an all-reduce with the reduce function using a butterfly scheme.
-- The initial transformation is applied in the processes to obtain the values
-- that will be reduced. The final combine function is used to create a processes outputs.
-- result from the initial input and the reduced value.
allReduceRD :: forall a b c. (Trans a, Trans b, Trans c)
               => (a -> b)       -- ^initial transform function
               -> (b -> b -> b)  -- ^reduce function
               -> (a -> b -> c)  -- ^final combine function
               -> [RD a] -> [RD c]
allReduceRD = allReduceRDAt [0] where


-- | Performs an all-reduce with the reduce function using a butterfly scheme.
-- The initial transformation is applied in the processes to obtain the values
-- that will be reduced. The final combine function is used to create a processes output.
-- result from the initial input and the reduced value.
allReduceRDAt :: forall a b c. (Trans a, Trans b, Trans c)
               => Places         -- ^where to instantiate
               -> (a -> b)       -- ^initial transform function
               -> (b -> b -> b)  -- ^reduce function
               -> (a -> b -> c)  -- ^final combine function
               -> [RD a] -> [RD c]
allReduceRDAt places initF redF resF rdAs = rdCs where
  steps = (ceiling . logBase 2 . fromIntegral . length) rdAs
  (rdBss,rdCs) = steps `pseq` unzip $ parMapAt places (uncurry p) inp
  inp          = zip rdAs $ lazy $ buflyF $ transposeRt rdBss
  buflyF       = transposeRt . shiftFlipF steps . fillF steps
  
  p :: RD a -> [Maybe (Both (RD b))] -> ([RD b], RD c)
  p rdA rdBs = (rdBs'', res) where
    res      = release $ resF a $ reduced !! steps
    rdBs''   = (releaseAll . take steps . lazy) reduced
    reduced  = scanl redF' b toReduce
    toReduce = fetchAll' rdBs'
    rdBs'    = zipWith (flip maybe Left) (map Right rdBs'') rdBs
    b        = initF a
    a        = fetch rdA
  
  --List encoding:
  -- Right: No Partner present, use value b without reduction
  -- Left: RD value comes from partner, then inner encoding:
  --       Right: Partner is positioned at the right hand side
  --       Left: Partner is positioned at the left hand side
  -- needed such that redF does not need to be commutativie
  redF' :: b -> Either (Both b) b -> b
  redF' _ (Right b) = b
  redF' b (Left (Right b')) = redF b b'
  redF' b (Left (Left b'))  = redF b' b

type Both a = Either a a

--custom fetchAll inside nested Eithers
fetchAll' :: Trans a => [Either (Both (RD a)) (RD a)] -> [Either (Both a) a]
fetchAll' = runPA . mapM fetchPA' where
  fetchPA' (Left (Left rda))  = do a <- fetchPA rda
                                   return $ Left $ Left a
  fetchPA' (Left (Right rda)) = do a <- fetchPA rda
                                   return $ Left $ Right a
  fetchPA' (Right rda)        = do a <- fetchPA rda
                                   return $ Right a

--Fill rows to the power of ldn with Nothing, map Just to the rest
fillF :: Int -> [[a]] -> [[Maybe a]]
fillF ldn ass = map fillRow ass where
  n = 2 ^ ldn
  fillRow as = take n $ (map Just as) ++ (repeat Nothing)

shiftFlipF :: Int -> [[Maybe a]] -> [[Maybe (Both a)]]
shiftFlipF ldn rdBss = zipWith shiftFlipRow [1..ldn] rdBss  where  
  shiftFlipRow ldi rdBs = (shuffle . flipAtHalfF . unshuffle i) rdBs where
    i = 2 ^ ldi
    flipAtHalfF xs = let (xs1, xs2) = splitAt (i`div`2) xs 
                     in map (map (fmap Right)) xs2 ++ map (map (fmap Left)) xs1


-- | Performs an all-gather using a butterfly scheme (based on allReduceRDAt). 
-- The initial transformation is applied in  the processes to obtain the values that will be reduced.
-- The final combine function is used to create a processes outputs from the initial input and the 
-- gathered values.
allGatherBuFlyRD :: forall a b c. (Trans a, Trans b, Trans c)
                    => (a -> b)         -- ^initial transform function
                    -> (a -> [b] -> c)  -- ^final combine function
                    -> [RD a] -> [RD c]
allGatherBuFlyRD = allGatherBuFlyRDAt [0]

-- | Performs an all-gather using a butterfly scheme (based on allReduceRDAt). 
-- The initial transformation is applied in  the processes to obtain the values that will be reduced.
-- The final combine function is used to create a processes outputs from the initial input and the 
-- gathered values.
allGatherBuFlyRDAt :: forall a b c. (Trans a, Trans b, Trans c)
                      => Places           -- ^where to instantiate
                      -> (a -> b)         -- ^initial transform function
                      -> (a -> [b] -> c)  -- ^final combine function
                      -> [RD a] -> [RD c]
allGatherBuFlyRDAt places t1 t2 = allReduceRDAt places t1' (++) t2 where
  t1' :: a -> [b]
  t1' a = [t1 a]