module Data.Tensort.Utils.Compose
  ( createInitialTensors,
    createTensor,
  )
where

import Data.Tensort.Utils.Split (splitEvery)
import Data.Tensort.Utils.Types (Byte, Memory (..), Record, SortAlg, Sortable (..), Tensor, TensortProps (..), fromSortRec, Bit)

-- | Convert a list of Bytes to a list of TensorStacks.

-- | This is accomplished by making a Tensor for each Byte, converting that
--   Tensor into a TensorStack (these are equivalent terms - see type
--   definitions for more info) and collating the TensorStacks into a list

-- | ==== __Examples__
--  >>> createInitialTensors [[2,4],[6,8],[1,3],[5,7]] 2
--  [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]
createInitialTensors :: [Byte] -> TensortProps -> [Tensor]
createInitialTensors :: [Byte] -> TensortProps -> [Tensor]
createInitialTensors [Byte]
bytes TensortProps
tsProps = ([Byte] -> [Tensor] -> [Tensor])
-> [Tensor] -> [[Byte]] -> [Tensor]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [Byte] -> [Tensor] -> [Tensor]
acc [] (Int -> [Byte] -> [[Byte]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [Byte]
bytes)
  where
    acc :: [Byte] -> [Tensor] -> [Tensor]
    acc :: [Byte] -> [Tensor] -> [Tensor]
acc [Byte]
byte [Tensor]
tensorStacks = [Tensor]
tensorStacks [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [[Byte] -> SortAlg -> Tensor
getTensorFromBytes [Byte]
byte (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)]

-- | Create a Tensor from a Memory
--   Aliases to getTensorFromBytes for ByteMem and getTensorFromTensors for
--   TensorMem
createTensor :: Memory -> SortAlg -> Tensor
createTensor :: Memory -> SortAlg -> Tensor
createTensor (ByteMem [Byte]
bytes) SortAlg
subAlg = [Byte] -> SortAlg -> Tensor
getTensorFromBytes [Byte]
bytes SortAlg
subAlg
createTensor (TensorMem [Tensor]
tensors) SortAlg
subAlg = [Tensor] -> SortAlg -> Tensor
getTensorFromTensors [Tensor]
tensors SortAlg
subAlg

-- | Convert a list of Bytes to a Tensor

-- | We do this by loading the list of Bytes into the new Tensor's Memory
--   and adding a sorted Register containing References to each Byte in Memory

-- | Each Record contains an Address pointing to the index of the referenced
--   Byte and a TopBit containing the value of the last (i.e. highest) Bit in
--   the referenced Byte

-- | The Register is sorted by the TopBits of each Record

-- | ==== __Examples__
--  >>> getTensorFromBytes [[2,4,6,8],[1,3,5,7]]
--  ([(1,7),(0,8)],ByteMem [[2,4,6,8],[1,3,5,7]])
getTensorFromBytes :: [Byte] -> SortAlg -> Tensor
getTensorFromBytes :: [Byte] -> SortAlg -> Tensor
getTensorFromBytes [Byte]
bytes SortAlg
subAlg = do
  let register :: [Record]
register = [Byte] -> [Record] -> Int -> [Record]
acc [Byte]
bytes [] Int
0
  let register' :: [Record]
register' = Sortable -> [Record]
fromSortRec (SortAlg
subAlg ([Record] -> Sortable
SortRec [Record]
register))
  ([Record]
register', [Byte] -> Memory
ByteMem [Byte]
bytes)
  where
    acc :: [Byte] -> [Record] -> Int -> [Record]
    acc :: [Byte] -> [Record] -> Int -> [Record]
acc [] [Record]
register Int
_ = [Record]
register
    acc ([] : [Byte]
remainingBytes) [Record]
register Int
i = [Byte] -> [Record] -> Int -> [Record]
acc [Byte]
remainingBytes [Record]
register (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    acc (Byte
byte : [Byte]
remainingBytes) [Record]
register Int
i = [Byte] -> [Record] -> Int -> [Record]
acc [Byte]
remainingBytes ([Record]
register [Record] -> [Record] -> [Record]
forall a. [a] -> [a] -> [a]
++ [(Int
i, Byte -> Int
forall a. HasCallStack => [a] -> a
last Byte
byte)]) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | Create a TensorStack with the collated and sorted References from the
--   Tensors as the Register and the original Tensors as the data

-- | ==== __Examples__
-- >>> getTensorFromTensors [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(1,14),(0,17)],ByteMem [[16,17],[12,14]])]
-- ([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(1,14),(0,17)],ByteMem [[16,17],[12,14]])])
getTensorFromTensors :: [Tensor] -> SortAlg -> Tensor
getTensorFromTensors :: [Tensor] -> SortAlg -> Tensor
getTensorFromTensors [Tensor]
tensors SortAlg
subAlg = (Sortable -> [Record]
fromSortRec (SortAlg
subAlg ([Record] -> Sortable
SortRec ([Tensor] -> [Record]
getRegisterFromTensors [Tensor]
tensors))), [Tensor] -> Memory
TensorMem [Tensor]
tensors)

-- | For each Tensor, produces a Record by combining the top bit of the
--  Tensor with an index value for its Address

-- | Note that this output is not sorted. Sorting is done in the
--   getTensorFromTensors function

-- | ==== __Examples__
-- >>> getRegisterFromTensors [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]]),([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]
-- [(0,18),(1,17),(2,7),(3,8)]
getRegisterFromTensors :: [Tensor] -> [Record]
getRegisterFromTensors :: [Tensor] -> [Record]
getRegisterFromTensors [Tensor]
tensors = [Tensor] -> [Record] -> [Record]
acc [Tensor]
tensors []
  where
    acc :: [Tensor] -> [Record] -> [Record]
    acc :: [Tensor] -> [Record] -> [Record]
acc [] [Record]
records = [Record]
records
    acc (([], Memory
_) : [Tensor]
remainingTensors) [Record]
records = [Tensor] -> [Record] -> [Record]
acc [Tensor]
remainingTensors [Record]
records
    acc (Tensor
tensor : [Tensor]
remainingTensors) [Record]
records = [Tensor] -> [Record] -> [Record]
acc [Tensor]
remainingTensors ([Record]
records [Record] -> [Record] -> [Record]
forall a. [a] -> [a] -> [a]
++ [(Int
i, Tensor -> Int
getTopBitFromTensorStack Tensor
tensor)])
      where
        i :: Int
i = [Record] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Record]
records

-- | Get the top Bit from a TensorStack

-- | The top Bit is the last Bit in the last Byte referenced in the last record
--   of the Tensor referenced in the last record of the last Tensor of...
--   and so on until you reach the top level of the TensorStack

-- | This is also expected to be the highest value in the TensorStack

-- | ==== __Examples__
-- >>> getTopBitFromTensorStack (([(0,28),(1,38)],TensorMem [([(0,27),(1,28)],TensorMem [([(0,23),(1,27)],ByteMem [[21,23],[25,27]]),([(0,24),(1,28)],ByteMem [[22,24],[26,28]])]),([(1,37),(0,38)],TensorMem [([(0,33),(1,38)],ByteMem [[31,33],[35,38]]),([(0,34),(1,37)],ByteMem [[32,14],[36,37]])])]))
-- 38
getTopBitFromTensorStack :: Tensor -> Bit
getTopBitFromTensorStack :: Tensor -> Int
getTopBitFromTensorStack ([Record]
register, Memory
_) = Record -> Int
forall a b. (a, b) -> b
snd ([Record] -> Record
forall a. HasCallStack => [a] -> a
last [Record]
register)