{-# LANGUAGE CPP #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, FlexibleInstances, UndecidableInstances #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE Trustworthy #-} {-# LANGUAGE DeriveGeneric #-} #ifndef MIN_VERSION_hashable #define MIN_VERSION_hashable(x,y,z) 1 #endif #ifndef MIN_VERSION_reflection #define MIN_VERSION_reflection(x,y,z) 1 #endif #ifndef MIN_VERSION_transformers #define MIN_VERSION_transformers(x,y,z) 1 #endif #ifndef MIN_VERSION_base #define MIN_VERSION_base(x,y,z) 1 #endif ----------------------------------------------------------------------------- -- | -- Copyright : (C) 2012-2015 Edward Kmett -- License : BSD-style (see the file LICENSE) -- -- Maintainer : Edward Kmett -- Stability : experimental -- Portability : non-portable -- -- n-D Vectors ---------------------------------------------------------------------------- module Linear.V ( V(V,toVector) #ifdef MIN_VERSION_template_haskell , int #endif , dim , Dim(..) , reifyDim , reifyVector , reifyDimNat , reifyVectorNat , fromVector , Finite(..) , _V, _V' ) where import Control.Applicative import Control.DeepSeq (NFData) import Control.Monad import Control.Monad.Fix import Control.Monad.Trans.State import Control.Monad.Zip import Control.Lens as Lens import Data.Binary as Binary import Data.Bytes.Serial import Data.Complex import Data.Data import Data.Distributive import Data.Foldable as Foldable import qualified Data.Foldable.WithIndex as WithIndex import Data.Functor.Bind import Data.Functor.Classes import Data.Functor.Rep as Rep import qualified Data.Functor.WithIndex as WithIndex import Data.Hashable import Data.Hashable.Lifted import Data.Kind import Data.Reflection as R import Data.Serialize as Cereal import qualified Data.Traversable.WithIndex as WithIndex import qualified Data.Vector as V import Data.Vector (Vector) import Data.Vector.Fusion.Util (Box(..)) import qualified Data.Vector.Generic as G import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Generic.Mutable as M import Foreign.Ptr import Foreign.Storable import GHC.TypeLits import GHC.Generics (Generic, Generic1) #if !(MIN_VERSION_reflection(1,3,0)) && defined(MIN_VERSION_template_haskell) import Language.Haskell.TH #endif import Linear.Epsilon import Linear.Metric import Linear.Vector import Prelude as P #if !(MIN_VERSION_base(4,11,0)) import Data.Semigroup #endif import System.Random (Random(..)) class Dim n where reflectDim :: p n -> Int type role V nominal representational class Finite v where type Size (v :: Type -> Type) :: Nat -- this should allow kind k, for Reifies k Int toV :: v a -> V (Size v) a default toV :: Foldable v => v a -> V (Size v) a toV = V . V.fromList . Foldable.toList fromV :: V (Size v) a -> v a instance Finite Complex where type Size Complex = 2 toV (a :+ b) = V (V.fromListN 2 [a, b]) fromV (V v) = (v V.! 0) :+ (v V.! 1) _V :: (Finite u, Finite v) => Iso (V (Size u) a) (V (Size v) b) (u a) (v b) _V = iso fromV toV _V' :: Finite v => Iso (V (Size v) a) (V (Size v) b) (v a) (v b) _V' = iso fromV toV instance Finite (V (n :: Nat)) where type Size (V n) = n toV = id fromV = id newtype V n a = V { toVector :: V.Vector a } deriving (Eq,Ord,Show,Read,NFData ,Generic,Generic1 ) dim :: forall n a. Dim n => V n a -> Int dim _ = reflectDim (Proxy :: Proxy n) {-# INLINE dim #-} instance KnownNat n => Dim (n :: Nat) where reflectDim = fromInteger . natVal {-# INLINE reflectDim #-} instance (Dim n, Random a) => Random (V n a) where random = runState (V <$> V.replicateM (reflectDim (Proxy :: Proxy n)) (state random)) randomR (V ls,V hs) = runState (V <$> V.zipWithM (\l h -> state $ randomR (l,h)) ls hs) data ReifiedDim (s :: Type) retagDim :: (Proxy s -> a) -> proxy (ReifiedDim s) -> a retagDim f _ = f Proxy {-# INLINE retagDim #-} instance Reifies s Int => Dim (ReifiedDim s) where reflectDim = retagDim reflect {-# INLINE reflectDim #-} reifyDimNat :: Int -> (forall (n :: Nat). KnownNat n => Proxy n -> r) -> r reifyDimNat i f = R.reifyNat (fromIntegral i) f {-# INLINE reifyDimNat #-} reifyVectorNat :: forall a r. Vector a -> (forall (n :: Nat). KnownNat n => V n a -> r) -> r reifyVectorNat v f = reifyNat (fromIntegral $ V.length v) $ \(Proxy :: Proxy n) -> f (V v :: V n a) {-# INLINE reifyVectorNat #-} reifyDim :: Int -> (forall (n :: Type). Dim n => Proxy n -> r) -> r reifyDim i f = R.reify i (go f) where go :: (Proxy (ReifiedDim n) -> a) -> proxy n -> a go g _ = g Proxy {-# INLINE reifyDim #-} reifyVector :: forall a r. Vector a -> (forall (n :: Type). Dim n => V n a -> r) -> r reifyVector v f = reifyDim (V.length v) $ \(Proxy :: Proxy n) -> f (V v :: V n a) {-# INLINE reifyVector #-} instance Dim n => Dim (V n a) where reflectDim _ = reflectDim (Proxy :: Proxy n) {-# INLINE reflectDim #-} instance (Dim n, Semigroup a) => Semigroup (V n a) where (<>) = liftA2 (<>) instance (Dim n, Monoid a) => Monoid (V n a) where mempty = pure mempty #if !(MIN_VERSION_base(4,11,0)) mappend = liftA2 mappend #endif instance Functor (V n) where fmap f (V as) = V (fmap f as) {-# INLINE fmap #-} instance WithIndex.FunctorWithIndex Int (V n) where imap f (V as) = V (Lens.imap f as) {-# INLINE imap #-} instance Foldable (V n) where fold (V as) = fold as {-# INLINE fold #-} foldMap f (V as) = Foldable.foldMap f as {-# INLINE foldMap #-} foldr f z (V as) = V.foldr f z as {-# INLINE foldr #-} foldl f z (V as) = V.foldl f z as {-# INLINE foldl #-} foldr' f z (V as) = V.foldr' f z as {-# INLINE foldr' #-} foldl' f z (V as) = V.foldl' f z as {-# INLINE foldl' #-} foldr1 f (V as) = V.foldr1 f as {-# INLINE foldr1 #-} foldl1 f (V as) = V.foldl1 f as {-# INLINE foldl1 #-} length (V as) = V.length as {-# INLINE length #-} null (V as) = V.null as {-# INLINE null #-} toList (V as) = V.toList as {-# INLINE toList #-} elem a (V as) = V.elem a as {-# INLINE elem #-} maximum (V as) = V.maximum as {-# INLINE maximum #-} minimum (V as) = V.minimum as {-# INLINE minimum #-} sum (V as) = V.sum as {-# INLINE sum #-} product (V as) = V.product as {-# INLINE product #-} instance WithIndex.FoldableWithIndex Int (V n) where ifoldMap f (V as) = ifoldMap f as {-# INLINE ifoldMap #-} instance Traversable (V n) where traverse f (V as) = V <$> traverse f as {-# INLINE traverse #-} instance WithIndex.TraversableWithIndex Int (V n) where itraverse f (V as) = V <$> itraverse f as {-# INLINE itraverse #-} #if !MIN_VERSION_lens(5,0,0) instance Lens.FunctorWithIndex Int (V n) where imap = WithIndex.imap instance Lens.FoldableWithIndex Int (V n) where ifoldMap = WithIndex.ifoldMap instance Lens.TraversableWithIndex Int (V n) where itraverse = WithIndex.itraverse #endif instance Apply (V n) where V as <.> V bs = V (V.zipWith id as bs) {-# INLINE (<.>) #-} instance Dim n => Applicative (V n) where pure = V . V.replicate (reflectDim (Proxy :: Proxy n)) {-# INLINE pure #-} V as <*> V bs = V (V.zipWith id as bs) {-# INLINE (<*>) #-} instance Bind (V n) where V as >>- f = V $ V.generate (V.length as) $ \i -> toVector (f (as `V.unsafeIndex` i)) `V.unsafeIndex` i {-# INLINE (>>-) #-} instance Dim n => Monad (V n) where #if !(MIN_VERSION_base(4,11,0)) return = V . V.replicate (reflectDim (Proxy :: Proxy n)) {-# INLINE return #-} #endif V as >>= f = V $ V.generate (reflectDim (Proxy :: Proxy n)) $ \i -> toVector (f (as `V.unsafeIndex` i)) `V.unsafeIndex` i {-# INLINE (>>=) #-} instance Dim n => Additive (V n) where zero = pure 0 {-# INLINE zero #-} liftU2 f (V as) (V bs) = V (V.zipWith f as bs) {-# INLINE liftU2 #-} liftI2 f (V as) (V bs) = V (V.zipWith f as bs) {-# INLINE liftI2 #-} instance (Dim n, Num a) => Num (V n a) where V as + V bs = V $ V.zipWith (+) as bs {-# INLINE (+) #-} V as - V bs = V $ V.zipWith (-) as bs {-# INLINE (-) #-} V as * V bs = V $ V.zipWith (*) as bs {-# INLINE (*) #-} negate = fmap negate {-# INLINE negate #-} abs = fmap abs {-# INLINE abs #-} signum = fmap signum {-# INLINE signum #-} fromInteger = pure . fromInteger {-# INLINE fromInteger #-} instance (Dim n, Fractional a) => Fractional (V n a) where recip = fmap recip {-# INLINE recip #-} V as / V bs = V $ V.zipWith (/) as bs {-# INLINE (/) #-} fromRational = pure . fromRational {-# INLINE fromRational #-} instance (Dim n, Floating a) => Floating (V n a) where pi = pure pi {-# INLINE pi #-} exp = fmap exp {-# INLINE exp #-} sqrt = fmap sqrt {-# INLINE sqrt #-} log = fmap log {-# INLINE log #-} V as ** V bs = V $ V.zipWith (**) as bs {-# INLINE (**) #-} logBase (V as) (V bs) = V $ V.zipWith logBase as bs {-# INLINE logBase #-} sin = fmap sin {-# INLINE sin #-} tan = fmap tan {-# INLINE tan #-} cos = fmap cos {-# INLINE cos #-} asin = fmap asin {-# INLINE asin #-} atan = fmap atan {-# INLINE atan #-} acos = fmap acos {-# INLINE acos #-} sinh = fmap sinh {-# INLINE sinh #-} tanh = fmap tanh {-# INLINE tanh #-} cosh = fmap cosh {-# INLINE cosh #-} asinh = fmap asinh {-# INLINE asinh #-} atanh = fmap atanh {-# INLINE atanh #-} acosh = fmap acosh {-# INLINE acosh #-} instance Dim n => Distributive (V n) where distribute f = V $ V.generate (reflectDim (Proxy :: Proxy n)) $ \i -> fmap (\(V v) -> V.unsafeIndex v i) f {-# INLINE distribute #-} instance Hashable a => Hashable (V n a) where hashWithSalt s0 (V v) = V.foldl' (\s a -> s `hashWithSalt` a) s0 v `hashWithSalt` V.length v instance Dim n => Hashable1 (V n) where liftHashWithSalt h s0 (V v) = V.foldl' (\s a -> h s a) s0 v `hashWithSalt` V.length v {-# INLINE liftHashWithSalt #-} instance (Dim n, Storable a) => Storable (V n a) where sizeOf _ = reflectDim (Proxy :: Proxy n) * sizeOf (undefined:: a) {-# INLINE sizeOf #-} alignment _ = alignment (undefined :: a) {-# INLINE alignment #-} poke ptr (V xs) = Foldable.forM_ [0..reflectDim (Proxy :: Proxy n)-1] $ \i -> pokeElemOff ptr' i (V.unsafeIndex xs i) where ptr' = castPtr ptr {-# INLINE poke #-} peek ptr = V <$> V.generateM (reflectDim (Proxy :: Proxy n)) (peekElemOff ptr') where ptr' = castPtr ptr {-# INLINE peek #-} instance (Dim n, Epsilon a) => Epsilon (V n a) where nearZero = nearZero . quadrance {-# INLINE nearZero #-} instance Dim n => Metric (V n) where dot (V a) (V b) = V.sum $ V.zipWith (*) a b {-# INLINE dot #-} -- TODO: instance (Dim n, Ix a) => Ix (V n a) fromVector :: forall n a. Dim n => Vector a -> Maybe (V n a) fromVector v | V.length v == reflectDim (Proxy :: Proxy n) = Just (V v) | otherwise = Nothing #if !(MIN_VERSION_reflection(1,3,0)) && defined(MIN_VERSION_template_haskell) data Z -- 0 data D (n :: *) -- 2n data SD (n :: *) -- 2n+1 data PD (n :: *) -- 2n-1 instance Reifies Z Int where reflect _ = 0 {-# INLINE reflect #-} retagD :: (Proxy n -> a) -> proxy (D n) -> a retagD f _ = f Proxy {-# INLINE retagD #-} retagSD :: (Proxy n -> a) -> proxy (SD n) -> a retagSD f _ = f Proxy {-# INLINE retagSD #-} retagPD :: (Proxy n -> a) -> proxy (PD n) -> a retagPD f _ = f Proxy {-# INLINE retagPD #-} instance Reifies n Int => Reifies (D n) Int where reflect = (\n -> n+n) <$> retagD reflect {-# INLINE reflect #-} instance Reifies n Int => Reifies (SD n) Int where reflect = (\n -> n+n+1) <$> retagSD reflect {-# INLINE reflect #-} instance Reifies n Int => Reifies (PD n) Int where reflect = (\n -> n+n-1) <$> retagPD reflect {-# INLINE reflect #-} -- | This can be used to generate a template haskell splice for a type level version of a given 'int'. -- -- This does not use GHC TypeLits, instead it generates a numeric type by hand similar to the ones used -- in the \"Functional Pearl: Implicit Dimurations\" paper by Oleg Kiselyov and Chung-Chieh Shan. int :: Int -> TypeQ int n = case quotRem n 2 of (0, 0) -> conT ''Z (q,-1) -> conT ''PD `appT` int q (q, 0) -> conT ''D `appT` int q (q, 1) -> conT ''SD `appT` int q _ -> error "ghc is bad at math" #endif instance Dim n => Representable (V n) where type Rep (V n) = Int tabulate = V . V.generate (reflectDim (Proxy :: Proxy n)) {-# INLINE tabulate #-} index (V xs) i = xs V.! i {-# INLINE index #-} type instance Index (V n a) = Int type instance IxValue (V n a) = a instance Ixed (V n a) where ix i f v@(V as) | i < 0 || i >= V.length as = pure v | otherwise = vLens i f v {-# INLINE ix #-} instance Dim n => MonadZip (V n) where mzip (V as) (V bs) = V $ V.zip as bs mzipWith f (V as) (V bs) = V $ V.zipWith f as bs instance Dim n => MonadFix (V n) where mfix f = tabulate $ \r -> let a = Rep.index (f a) r in a instance Each (V n a) (V n b) a b where each = traverse {-# INLINE each #-} instance (Bounded a, Dim n) => Bounded (V n a) where minBound = pure minBound {-# INLINE minBound #-} maxBound = pure maxBound {-# INLINE maxBound #-} vConstr :: Constr vConstr = mkConstr vDataType "variadic" [] Prefix {-# NOINLINE vConstr #-} vDataType :: DataType vDataType = mkDataType "Linear.V.V" [vConstr] {-# NOINLINE vDataType #-} instance (Typeable (V n), Typeable (V n a), Dim n, Data a) => Data (V n a) where gfoldl f z (V as) = z (V . V.fromList) `f` V.toList as toConstr _ = vConstr gunfold k z c = case constrIndex c of 1 -> k (z (V . V.fromList)) _ -> error "gunfold" dataTypeOf _ = vDataType dataCast1 f = gcast1 f instance Dim n => Serial1 (V n) where serializeWith = traverse_ deserializeWith f = sequenceA $ pure f instance (Dim n, Serial a) => Serial (V n a) where serialize = traverse_ serialize deserialize = sequenceA $ pure deserialize instance (Dim n, Binary a) => Binary (V n a) where put = serializeWith Binary.put get = deserializeWith Binary.get instance (Dim n, Serialize a) => Serialize (V n a) where put = serializeWith Cereal.put get = deserializeWith Cereal.get instance Eq1 (V n) where liftEq f0 (V as0) (V bs0) = go f0 (V.toList as0) (V.toList bs0) where go _ [] [] = True go f (a:as) (b:bs) = f a b && go f as bs go _ _ _ = False instance Ord1 (V n) where liftCompare f0 (V as0) (V bs0) = go f0 (V.toList as0) (V.toList bs0) where go f (a:as) (b:bs) = f a b `mappend` go f as bs go _ [] [] = EQ go _ _ [] = GT go _ [] _ = LT instance Show1 (V n) where liftShowsPrec _ g d (V as) = showParen (d > 10) $ showString "V " . g (V.toList as) instance Dim n => Read1 (V n) where liftReadsPrec _ g d = readParen (d > 10) $ \r -> [ (V (V.fromList as), r2) | ("V",r1) <- lex r , (as, r2) <- g r1 , P.length as == reflectDim (Proxy :: Proxy n) ] data instance U.Vector (V n a) = V_VN {-# UNPACK #-} !Int !(U.Vector a) data instance U.MVector s (V n a) = MV_VN {-# UNPACK #-} !Int !(U.MVector s a) instance (Dim n, U.Unbox a) => U.Unbox (V n a) instance (Dim n, U.Unbox a) => M.MVector U.MVector (V n a) where {-# INLINE basicLength #-} {-# INLINE basicUnsafeSlice #-} {-# INLINE basicOverlaps #-} {-# INLINE basicUnsafeNew #-} {-# INLINE basicUnsafeRead #-} {-# INLINE basicUnsafeWrite #-} basicLength (MV_VN n _) = n basicUnsafeSlice m n (MV_VN _ v) = MV_VN n (M.basicUnsafeSlice (d*m) (d*n) v) where d = reflectDim (Proxy :: Proxy n) basicOverlaps (MV_VN _ v) (MV_VN _ u) = M.basicOverlaps v u basicUnsafeNew n = liftM (MV_VN n) (M.basicUnsafeNew (d*n)) where d = reflectDim (Proxy :: Proxy n) basicUnsafeRead (MV_VN _ v) i = liftM V $ V.generateM d (\j -> M.basicUnsafeRead v (d*i+j)) where d = reflectDim (Proxy :: Proxy n) basicUnsafeWrite (MV_VN _ v0) i (V vn0) = let d0 = V.length vn0 in go v0 vn0 d0 (d0*i) 0 where go v vn d o j | j >= d = return () | otherwise = do a <- liftBox $ G.basicUnsafeIndexM vn j M.basicUnsafeWrite v o a go v vn d (o+1) (j+1) basicInitialize (MV_VN _ v) = M.basicInitialize v {-# INLINE basicInitialize #-} liftBox :: Monad m => Box a -> m a liftBox (Box a) = return a {-# INLINE liftBox #-} instance (Dim n, U.Unbox a) => G.Vector U.Vector (V n a) where {-# INLINE basicUnsafeFreeze #-} {-# INLINE basicUnsafeThaw #-} {-# INLINE basicLength #-} {-# INLINE basicUnsafeSlice #-} {-# INLINE basicUnsafeIndexM #-} basicUnsafeFreeze (MV_VN n v) = liftM ( V_VN n) (G.basicUnsafeFreeze v) basicUnsafeThaw ( V_VN n v) = liftM (MV_VN n) (G.basicUnsafeThaw v) basicLength ( V_VN n _) = n basicUnsafeSlice m n (V_VN _ v) = V_VN n (G.basicUnsafeSlice (d*m) (d*n) v) where d = reflectDim (Proxy :: Proxy n) basicUnsafeIndexM (V_VN _ v) i = liftM V $ V.generateM d (\j -> G.basicUnsafeIndexM v (d*i+j)) where d = reflectDim (Proxy :: Proxy n) vLens :: Int -> Lens' (V n a) a vLens i = \f (V v) -> f (v V.! i) <&> \a -> V (v V.// [(i, a)]) {-# INLINE vLens #-} instance ( 1 <= n) => Field1 (V n a) (V n a) a a where _1 = vLens 0 instance ( 2 <= n) => Field2 (V n a) (V n a) a a where _2 = vLens 1 instance ( 3 <= n) => Field3 (V n a) (V n a) a a where _3 = vLens 2 instance ( 4 <= n) => Field4 (V n a) (V n a) a a where _4 = vLens 3 instance ( 5 <= n) => Field5 (V n a) (V n a) a a where _5 = vLens 4 instance ( 6 <= n) => Field6 (V n a) (V n a) a a where _6 = vLens 5 instance ( 7 <= n) => Field7 (V n a) (V n a) a a where _7 = vLens 6 instance ( 8 <= n) => Field8 (V n a) (V n a) a a where _8 = vLens 7 instance ( 9 <= n) => Field9 (V n a) (V n a) a a where _9 = vLens 8 instance (10 <= n) => Field10 (V n a) (V n a) a a where _10 = vLens 9 instance (11 <= n) => Field11 (V n a) (V n a) a a where _11 = vLens 10 instance (12 <= n) => Field12 (V n a) (V n a) a a where _12 = vLens 11 instance (13 <= n) => Field13 (V n a) (V n a) a a where _13 = vLens 12 instance (14 <= n) => Field14 (V n a) (V n a) a a where _14 = vLens 13 instance (15 <= n) => Field15 (V n a) (V n a) a a where _15 = vLens 14 instance (16 <= n) => Field16 (V n a) (V n a) a a where _16 = vLens 15 instance (17 <= n) => Field17 (V n a) (V n a) a a where _17 = vLens 16 instance (18 <= n) => Field18 (V n a) (V n a) a a where _18 = vLens 17 instance (19 <= n) => Field19 (V n a) (V n a) a a where _19 = vLens 18