module Data.Tensort.Utils.Reduce (reduceTensorStacks) where

import Data.Tensort.Utils.Compose (createTensor)
import Data.Tensort.Utils.Split (splitEvery)
import Data.Tensort.Utils.Types (Memory (..), TensorStack, TensortProps (..))

-- | Take a list of TensorStacks and group them together in new
--   TensorStacks, each containing bytesize number of Tensors (former
--   TensorStacks), until the number of TensorStacks is equal to the bytesize

-- | The Registers of the new TensorStacks are bubblesorted, as usual

-- | ==== __Examples__
-- >>> reduceTensorStacks [([(0, 33), (1, 38)], ByteMem [[31, 33], [35, 38]]), ([(0, 34), (1, 37)], ByteMem [[32, 14], [36, 37]]), ([(0, 23), (1, 27)], ByteMem [[21, 23], [25, 27]]), ([(0, 24), (1, 28)], ByteMem [[22, 24], [26, 28]]),([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]]),([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])] 2
-- ([(1,18),(0,38)],TensorMem [([(0,28),(1,38)],TensorMem [([(0,27),(1,28)],TensorMem [([(0,23),(1,27)],ByteMem [[21,23],[25,27]]),([(0,24),(1,28)],ByteMem [[22,24],[26,28]])]),([(1,37),(0,38)],TensorMem [([(0,33),(1,38)],ByteMem [[31,33],[35,38]]),([(0,34),(1,37)],ByteMem [[32,14],[36,37]])])]),([(0,8),(1,18)],TensorMem [([(0,7),(1,8)],TensorMem [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]),([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]])])])])
reduceTensorStacks :: [TensorStack] -> TensortProps -> TensorStack
reduceTensorStacks :: [TensorStack] -> TensortProps -> TensorStack
reduceTensorStacks [TensorStack]
tensorStacks TensortProps
tsProps = do
  let newTensorStacks :: [TensorStack]
newTensorStacks = [TensorStack] -> TensortProps -> [TensorStack]
reduceTensorStacksSinglePass [TensorStack]
tensorStacks TensortProps
tsProps
  if [TensorStack] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TensorStack]
newTensorStacks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TensortProps -> Int
bytesize TensortProps
tsProps
    then Memory -> SortAlg -> TensorStack
createTensor ([TensorStack] -> Memory
TensorMem [TensorStack]
newTensorStacks) (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
    else [TensorStack] -> TensortProps -> TensorStack
reduceTensorStacks [TensorStack]
newTensorStacks TensortProps
tsProps

-- | Take a list of TensorStacks  and group them together in new
--   TensorStacks each containing bytesize number of Tensors (former TensorStacks)

-- | The Registers of the new TensorStacks are bubblesorted, as usual

-- | ==== __Examples__
-- >>> reduceTensorStacksSinglePass [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]]),([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])] 2
-- [([(0,7),(1,8)],TensorMem [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]),([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]])])]
reduceTensorStacksSinglePass :: [TensorStack] -> TensortProps -> [TensorStack]
reduceTensorStacksSinglePass :: [TensorStack] -> TensortProps -> [TensorStack]
reduceTensorStacksSinglePass [TensorStack]
tensorStacks TensortProps
tsProps = ([TensorStack] -> [TensorStack] -> [TensorStack])
-> [TensorStack] -> [[TensorStack]] -> [TensorStack]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [TensorStack] -> [TensorStack] -> [TensorStack]
acc [] (Int -> [TensorStack] -> [[TensorStack]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [TensorStack]
tensorStacks)
  where
    acc :: [TensorStack] -> [TensorStack] -> [TensorStack]
    acc :: [TensorStack] -> [TensorStack] -> [TensorStack]
acc [TensorStack]
tensorStack [TensorStack]
newTensorStacks = [TensorStack]
newTensorStacks [TensorStack] -> [TensorStack] -> [TensorStack]
forall a. [a] -> [a] -> [a]
++ [Memory -> SortAlg -> TensorStack
createTensor ([TensorStack] -> Memory
TensorMem [TensorStack]
tensorStack) (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)]