{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables#-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
module Numerical.Array.Layout.Builder where
import Control.Monad.Primitive ( PrimMonad, PrimState )
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import Numerical.Array.Layout.Base
import Numerical.Array.Layout.Dense as Dense
import Control.Monad.ST (runST)
import Data.Typeable
import qualified Data.Foldable as F
import Data.Traversable as T
import Control.Applicative as A
import Numerical.Data.Vector.Pair
import Numerical.Array.Layout.Sparse
import Data.Vector.Algorithms.Intro as IntroSort
import Data.List (group)
import Numerical.InternalUtils
import Prelude hiding (error)
data BatchInit v = BatchInit { batchInitSize :: !Int
,batchInitKV :: !(Either [v] (IntFun v)) }
deriving (Typeable)
materializeBatchMV :: (PrimMonad m, VGM.MVector mv a) => BatchInit a -> m (mv (PrimState m) a)
materializeBatchMV (BatchInit size (Left ls )) =
newMV <- VGM.new size
_ <- Prelude.mapM (\(ix ,val )-> VGM.unsafeWrite newMV ix val ) (zip [0..] $ take size ls)
return newMV
materializeBatchMV (BatchInit size (Right (IntFun f) )) =
newMV <- VGM.new size
_ <- Prelude.mapM (\ix -> do v <- (f ix) ; VGM.unsafeWrite newMV ix v ) $ take size [0..]
return newMV
newtype AnyMV mv e = AMV (forall s . mv s e )
instance (Show a)=> Show (BatchInit a) where
show (BatchInit size (Left ls) ) | size > 100 = "(BatchInit " ++show size ++
"-- only showing the first 100 elements\n"
++ "(Left "++(show $ take 100 ls ) ++ "))\n"
| otherwise ="(BatchInit " ++show size ++
" (Left "++(show ls ) ++ "))\n"
show (BatchInit size (Right (IntFun f)) ) | size > 100 = "(BatchInit " ++show size ++
"-- only showing the first 100 elements\n"
++ "(Left "++(show $ runST (Prelude.mapM f [0..100]) ) ++ "))\n"
| otherwise ="(BatchInit " ++show size
++ "(Left "++(show $ runST (Prelude.mapM f [0,1..size -1]) ) ++ "))\n"
newtype IntFun a = IntFun (forall m. (PrimMonad m)=> Int -> m a )
deriving (Typeable)
instance Functor IntFun where
fmap f (IntFun g) = IntFun (\x-> g x >>= (\ y -> return (f y)) )
{-# INLINE fmap #-}
instance Functor BatchInit where
{-# INLINE fmap #-}
fmap = \f bival ->
case bival of
(BatchInit size (Left ls))->
BatchInit size (Left (Prelude.map f ls ))
(BatchInit size (Right gfun))->
BatchInit size (Right $ fmap f gfun )
fromListBI :: [a] -> BatchInit a
fromListBI ls = BatchInit (length ls) (Left ls)
fromVectorBI :: VG.Vector v e => v e -> BatchInit e
fromVectorBI v = BatchInit size
(IntFun $
\i -> if i >= size
then error $ " out of bounds index on IntFun of size: " ++ show i
else return $ v VG.! i
size = VG.length v
fromMVectorBI :: (VGM.MVector mv e ) => AnyMV mv e -> BatchInit e
fromMVectorBI (AMV v) = BatchInit size
(IntFun $
\i -> if i >= size
then error $ " out of bounds index on IntFun of size: " ++ show i
else v `VGM.read` i
size = VGM.length v
class Layout form (rank::Nat) => LayoutBuilder form (rank::Nat) | form -> rank where
buildFormatM :: (store~FormatStorageRep form,Buffer store Int ,Buffer store a,PrimMonad m)=>
Index rank -> proxy form -> a
-> Maybe (BatchInit (Index rank ,a))
->m (form, BufferMut store (PrimState m) a )
buildFormatPure:: forall store form rank proxy m a. (LayoutBuilder form (rank::Nat)
,store~FormatStorageRep form,Buffer store Int ,Buffer store a, Monad m ) =>
Index rank -> proxy form -> a -> Maybe (BatchInit (Index rank ,a))
->m (form, BufferPure store a )
buildFormatPure shape prox defaultValue builder =
do res@(!_,!_)<-return $! theComputation
return res
theComputation :: (form,BufferPure store a )
!theComputation = runST $
do (form,buf) <- buildFormatM shape prox defaultValue builder
pureBuff <- VG.unsafeFreeze buf
return (form, pureBuff)
instance LayoutBuilder (Format Direct 'Contiguous ('S 'Z) rep) ('S 'Z) where
buildFormatM (size:* _) _ defaultValue _ =
buf<- VGM.replicate size defaultValue
return (FormatDirectContiguous size,buf)
instance (F.Foldable (Shape r),T.Traversable (Shape r) ,A.Applicative (Shape r))
=> LayoutBuilder (Format Row 'Contiguous r rep) r where
buildFormatM ix _ defaultValue _ =
buf<- VGM.replicate (F.foldl' (*) 0 ix) defaultValue
return (FormatRowContiguous ix,buf)
instance (F.Foldable (Shape r),T.Traversable (Shape r) ,A.Applicative (Shape r))
=> LayoutBuilder (Format Column 'Contiguous r rep) r where
buildFormatM ix _ defaultValue _ =
buf<- VGM.replicate (F.foldl' (*) 0 ix) defaultValue
return (FormatColumnContiguous ix,buf)
isStrictlyMonotonicV ::(VG.Vector v e)=> (e -> e->Ordering)-> v e -> Maybe Int
isStrictlyMonotonicV cmp v = go 0 (VG.length v)
go !i !len | i+1 >= len = Nothing
| (v VG.! i) `lt` (v VG.! (i+1))= go (i+1) len
| otherwise = Just i
lt a b = case cmp a b of
LT -> True
_ -> False
instance (Buffer rep Int)=>LayoutBuilder (Format DirectSparse 'Contiguous ('S 'Z) rep ) ('S 'Z) where
buildFormatM (size:* _) _ _ Nothing = do
mvI <- VGM.new 0
vI <- VG.unsafeFreeze mvI
mvV <- VGM.new 0
return $! (FormatDirectSparseContiguous size 0 vI, mvV)
buildFormatM (size:* _) _ _ (Just builder)= do
mvt@(MVPair (MVLeaf ix) (MVLeaf val)) <- materializeBatchMV $ fmap ( \((ix:*_),v)-> (ix,v)) builder
_<- IntroSort.sortBy (\x y -> compare (fst x) (fst y)) mvt
vIx <- VG.unsafeFreeze ix
optFail <- return $ isStrictlyMonotonicV compare vIx
case optFail of
Nothing -> return (FormatDirectSparseContiguous size 0 vIx, val)
Just ixWrong -> error $ "DirectSparse Index duplication at index "++ show (vIx VG.! ixWrong)
instance (Buffer rep Int) => LayoutBuilder (Format CompressedSparseRow 'Contiguous ('S ('S 'Z)) rep ) ('S ('S 'Z)) where
buildFormatM (x:* y :* _) _ _ Nothing= do
mvi <- VGM.new 0
vi <- VG.unsafeFreeze mvi
mvval <- VGM.new 0
return $
(FormatContiguousCompressedSparseInternal y x vi vi), mvval )
buildFormatM (x:* y :* _) proxyFormat _ (Just builder) = do
mvtup@(MVPair (MVPair (MVLeaf mvectYs) (MVLeaf mvectXs)) (MVLeaf mvectVals))<-
materializeBatchMV $ fmap (\((xix:* yix :* _),val)-> ((yix,xix),val) ) builder
_ <- IntroSort.sortBy (\((y1,x1),_) ((y2,x2),_) -> basicCompareIndex proxyFormat (x1:*y1 :* Nil) (x2:*y2:* Nil) )
vectXs <- unsafeBufferFreeze mvectXs
vectYs <- unsafeBufferFreeze mvectYs
yRunsVect <- return $
VG.replicate (y+1) (0::Int) VG.// computeStarts (computeRunLengths vectYs) 0 y
let xyVect = (VPair (VLeaf vectXs) (VLeaf vectYs))
optFail <- return $
isStrictlyMonotonicV (\(x1,y1) (x2,y2)->basicCompareIndex proxyFormat (x1:*y1:*Nil) (x2:*y2:*Nil))
case optFail of
Nothing -> return $
(FormatContiguousCompressedSparseInternal y x vectXs yRunsVect), mvectVals )
Just i ->
error $ "illegal duplication in CSR builder (x,y) coordinates "
++ show (xyVect VG.! i) ++ " and " ++ show (xyVect VG.! (i+1))
++ "starting at position " ++ show i
computeRunLengths :: (VG.Vector v e, Eq e)=> v e -> [(e,Int)]
computeRunLengths = \y -> fmap (\x ->(head x,length x)) $ group $ VG.toList y
{-# SPECIALIZE INLINE computeStarts :: [(Int,Int)]->Int->Int ->[(Int,Int)] #-}
computeStarts:: (Enum a, Ord a, Num b )=>[(a,b)]-> a -> a -> [(a,b)]
computeStarts [] start end | start <= end = fmap (\x -> (x ,0)) [start..end]
| otherwise = error "bad start end arguments to computeStarts"
computeStarts ls start end | start <= end = go start 0 ls
| otherwise = error "bad start end arguments to computeStarts"
go !posNext preSum [] | posNext <= end = fmap (\x -> (x,preSum)) [posNext .. end]
| otherwise = error "impossible go computeStarts "
go !posNext !preSum gls@((posAt,atSum):rest)
| posNext < posAt= (posNext,preSum): go (succ posNext) preSum gls
| posNext == posAt = (posNext,preSum) : go (succ posNext) (preSum + atSum) rest
| otherwise = error "bad position in prefix stream for computeStarts go, literally unpossible "