#if __GLASGOW_HASKELL__ >= 707
#define USE_TYPE_LITS 1
#endif
#ifndef MIN_VERSION_reflection
#define MIN_VERSION_reflection(x,y,z) 1
#endif
module Linear.V
( V(V,toVector)
#ifdef MIN_VERSION_template_haskell
, int
#endif
, dim
, Dim(..)
, reifyDim
, reifyVector
#if (MIN_VERSION_reflection(2,0,0)) && __GLASGOW_HASKELL__ >= 708
, reifyDimNat
, reifyVectorNat
#endif
, fromVector
) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Control.DeepSeq (NFData)
import Control.Monad
import Control.Monad.Fix
import Control.Monad.Zip
import Control.Lens as Lens
import Data.Binary as Binary
import Data.Bytes.Serial
import Data.Data
import Data.Distributive
import Data.Foldable as Foldable
import Data.Functor.Bind
import Data.Functor.Classes
import Data.Functor.Rep as Rep
#if __GLASGOW_HASKELL__ < 708
import Data.Proxy
#endif
import Data.Reflection as R
import Data.Serialize as Cereal
#if __GLASGOW_HASKELL__ < 710
import Data.Traversable (sequenceA)
#endif
import Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Generic.Mutable as M
import Foreign.Ptr
import Foreign.Storable
#ifdef USE_TYPE_LITS
import GHC.TypeLits
#endif
#if __GLASGOW_HASKELL__ >= 702
import GHC.Generics (Generic)
#endif
#if __GLASGOW_HASKELL__ >= 707
import GHC.Generics (Generic1)
#endif
#if !(MIN_VERSION_reflection(1,3,0)) && defined(MIN_VERSION_template_haskell)
import Language.Haskell.TH
#endif
import Linear.Epsilon
import Linear.Metric
import Linear.Vector
#ifdef HLINT
#endif
class Dim n where
reflectDim :: p n -> Int
#if __GLASGOW_HASKELL__ >= 707
type role V nominal representational
#endif
newtype V n a = V { toVector :: V.Vector a } deriving (Eq,Ord,Show,Read,Typeable,NFData
, Generic
#if __GLASGOW_HASKELL__ >= 707
,Generic1
#endif
)
dim :: forall n a. Dim n => V n a -> Int
dim _ = reflectDim (Proxy :: Proxy n)
#ifdef USE_TYPE_LITS
instance KnownNat n => Dim (n :: Nat) where
reflectDim = fromInteger . natVal
#endif
data ReifiedDim (s :: *)
retagDim :: (Proxy s -> a) -> proxy (ReifiedDim s) -> a
retagDim f _ = f Proxy
instance Reifies s Int => Dim (ReifiedDim s) where
reflectDim = retagDim reflect
#if (MIN_VERSION_reflection(2,0,0)) && __GLASGOW_HASKELL__ >= 708
reifyDimNat :: Int -> (forall (n :: Nat). KnownNat n => Proxy n -> r) -> r
reifyDimNat i f = R.reifyNat (fromIntegral i) f
reifyVectorNat :: forall a r. Vector a -> (forall (n :: Nat). KnownNat n => V n a -> r) -> r
reifyVectorNat v f = reifyNat (fromIntegral $ V.length v) $ \(Proxy :: Proxy n) -> f (V v :: V n a)
#endif
reifyDim :: Int -> (forall (n :: *). Dim n => Proxy n -> r) -> r
reifyDim i f = R.reify i (go f) where
go :: Reifies n Int => (Proxy (ReifiedDim n) -> a) -> proxy n -> a
go g _ = g Proxy
reifyVector :: forall a r. Vector a -> (forall (n :: *). Dim n => V n a -> r) -> r
reifyVector v f = reifyDim (V.length v) $ \(Proxy :: Proxy n) -> f (V v :: V n a)
instance Dim n => Dim (V n a) where
reflectDim _ = reflectDim (Proxy :: Proxy n)
instance Functor (V n) where
fmap f (V as) = V (fmap f as)
instance FunctorWithIndex Int (V n) where
imap f (V as) = V (Lens.imap f as)
instance Foldable (V n) where
fold (V as) = fold as
foldMap f (V as) = foldMap f as
foldr f z (V as) = V.foldr f z as
foldl f z (V as) = V.foldl f z as
#if __GLASGOW_HASKELL__ >= 706
foldr' f z (V as) = V.foldr' f z as
foldl' f z (V as) = V.foldl' f z as
#endif
foldr1 f (V as) = V.foldr1 f as
foldl1 f (V as) = V.foldl1 f as
#if __GLASGOW_HASKELL__ >= 710
length (V as) = V.length as
null (V as) = V.null as
toList (V as) = V.toList as
elem a (V as) = V.elem a as
maximum (V as) = V.maximum as
minimum (V as) = V.minimum as
sum (V as) = V.sum as
product (V as) = V.product as
#endif
instance FoldableWithIndex Int (V n) where
ifoldMap f (V as) = ifoldMap f as
instance Traversable (V n) where
traverse f (V as) = V <$> traverse f as
instance TraversableWithIndex Int (V n) where
itraverse f (V as) = V <$> itraverse f as
instance Apply (V n) where
V as <.> V bs = V (V.zipWith id as bs)
instance Dim n => Applicative (V n) where
pure = V . V.replicate (reflectDim (Proxy :: Proxy n))
V as <*> V bs = V (V.zipWith id as bs)
instance Bind (V n) where
V as >>- f = V $ generate (V.length as) $ \i ->
toVector (f (as `unsafeIndex` i)) `unsafeIndex` i
instance Dim n => Monad (V n) where
return = V . V.replicate (reflectDim (Proxy :: Proxy n))
V as >>= f = V $ generate (reflectDim (Proxy :: Proxy n)) $ \i ->
toVector (f (as `unsafeIndex` i)) `unsafeIndex` i
instance Dim n => Additive (V n) where
zero = pure 0
liftU2 f (V as) (V bs) = V (V.zipWith f as bs)
liftI2 f (V as) (V bs) = V (V.zipWith f as bs)
instance (Dim n, Num a) => Num (V n a) where
V as + V bs = V $ V.zipWith (+) as bs
V as V bs = V $ V.zipWith () as bs
V as * V bs = V $ V.zipWith (*) as bs
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger
instance (Dim n, Fractional a) => Fractional (V n a) where
recip = fmap recip
V as / V bs = V $ V.zipWith (/) as bs
fromRational = pure . fromRational
instance (Dim n, Floating a) => Floating (V n a) where
pi = pure pi
exp = fmap exp
sqrt = fmap sqrt
log = fmap log
V as ** V bs = V $ V.zipWith (**) as bs
logBase (V as) (V bs) = V $ V.zipWith logBase as bs
sin = fmap sin
tan = fmap tan
cos = fmap cos
asin = fmap asin
atan = fmap atan
acos = fmap acos
sinh = fmap sinh
tanh = fmap tanh
cosh = fmap cosh
asinh = fmap asinh
atanh = fmap atanh
acosh = fmap acosh
instance Dim n => Distributive (V n) where
distribute f = V $ V.generate (reflectDim (Proxy :: Proxy n)) $ \i -> fmap (\(V v) -> unsafeIndex v i) f
instance (Dim n, Storable a) => Storable (V n a) where
sizeOf _ = reflectDim (Proxy :: Proxy n) * sizeOf (undefined:: a)
alignment _ = alignment (undefined :: a)
poke ptr (V xs) = Foldable.forM_ [0..reflectDim (Proxy :: Proxy n)1] $ \i ->
pokeElemOff ptr' i (unsafeIndex xs i)
where ptr' = castPtr ptr
peek ptr = V <$> generateM (reflectDim (Proxy :: Proxy n)) (peekElemOff ptr')
where ptr' = castPtr ptr
instance (Dim n, Epsilon a) => Epsilon (V n a) where
nearZero = nearZero . quadrance
instance Dim n => Metric (V n) where
dot (V a) (V b) = V.sum $ V.zipWith (*) a b
fromVector :: forall n a. Dim n => Vector a -> Maybe (V n a)
fromVector v
| V.length v == reflectDim (Proxy :: Proxy n) = Just (V v)
| otherwise = Nothing
#if !(MIN_VERSION_reflection(1,3,0)) && defined(MIN_VERSION_template_haskell)
data Z
data D (n :: *)
data SD (n :: *)
data PD (n :: *)
instance Reifies Z Int where
reflect _ = 0
retagD :: (Proxy n -> a) -> proxy (D n) -> a
retagD f _ = f Proxy
retagSD :: (Proxy n -> a) -> proxy (SD n) -> a
retagSD f _ = f Proxy
retagPD :: (Proxy n -> a) -> proxy (PD n) -> a
retagPD f _ = f Proxy
instance Reifies n Int => Reifies (D n) Int where
reflect = (\n -> n+n) <$> retagD reflect
instance Reifies n Int => Reifies (SD n) Int where
reflect = (\n -> n+n+1) <$> retagSD reflect
instance Reifies n Int => Reifies (PD n) Int where
reflect = (\n -> n+n1) <$> retagPD reflect
int :: Int -> TypeQ
int n = case quotRem n 2 of
(0, 0) -> conT ''Z
(q,1) -> conT ''PD `appT` int q
(q, 0) -> conT ''D `appT` int q
(q, 1) -> conT ''SD `appT` int q
_ -> error "ghc is bad at math"
#endif
instance Dim n => Representable (V n) where
type Rep (V n) = Int
tabulate = V . generate (reflectDim (Proxy :: Proxy n))
index (V xs) i = xs V.! i
type instance Index (V n a) = Int
type instance IxValue (V n a) = a
instance Ixed (V n a) where
ix i f (V as)
| i < 0 || i >= V.length as = pure $ V as
| otherwise = f (as ! i) <&> \a -> V $ as V.// [(i, a)]
instance Dim n => MonadZip (V n) where
mzip (V as) (V bs) = V $ V.zip as bs
mzipWith f (V as) (V bs) = V $ V.zipWith f as bs
instance Dim n => MonadFix (V n) where
mfix f = tabulate $ \r -> let a = Rep.index (f a) r in a
instance Each (V n a) (V n b) a b where
each = traverse
instance (Bounded a, Dim n) => Bounded (V n a) where
minBound = pure minBound
maxBound = pure maxBound
vConstr :: Constr
vConstr = mkConstr vDataType "variadic" [] Prefix
vDataType :: DataType
vDataType = mkDataType "Linear.V.V" [vConstr]
#if __GLASGOW_HASKELL__ >= 708
#define Typeable1 Typeable
#endif
instance (Typeable1 (V n), Typeable (V n a), Dim n, Data a) => Data (V n a) where
gfoldl f z (V as) = z (V . fromList) `f` V.toList as
toConstr _ = vConstr
gunfold k z c = case constrIndex c of
1 -> k (z (V . fromList))
_ -> error "gunfold"
dataTypeOf _ = vDataType
dataCast1 f = gcast1 f
instance Dim n => Serial1 (V n) where
serializeWith = traverse_
deserializeWith f = sequenceA $ pure f
instance (Dim n, Serial a) => Serial (V n a) where
serialize = traverse_ serialize
deserialize = sequenceA $ pure deserialize
instance (Dim n, Binary a) => Binary (V n a) where
put = serializeWith Binary.put
get = deserializeWith Binary.get
instance (Dim n, Serialize a) => Serialize (V n a) where
put = serializeWith Cereal.put
get = deserializeWith Cereal.get
instance Dim n => Eq1 (V n) where eq1 = (==)
instance Dim n => Ord1 (V n) where compare1 = compare
instance Dim n => Show1 (V n) where showsPrec1 = showsPrec
instance Dim n => Read1 (V n) where readsPrec1 = readsPrec
data instance U.Vector (V n a) = V_VN !Int !(U.Vector a)
data instance U.MVector s (V n a) = MV_VN !Int !(U.MVector s a)
instance (Dim n, U.Unbox a) => U.Unbox (V n a)
instance (Dim n, U.Unbox a) => M.MVector U.MVector (V n a) where
basicLength (MV_VN n _) = n
basicUnsafeSlice m n (MV_VN _ v) = MV_VN n (M.basicUnsafeSlice (d*m) (d*n) v)
where d = reflectDim (Proxy :: Proxy n)
basicOverlaps (MV_VN _ v) (MV_VN _ u) = M.basicOverlaps v u
basicUnsafeNew n = liftM (MV_VN n) (M.basicUnsafeNew (d*n))
where d = reflectDim (Proxy :: Proxy n)
basicUnsafeRead (MV_VN _ v) i =
liftM V $ V.generateM d (\j -> M.basicUnsafeRead v (d*i+j))
where d = reflectDim (Proxy :: Proxy n)
basicUnsafeWrite (MV_VN _ v0) i0 (V vn0) = let d0 = V.length vn0 in go v0 vn0 d0 (d0*i0) i0
where
go v vn d o i
| i >= d = return ()
| otherwise = do
a <- G.basicUnsafeIndexM vn i
M.basicUnsafeWrite v o a
go v vn d (o+1) (i1)
#if MIN_VERSION_vector(0,11,0)
basicInitialize (MV_VN _ v) = M.basicInitialize v
#endif
instance (Dim n, U.Unbox a) => G.Vector U.Vector (V n a) where
basicUnsafeFreeze (MV_VN n v) = liftM ( V_VN n) (G.basicUnsafeFreeze v)
basicUnsafeThaw ( V_VN n v) = liftM (MV_VN n) (G.basicUnsafeThaw v)
basicLength ( V_VN n _) = n
basicUnsafeSlice m n (V_VN _ v) = V_VN n (G.basicUnsafeSlice (d*m) (d*n) v)
where d = reflectDim (Proxy :: Proxy n)
basicUnsafeIndexM (V_VN _ v) i =
liftM V $ V.generateM d (\j -> G.basicUnsafeIndexM v (d*i+j))
where d = reflectDim (Proxy :: Proxy n)