module Data.Tensort.Utils.Reduce (reduceTensorStacks) where
import Data.Tensort.Utils.Compose (createTensor)
import Data.Tensort.Utils.Split (splitEvery)
import Data.Tensort.Utils.Types (Memory (..), TensorStack, TensortProps (..))
reduceTensorStacks :: [TensorStack] -> TensortProps -> TensorStack
reduceTensorStacks :: [TensorStack] -> TensortProps -> TensorStack
reduceTensorStacks [TensorStack]
tensorStacks TensortProps
tsProps = do
let newTensorStacks :: [TensorStack]
newTensorStacks = [TensorStack] -> TensortProps -> [TensorStack]
reduceTensorStacksSinglePass [TensorStack]
tensorStacks TensortProps
tsProps
if [TensorStack] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TensorStack]
newTensorStacks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TensortProps -> Int
bytesize TensortProps
tsProps
then Memory -> SortAlg -> TensorStack
createTensor ([TensorStack] -> Memory
TensorMem [TensorStack]
newTensorStacks) (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
else [TensorStack] -> TensortProps -> TensorStack
reduceTensorStacks [TensorStack]
newTensorStacks TensortProps
tsProps
reduceTensorStacksSinglePass :: [TensorStack] -> TensortProps -> [TensorStack]
reduceTensorStacksSinglePass :: [TensorStack] -> TensortProps -> [TensorStack]
reduceTensorStacksSinglePass [TensorStack]
tensorStacks TensortProps
tsProps = ([TensorStack] -> [TensorStack] -> [TensorStack])
-> [TensorStack] -> [[TensorStack]] -> [TensorStack]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [TensorStack] -> [TensorStack] -> [TensorStack]
acc [] (Int -> [TensorStack] -> [[TensorStack]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [TensorStack]
tensorStacks)
where
acc :: [TensorStack] -> [TensorStack] -> [TensorStack]
acc :: [TensorStack] -> [TensorStack] -> [TensorStack]
acc [TensorStack]
tensorStack [TensorStack]
newTensorStacks = [TensorStack]
newTensorStacks [TensorStack] -> [TensorStack] -> [TensorStack]
forall a. [a] -> [a] -> [a]
++ [Memory -> SortAlg -> TensorStack
createTensor ([TensorStack] -> Memory
TensorMem [TensorStack]
tensorStack) (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)]