{-# LANGUAGE ConstraintKinds, DataKinds, DeriveDataTypeable, GADTs,
             FlexibleContexts, FlexibleInstances, TypeOperators, PolyKinds,
             GeneralizedNewtypeDeriving, InstanceSigs, RoleAnnotations,
             MultiParamTypeClasses, NoImplicitPrelude, StandaloneDeriving,
             ScopedTypeVariables, TupleSections, TypeFamilies, RankNTypes,
             TypeSynonymInstances, UndecidableInstances,
             RebindableSyntax #-}

-- | The wrapper for a C implementation of the Tensor interface.

module Crypto.Lol.Cyclotomic.Tensor.CTensor
( CT
-- Exports below here are due solely to ticket #10338. See CycTests for more details
, CRNS
, Dispatch
) where

import Algebra.Additive as Additive (C)
import Algebra.Ring     as Ring (C)

import Control.Applicative
import Control.DeepSeq
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Random
import Control.Monad.Trans (lift)

import Data.Coerce
import Data.Constraint
import Data.Foldable as F
import Data.Int
import Data.Maybe
import Data.Traversable as T
import Data.Typeable
import Data.Vector.Generic           as V (zip, unzip)
import Data.Vector.Storable          as SV (Vector, replicate, replicateM, thaw, convert, foldl',
                                            unsafeToForeignPtr0, unsafeSlice, mapM, fromList,
                                            generate, foldl1',
                                            unsafeWith, zipWith, map, length, unsafeFreeze, thaw)
import Data.Vector.Storable.Internal (getPtr)
import Data.Vector.Storable.Mutable  as SM hiding (replicate)

import           Foreign.ForeignPtr
import           Foreign.Marshal.Array
import           Foreign.Ptr
import           Foreign.Storable        (Storable (..))
import qualified Foreign.Storable.Record as Store
import           Foreign.Storable.Tuple  ()
import           System.IO.Unsafe
import           Test.QuickCheck         hiding (generate)
import           Unsafe.Coerce

import Crypto.Lol.CRTrans
import Crypto.Lol.LatticePrelude as LP hiding (replicate, unzip, zip, lift)
import Crypto.Lol.Reflects
import Crypto.Lol.Cyclotomic.Tensor

import Crypto.Lol.Types.IZipVector
import Crypto.Lol.Types.ZqBasic
import Crypto.Lol.GaussRandom

import Crypto.Lol.Cyclotomic.Tensor.CTensor.Extension

import Algebra.ZeroTestable   as ZeroTestable (C)


-- | An implementation of 'Tensor' backed by C code.
newtype CT' (m :: Factored) r = CT' { unCT :: Vector r } deriving (Show, Eq, NFData, Typeable)

-- the first argument, though phantom, affects representation
type role CT' representational nominal

-- GADT wrapper that distinguishes between Unbox and unrestricted
-- element types
-- | A wrapper type to seamlessly convert between internal representations
-- This type is an instance of 'Tensor'.
data CT (m :: Factored) r where 
  CT :: Storable r => CT' m r -> CT m r
  ZV :: IZipVector m r -> CT m r
  deriving (Typeable)

instance Eq r => Eq (CT m r) where
  (ZV x) == (ZV y) = x == y
  (CT x) == (CT y) = x == y
  x@(CT _) == y = x == toCT y
  y == x@(CT _) = x == toCT y

deriving instance Show r => Show (CT m r)

toCT :: (Storable r) => CT m r -> CT m r
toCT v@(CT _) = v
toCT (ZV v) = CT $ zvToCT' v

toZV :: (Fact m) => CT m r -> CT m r
toZV (CT (CT' v)) = ZV $ fromMaybe (error "toZV: internal error") $
                    iZipVector $ convert v
toZV v@(ZV _) = v

zvToCT' :: forall m r . (Storable r) => IZipVector m r -> CT' m r
zvToCT' v = coerce $ (convert $ unIZipVector v :: Vector r)

wrap :: (Storable r) => (CT' l r -> CT' m r) -> (CT l r -> CT m r)
wrap f (CT v) = CT $ f v
wrap f (ZV v) = CT $ f $ zvToCT' v

wrapM :: (Storable r, Monad mon) => (CT' l r -> mon (CT' m r))
         -> (CT l r -> mon (CT m r))
wrapM f (CT v) = liftM CT $ f v
wrapM f (ZV v) = liftM CT $ f $ zvToCT' v

-- convert an CT' *twace* signature to Tagged one
type family Tw (r :: *) :: * where
  Tw (CT' m' r -> CT' m r) = Tagged '(m,m') (Vector r -> Vector r)
  Tw (Maybe (CT' m' r -> CT' m r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)

type family Em r where
  Em (CT' m r -> CT' m' r) = Tagged '(m,m') (Vector r -> Vector r)
  Em (Maybe (CT' m r -> CT' m' r)) = TaggedT '(m,m') Maybe (Vector r -> Vector r)


---------- NUMERIC PRELUDE INSTANCES ----------
instance (Additive r, Storable r, CRNS r, Fact m)
  => Additive.C (CT m r) where
  (CT a@(CT' _)) + (CT b@(CT' _)) = CT $ (zipWrapper $ untag $ cZipDispatch dadd) a b  --pack $ SV.zipWith (+) (unpack a) (unpack b) -- Vector code --
  a + b = (toCT a) + (toCT b)
  negate (CT (CT' a)) = CT $ CT' $ SV.map negate a -- EAC: This probably should be converted to C code
  negate a = negate $ toCT a

  zero = CT $ repl zero

instance (Fact m, Ring r, Storable r, CRNS r)
  => Ring.C (CT m r) where
  (CT a@(CT' _)) * (CT b@(CT' _)) = CT $ (zipWrapper $ untag $ cZipDispatch dmul) a b  --pack $ SV.zipWith (*) (unpack a) (unpack b) -- Vector code --
  a * b = (toCT a) * (toCT b)

  fromInteger = CT . repl . fromInteger

instance (ZeroTestable r, Storable r, Fact m)
         => ZeroTestable.C (CT m r) where
  --{-# INLINABLE isZero #-} 
  isZero (CT (CT' a)) = SV.foldl' (\ b x -> b && isZero x) True a
  isZero (ZV v) = isZero v

---------- "Container" instances ----------

instance Fact m => Functor (CT m) where
  -- Functor instance is implied by Applicative laws
  fmap f x = pure f <*> x

instance Fact m => Applicative (CT m) where
  pure = ZV . pure

  (ZV f) <*> (ZV a) = ZV (f <*> a)
  f@(ZV _) <*> v@(CT _) = f <*> toZV v

instance Fact m => Foldable (CT m) where
  -- Foldable instance is implied by Traversable
  foldMap = foldMapDefault

instance Fact m => Traversable (CT m) where
  traverse f r@(CT _) = T.traverse f $ toZV r
  traverse f (ZV v) = ZV <$> T.traverse f v

instance Tensor CT where

  type TElt CT r = (IntegralDomain r, ZeroTestable r, 
                    Eq r, Random r, NFData r,
                    Storable r, CRNS r)

  entailIndexT = tag $ Sub Dict
  entailFullT = tag $ Sub Dict

  scalarPow = CT . scalarPow' -- Vector code

  l = wrap $ lgWrapper $ untag $ lgDispatch dl
  lInv = wrap $ lgWrapper $ untag $ lgDispatch dlinv

  mulGPow = wrap mulGPow' -- mulGPow' already has lgWrapper
  mulGDec = wrap $ lgWrapper $ untag $ lgDispatch dmulgdec

  divGPow = wrapM $ divGPow'
  -- we divide by p in the C code (for divGDec only(?)), do NOT call checkDiv!
  divGDec = wrapM $ divGWrapper $ Just . (untag $ lgDispatch dginvdec)

  crtFuncs = (,,,,) <$>
    Just (CT . repl) <*>
    (liftM wrap $ crtWrapper $ (untag $ cZipDispatch dmul) <$> untagT gCoeffsCRT) <*>
    (liftM wrap $ crtWrapper $ (untag $ cZipDispatch dmul) <$> untagT gInvCoeffsCRT) <*>
    (liftM wrap $ untagT $ crt') <*>
    (liftM wrap $ crtWrapper $ untagT ctCRTInv) 

  twacePowDec = wrap $ runIdentity $ coerceTw twacePowDec'
  embedPow = wrap $ runIdentity $ coerceEm embedPow'
  embedDec = wrap $ runIdentity $ coerceEm embedDec'

  tGaussianDec v = liftM CT $ gaussWrapper $ cDispatchGaussian v
  --tGaussianDec v = liftM CT $ coerceT' $ gaussianDec v

  crtExtFuncs = (,) <$> (liftM wrap $ coerceTw twaceCRT')
                    <*> (liftM wrap $ coerceEm embedCRT')

  coeffs = wrapM $ coerceCoeffs $ coeffs'

  powBasisPow = (CT <$>) <$> coerceBasis powBasisPow'

  crtSetDec = (CT <$>) <$> coerceBasis crtSetDec'

  fmapT f (CT v) = CT $ coerce (SV.map f) v
  fmapT f v@(ZV _) = fmapT f $ toCT v

  fmapTM f (CT (CT' arr)) = liftM (CT . CT') $ SV.mapM f arr
  fmapTM f v@(ZV _) = fmapTM f $ toCT v

coerceTw :: (Functor mon) => (TaggedT '(m, m') mon (Vector r -> Vector r)) -> mon (CT' m' r -> CT' m r)
coerceTw = (coerce <$>) . untagT

coerceEm :: (Functor mon) => (TaggedT '(m, m') mon (Vector r -> Vector r)) -> mon (CT' m r -> CT' m' r)
coerceEm = (coerce <$>) . untagT

-- | Useful coersion for defining @coeffs@ in the @Tensor@
-- interface. Using 'coerce' alone is insufficient for type inference.
coerceCoeffs :: (Fact m, Fact m') 
  => Tagged '(m,m') (Vector r -> [Vector r]) -> CT' m' r -> [CT' m r]
coerceCoeffs = coerce

-- | Useful coersion for defining @powBasisPow@ and @crtSetDec@ in the @Tensor@
-- interface. Using 'coerce' alone is insufficient for type inference.
coerceBasis :: 
  (Fact m, Fact m')
  => Tagged '(m,m') ([Vector r]) -> Tagged m [CT' m' r]
coerceBasis = coerce

-- | Class to dispatch tuples to the C backend.  In a different life,
-- the library used product-ring representation at the 'Cyclotomic'
-- level, so 'FastCyc' called 'Tensor'-level functions on each
-- component of the product ring. This class emulates that behavior
-- because making C handle arbitrary product rings seems difficult.
class CRNS r where

  zipWrapper :: (Fact m) => 
    (forall a . (TElt CT a, Dispatch a) => CT' m a -> CT' m a -> CT' m a)
    -> CT' m r -> CT' m r -> CT' m r

  crtWrapper :: (Fact m, CRTrans r) => 
    (forall a . (TElt CT a, CRTrans a, Dispatch a) => Maybe (CT' m a -> CT' m a))
    -> Maybe (CT' m r -> CT' m r)

  lgWrapper :: (Fact m) => 
    (forall a . (TElt CT a, Dispatch a) => CT' m a -> CT' m a)
    -> CT' m r -> CT' m r

  divGWrapper :: (Fact m) => 
    (forall a . (TElt CT a, Dispatch a) => CT' m a -> Maybe (CT' m a))
    -> CT' m r -> Maybe (CT' m r)

  gaussWrapper :: (Fact m, MonadRandom rnd) => 
    (forall a . (TElt CT a, Dispatch a, OrdFloat a, MonadRandom rnd) => rnd (CT' m a))
    -> rnd (CT' m r)

instance CRNS Double where
  zipWrapper f = f
  crtWrapper f = f
  lgWrapper f = f
  divGWrapper f = f
  gaussWrapper f = f

instance CRNS Int64 where
  zipWrapper f = f
  crtWrapper f = f
  lgWrapper f = f
  divGWrapper f = f
  gaussWrapper = error "Cannot call gaussianDec for Int64"

instance (TElt CT (Complex a), Dispatch (Complex a)) => CRNS (Complex a) where
  zipWrapper f = f
  crtWrapper f = f
  lgWrapper f = f
  divGWrapper f = f
  gaussWrapper = error "Cannot call gaussianDec for Complex"

-- EAC: we need PolyKinds in paritcular for this instance
instance (TElt CT (ZqBasic q i), Dispatch (ZqBasic q i)) => CRNS (ZqBasic q i) where
  zipWrapper f = f
  crtWrapper f = f
  lgWrapper f = f
  divGWrapper f = f
  gaussWrapper = error "Cannot call gaussianDec for ZqBasic"

instance (Storable a, Storable b, CRNS a, CRNS b, CRTrans a, CRTrans b) 
  => CRNS (a,b) where
  zipWrapper f (CT' x :: CT' m (a,b)) (CT' y) =
    let (a,b) = unzip x
        (c,d) = unzip y
        (CT' ac) = zipWrapper f (CT' a :: CT' m a) (CT' c)
        (CT' bd) = zipWrapper f (CT' b :: CT' m b) (CT' d)
    in CT' $ zip ac bd

  crtWrapper f = do
    fa <- crtWrapper f
    fb <- crtWrapper f
    return $ \ (CT' x :: CT' m (a,b)) -> 
      let (a,b) = unzip x
          (CT' a') = fa (CT' a :: CT' m a)
          (CT' b') = fb (CT' b :: CT' m b)
      in CT' $ zip a' b'

  lgWrapper f (CT' x :: CT' m (a,b)) = 
    let (a, b) = unzip x
        (CT' a') = lgWrapper f (CT' a :: CT' m a)
        (CT' b') = lgWrapper f (CT' b :: CT' m b)
    in CT' $ zip a' b'

  divGWrapper f (CT' x :: CT' m (a,b)) = 
    let (a, b) = unzip x
    in do -- in Maybe
      (CT' a') <- divGWrapper f (CT' a :: CT' m a)
      (CT' b') <- divGWrapper f (CT' b :: CT' m b)
      return $ CT' $ zip a' b'

  gaussWrapper f = do
    (CT' a) <- gaussWrapper f
    (CT' b) <- gaussWrapper f
    return $ CT' $ zip a b

mulGPow' :: (TElt CT r, Fact m) => CT' m r -> CT' m r
mulGPow' = lgWrapper $ untag $ lgDispatch dmulgpow

divGPow' :: forall m r . (TElt CT r, Fact m) => CT' m r -> Maybe (CT' m r)
divGPow' = divGWrapper $ untag $ checkDiv $ lgDispatch dginvpow

crt' :: forall m r . (TElt CT r, Fact m, CRTrans r) 
  => TaggedT m Maybe (CT' m r -> CT' m r)
crt' = tagT $ crtWrapper $ do
  f <- proxyT ctCRT (Proxy::Proxy m)
  return $ CT' . f . unCT

--{-# INLINE lgDispatch #-}
lgDispatch :: forall m r .
     (Storable r, Fact m, Additive r)
      => (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
         -> Tagged m (CT' m r -> CT' m r)
lgDispatch f = do
  factors <- liftM marshalFactors ppsFact
  totm <- liftM fromIntegral totientFact
  let numFacts = fromIntegral $ SV.length factors
  return $ coerce $ \yin -> unsafePerformIO $ do -- in IO
    yout <- SV.thaw yin
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith factors (\pfac ->
        f pout totm pfac numFacts))
    unsafeFreeze yout

--{-# INLINE ctCRT #-}
ctCRT :: forall m r .
         (Storable r, CRTrans r, Dispatch r,
          Fact m)
         => TaggedT m Maybe (Vector r -> Vector r)
ctCRT = do -- in TaggedT m Maybe
  ru' <- ru
  factors <- pureT $ liftM marshalFactors ppsFact
  totm <- pureT $ liftM fromIntegral totientFact
  let numFacts = fromIntegral $ SV.length factors
  return $ \yin -> unsafePerformIO $ do -- in IO
    yout <- SV.thaw yin
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith factors (\pfac ->
        withPtrArray ru' (\ruptr ->
          dcrt pout totm pfac numFacts ruptr)))
    unsafeFreeze yout

-- CTensor CRT^(-1) functions take inverse rus
--{-# INLINE ctCRTInv #-}
ctCRTInv :: (Storable r, CRTrans r, Dispatch r,
          Fact m)
         => TaggedT m Maybe (CT' m r -> CT' m r)
ctCRTInv = do -- in Maybe
  mhatInv <- liftM snd $ crtInfoFact
  ruinv' <- ruInv
  factors <- pureT $ liftM marshalFactors ppsFact
  totm <- pureT $ liftM fromIntegral totientFact
  let numFacts = fromIntegral $ SV.length factors
  -- EAC: can't use coerce here?
  return $ \(CT' yin) -> unsafePerformIO $ do
    yout <- SV.thaw yin
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith factors (\pfac ->
        withPtrArray ruinv' (\ruptr ->
          dcrtinv pout totm pfac numFacts ruptr mhatInv)))
    CT' <$> unsafeFreeze yout

checkDiv :: forall m r . 
  (IntegralDomain r, Storable r, ZeroTestable r, 
   Fact m)
    => Tagged m (CT' m r -> CT' m r) -> Tagged m (CT' m r -> Maybe (CT' m r))
checkDiv f = do
  f' <- f
  oddRad' <- liftM fromIntegral oddRadicalFact
  return $ \x -> 
    let (CT' y) = f' x
    in CT' <$> (SV.mapM (`divIfDivis` oddRad')) y

divIfDivis :: (IntegralDomain r, ZeroTestable r) => r -> r -> Maybe r
divIfDivis num den = let (q,r) = num `divMod` den
                     in if isZero r then Just q else Nothing

cZipDispatch :: (Storable r, Fact m, Additive r)
  => (Ptr r -> Ptr r -> Int64 -> IO ())
     -> Tagged m (CT' m r -> CT' m r -> CT' m r)
cZipDispatch f = do -- in Tagged m
  totm <- liftM fromIntegral $ totientFact
  return $ coerce $ \a b -> unsafePerformIO $ do
    yout <- SV.thaw a
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith b (\pin ->
        f pout pin totm))
    unsafeFreeze yout

cDispatchGaussian :: forall m r var rnd .
         (Storable r, Transcendental r, Dispatch r, Ord r,
          Fact m, ToRational var, Random r, MonadRandom rnd)
         => var -> rnd (CT' m r)
cDispatchGaussian var = liftM CT' $ flip proxyT (Proxy::Proxy m) $ do -- in TaggedT m rnd
  -- get rus for (Complex r)
  ruinv' <- mapTaggedT (return . fromMaybe (error "complexGaussianRoots")) $ ruInv
  factors <- liftM marshalFactors $ pureT ppsFact
  totm <- pureT totientFact
  m <- pureT valueFact
  rad <- pureT radicalFact
  yin <- lift $ realGaussians (var * fromIntegral (m `div` rad)) totm
  let numFacts = fromIntegral $ SV.length factors
  return $ unsafePerformIO $ do -- in IO
    --let yin = create $ SM.new totm :: Vector r -- contents will be overwritten, so no need to initialize
    yout <- SV.thaw yin
    SM.unsafeWith yout (\pout ->
      SV.unsafeWith factors (\pfac ->
       withPtrArray ruinv' (\ruptr ->
        dgaussdec pout (fromIntegral totm) pfac numFacts ruptr)))
    unsafeFreeze yout

instance (Arbitrary r, Fact m, Storable r) => Arbitrary (CT' m r) where
  arbitrary = replM arbitrary
  shrink = shrinkNothing

instance (Storable r, Arbitrary (CT' m r)) => Arbitrary (CT m r) where
  arbitrary = CT <$> arbitrary

instance (Storable r, Random r, Fact m) => Random (CT' m r) where
  --{-# INLINABLE random #-}
  random = runRand $ replM (liftRand random)

  randomR = error "randomR nonsensical for CT'"

instance (Storable r, Random (CT' m r)) => Random (CT m r) where
  --{-# INLINABLE random #-}
  random = runRand $ liftM CT (liftRand random)

  randomR = error "randomR nonsensical for CT"

instance (NFData r) => NFData (CT m r) where
  rnf (CT v) = rnf v
  rnf (ZV v) = rnf v

repl :: forall m r . (Fact m, Storable r) => r -> CT' m r
repl = let n = proxy totientFact (Proxy::Proxy m)
       in coerce . SV.replicate n

replM :: forall m r mon . (Fact m, Storable r, Monad mon) 
         => mon r -> mon (CT' m r)
replM = let n = proxy totientFact (Proxy::Proxy m)
        in liftM coerce . SV.replicateM n

--{-# INLINE scalarPow' #-}
scalarPow' :: forall t m r v .
  (Fact m, Additive r, Storable r)
  => r -> CT' m r
-- constant-term coefficient is first entry wrt powerful basis
scalarPow' = 
  let n = proxy totientFact (Proxy::Proxy m)
  in \r -> CT' $ generate n (\i -> if i == 0 then r else zero)

ru, ruInv :: forall r m . 
   (CRTrans r, Fact m, Storable r)
   => TaggedT m Maybe [Vector r]
--{-# INLINE ru #-}
ru = do
  mval <- pureT valueFact
  wPow <- liftM fst $ crtInfoFact
  liftM (LP.map
    (\(p,e) -> do
        let pp = p^e
            pow = mval `div` pp
        generate pp (wPow . (*pow)))) $
      pureT ppsFact

--{-# INLINE ruInv #-}
ruInv = do
  mval <- pureT valueFact
  wPow <- liftM fst $ crtInfoFact
  liftM (LP.map
    (\(p,e) -> do
        let pp = p^e
            pow = mval `div` pp
        generate pp (\i -> wPow $ (-i*pow)))) $
      pureT ppsFact

gCoeffsCRT, gInvCoeffsCRT :: (TElt CT r, CRTrans r, Fact m)
  => TaggedT m Maybe (CT' m r)
gCoeffsCRT = crt' <*> (return $ mulGPow' $ scalarPow' LP.one)
-- It's necessary to call 'fromJust' here: otherwise 
-- sequencing functions in 'crtFuncs' relies on 'divGPow' having an
-- implementation in C, which is not true for all types which have a C
-- implementation of, e.g. 'crt'. In particular, 'Complex Double' has C support
-- for 'crt', but not for 'divGPow'.
-- This really breaks the contract of Tensor, so it's probably a bad idea.
--   Someone can get the "crt" and can even pull the function "divGCRT" from Tensor,
--   but it will fail when they try to apply it.
-- As an implementation note if I ever do fix this: the division by rad(m) can be
-- tricky for Double/Complex Doubles, so be careful! This is why we have a custom
-- Complex wrapper around NP.Complex.
gInvCoeffsCRT = ($ fromJust $ divGPow' $ scalarPow' LP.one) <$> crt'

-- we can't put this in Extension with the rest of the twace/embed fucntions because it needs access to 
-- the C backend
twaceCRT' :: forall m m' r .
             (TElt CT r, CRTrans r, m `Divides` m')
             => TaggedT '(m, m') Maybe (Vector r -> Vector r)
twaceCRT' = tagT $ do -- Maybe monad
  (CT' g') <- proxyT gCoeffsCRT (Proxy::Proxy m')
  (CT' gInv) <- proxyT gInvCoeffsCRT (Proxy::Proxy m)
  embed <- proxyT embedCRT' (Proxy::Proxy '(m,m'))
  indices <- pure $ proxy extIndicesCRT (Proxy::Proxy '(m,m'))
  (_, m'hatinv) <- proxyT crtInfoFact (Proxy::Proxy m')
  let phi = proxy totientFact (Proxy::Proxy m)
      phi' = proxy totientFact (Proxy::Proxy m')
      mhat = fromIntegral $ proxy valueHatFact (Proxy::Proxy m)
      hatRatioInv = m'hatinv * mhat
      reltot = phi' `div` phi
      -- tweak = mhat * g' / (m'hat * g)
      tweak = SV.map (* hatRatioInv) $ SV.zipWith (*) (embed gInv) g'
  return $ \ arr -> -- take true trace after mul-by-tweak
    let v = backpermute' indices (SV.zipWith (*) tweak arr)
    in generate phi $ \i -> foldl1' (+) $ SV.unsafeSlice (i*reltot) reltot v












-- C-backend support

marshalFactors :: [PP] -> Vector CPP
marshalFactors = SV.fromList . LP.map (\(p,e) -> CPP (fromIntegral p) (fromIntegral e))

-- http://stackoverflow.com/questions/6517387/vector-vector-foo-ptr-ptr-foo-io-a-io-a
withPtrArray :: (Storable a) => [Vector a] -> (Ptr (Ptr a) -> IO b) -> IO b
withPtrArray v f = do
  let vs = LP.map SV.unsafeToForeignPtr0 v
      ptrV = LP.map (\(fp,_) -> getPtr fp) vs
  res <- withArray ptrV f
  LP.mapM_ (\(fp,_) -> touchForeignPtr fp) vs
  return res

data CPP = CPP {p' :: !Int32, e' :: !Int16}
-- stolen from http://hackage.haskell.org/packages/archive/numeric-prelude/0.4.0.3/doc/html/src/Number-Complex.html#T
-- the NumericPrelude Storable instance for complex numbers
instance Storable CPP where
   sizeOf    = Store.sizeOf store
   alignment = Store.alignment store
   peek      = Store.peek store
   poke      = Store.poke store

store :: Store.Dictionary CPP
store = Store.run $
   liftA2 CPP
      (Store.element p')
      (Store.element e')

instance Show CPP where
    show (CPP p e) = "(" LP.++ (show p) LP.++ "," LP.++ (show e) LP.++ ")"

foreign import ccall unsafe "tensorLR" tensorLR ::                  Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvR" tensorLInvR ::            Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLRq" tensorLRq ::                Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorLInvRq" tensorLInvRq ::          Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorLDouble" tensorLDouble ::       Ptr Double -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvDouble" tensorLInvDouble :: Ptr Double -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLC" tensorLC ::       Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorLInvC" tensorLInvC :: Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16          -> IO ()

foreign import ccall unsafe "tensorGPowR" tensorGPowR ::         Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorGPowRq" tensorGPowRq ::       Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorGDecR" tensorGDecR ::         Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorGDecRq" tensorGDecRq ::       Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR ::   Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq :: Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR ::   Ptr Int64 -> Int64 -> Ptr CPP -> Int16          -> IO ()
foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq :: Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Int64 -> IO ()
--foreign import ccall unsafe "tensorGCRTRq" tensorGCRTRq ::       Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Int64   -> IO ()
--foreign import ccall unsafe "tensorGCRTC" tensorGCRTC ::         Ptr (Complex Double) ->   Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()
--foreign import ccall unsafe "tensorGInvCRTRq" tensorGInvCRTRq :: Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Int64   -> IO ()
--foreign import ccall unsafe "tensorGInvCRTC" tensorGInvCRTC ::   Ptr (Complex Double) ->   Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()

foreign import ccall unsafe "tensorCRTRq" tensorCRTRq ::         Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Int64 -> IO ()
foreign import ccall unsafe "tensorCRTC" tensorCRTC ::           Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()
foreign import ccall unsafe "tensorCRTInvRq" tensorCRTInvRq ::   Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Int64 -> Int64 -> IO ()
foreign import ccall unsafe "tensorCRTInvC" tensorCRTInvC ::     Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> Double -> IO ()

foreign import ccall unsafe "tensorGaussianDec" tensorGaussianDec :: Ptr Double -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) ->  IO ()

foreign import ccall unsafe "mulRq" mulRq :: Ptr (ZqBasic q Int64) -> Ptr (ZqBasic q Int64) -> Int64 -> Int64 -> IO ()
foreign import ccall unsafe "mulC" mulC :: Ptr (Complex Double) -> Ptr (Complex Double) -> Int64 -> IO ()

foreign import ccall unsafe "addRq" addRq :: Ptr (ZqBasic q Int64) -> Ptr (ZqBasic q Int64) -> Int64 -> Int64 -> IO ()
foreign import ccall unsafe "addR" addR :: Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()
foreign import ccall unsafe "addC" addC :: Ptr (Complex Double) -> Ptr (Complex Double) -> Int64 -> IO ()
foreign import ccall unsafe "addD" addD :: Ptr Double -> Ptr Double -> Int64 -> IO ()

-- | Class to safely match Haskell types with the appropriate C function.
class Dispatch r where
  dcrt :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr r) -> IO ()
  dcrtinv :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr r) -> r -> IO ()
  dl :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
  dlinv :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
  dmulgpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
  dmulgdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
  dginvpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
  dginvdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
  dadd :: Ptr r -> Ptr r -> Int64 -> IO ()
  dmul :: Ptr r -> Ptr r -> Int64 -> IO ()
  dgcrt :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr r) -> IO ()
  dginvcrt :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr r) -> IO ()
  dgaussdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex r)) -> IO ()

instance (Reflects q Int64) => Dispatch (ZqBasic q Int64) where
  dcrt pout totm pfac numFacts ruptr = 
    let q = proxy value (Proxy::Proxy q)
    in tensorCRTRq pout totm pfac numFacts ruptr q
  dcrtinv pout totm pfac numFacts ruptr minv =
    let q = proxy value (Proxy::Proxy q)
    --EAC: GHC doesn't like it if I change the type of minv to ZqBasic in the
    -- signature of tensorCRTInvRq, and the constructor of ZqBasic isn't exposed
    -- so using unsafeCoerce for now
    in tensorCRTInvRq pout totm pfac numFacts ruptr (unsafeCoerce minv) q
  dl pout totm pfac numFacts =
    let q = proxy value (Proxy::Proxy q)
    in tensorLRq pout totm pfac numFacts q
  dlinv pout totm pfac numFacts =
    let q = proxy value (Proxy::Proxy q)
    in tensorLInvRq pout totm pfac numFacts q
  dmulgpow pout totm pfac numFacts =
    let q = proxy value (Proxy::Proxy q)
    in tensorGPowRq pout totm pfac numFacts q
  dmulgdec pout totm pfac numFacts =
    let q = proxy value (Proxy::Proxy q)
    in tensorGDecRq pout totm pfac numFacts q
  dginvpow pout totm pfac numFacts =
    let q = proxy value (Proxy::Proxy q)
    in tensorGInvPowRq pout totm pfac numFacts q
  dginvdec pout totm pfac numFacts =
    let q = proxy value (Proxy::Proxy q)
    in tensorGInvDecRq pout totm pfac numFacts q
  dadd aout bout totm = 
    let q = proxy value (Proxy::Proxy q)
    in addRq aout bout totm q
  dmul aout bout totm =
    let q = proxy value (Proxy::Proxy q)
    in mulRq aout bout totm q
  dgcrt pout totm pfac numFacts gcoeffs' = error "dgcrt zq"
    --let q = proxy value (Proxy::Proxy q)
    --in tensorGCRTRq pout totm pfac numFacts gcoeffs' q
  dginvcrt pout totm pfac numFacts gcoeffs' = error "dginvcrt zq"
    --let q = proxy value (Proxy::Proxy q)
    --in tensorGInvCRTRq pout totm pfac numFacts gcoeffs' q
  dgaussdec = error "cannot call CT gaussianDec on type ZqBasic"

instance Dispatch (Complex Double) where
  dcrt = tensorCRTC
  dcrtinv pout totm pfac numFacts ruptr minv = 
    tensorCRTInvC pout totm pfac numFacts ruptr (real minv)
  dl = tensorLC
  dlinv = tensorLInvC
  dmulgpow = error "cannot call CT mulGPow on type Complex Double"
  dmulgdec = error "cannot call CT mulGDec on type Complex Double"
  dginvpow = error "cannot call CT divGPow on type Complex Double"
  dginvdec = error "cannot call CT divGDec on type Complex Double"
  dadd = addC
  dmul = mulC
  dgcrt = error "tensorGCRTC"
  dginvcrt = error "tensorGInvCRTC"
  dgaussdec = error "cannot call CT gaussianDec on type Comple Double"

instance Dispatch Double where
  dcrt = error "cannot call CT Crt on type Double"
  dcrtinv = error "cannot call CT CrtInv on type Double"
  dl = tensorLDouble
  dlinv = tensorLInvDouble
  dmulgpow = error "cannot call CT mulGPow on type Double"
  dmulgdec = error "cannot call CT mulGDec on type Double"
  dginvpow = error "cannot call CT divGPow on type Double"
  dginvdec = error "cannot call CT divGDec on type Double"
  dadd = addD
  dmul = error "cannot call CT (*) on type Double"
  dgcrt = error "cannot call CT mulGCRT on type Double"
  dginvcrt = error "cannot call CT divGCRT on type Double"
  dgaussdec = tensorGaussianDec

instance Dispatch Int64 where
  dcrt = error "cannot call CT Crt on type Int64"
  dcrtinv = error "cannot call CT CrtInv on type Int64"
  dl = tensorLR
  dlinv = tensorLInvR
  dmulgpow = tensorGPowR
  dmulgdec = tensorGDecR
  dginvpow = tensorGInvPowR
  dginvdec = tensorGInvDecR
  dadd = addR
  dmul = error "cannot call CT (*) on type Int64"
  dgcrt = error "cannot call CT mulGCRT on type Int64"
  dginvcrt = error "cannot call CT divGCRT on type Int64"
  dgaussdec = error "cannot call CT gaussianDec on type Int64"