module Numeric.Netlib.CArray.Utility where

import qualified Data.Array.CArray as CArray
import Data.Array.IOCArray (IOCArray, withIOCArray)
import Data.Array.CArray (CArray, withCArray, Ix)

import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.Marshal.Array as Array
import qualified Foreign.Marshal.Alloc as Alloc
import qualified Foreign.C.String as CStr
import qualified Foreign.C.Types as C
import Foreign.Storable.Complex ()
import Foreign.Storable (Storable, peek)
import Foreign.Ptr (Ptr)

import Control.Monad.Trans.Cont (ContT(ContT))
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative ((<$>))

import Data.Complex (Complex)


type FortranIO r = ContT r IO

run :: FortranIO r (IO a) -> FortranIO r a
run act = act >>= liftIO

runChecked :: String -> FortranIO r (Ptr C.CInt -> IO a) -> FortranIO r a
runChecked name act = do
   info <- alloca
   a <- run $ fmap ($info) act
   liftIO $ check name (peek info)
   return a

check :: String -> IO C.CInt -> IO ()
check msg f = do
   err <- f
   when (err/=0) $ error $ msg ++ ": " ++ show err

assert :: String -> Bool -> IO ()
assert msg success = when (not success) $ error $ "assertion failed: " ++ msg

ignore :: String -> Int -> IO ()
ignore _msg _dim = return ()


newArray :: (Ix i, Storable e) => (i, i) -> IO (CArray i e)
newArray bnds = CArray.createCArray bnds (\_ -> return ())

newArray1 :: (Storable e) => Int -> IO (CArray Int e)
newArray1 m = newArray (0, m-1)

newArray2 :: (Storable e) => Int -> Int -> IO (CArray (Int,Int) e)
newArray2 m n = newArray ((0,0), (m-1,n-1))

newArray3 :: (Storable e) => Int -> Int -> Int -> IO (CArray (Int,Int,Int) e)
newArray3 m n k = newArray ((0,0,0), (m-1,n-1,k-1))


sizes1 :: (Ix i) => (i,i) -> Int
sizes1 = CArray.rangeSize

sizes2 :: (Ix i, Ix j) => ((i,j),(i,j)) -> (Int,Int)
sizes2 ((i0,j0), (i1,j1)) =
   (CArray.rangeSize (i0,i1), CArray.rangeSize (j0,j1))

sizes3 :: (Ix i, Ix j, Ix k) => ((i,j,k),(i,j,k)) -> (Int,Int,Int)
sizes3 ((i0,j0,k0), (i1,j1,k1)) =
   (CArray.rangeSize (i0,i1),
    CArray.rangeSize (j0,j1),
    CArray.rangeSize (k0,k1))


cint :: Int -> FortranIO r (Ptr C.CInt)
cint = ContT . Marshal.with . fromIntegral

range :: (Int,Int) -> FortranIO r (Ptr C.CInt)
range = cint . CArray.rangeSize

alloca :: (Storable a) => FortranIO r (Ptr a)
alloca = ContT Alloc.alloca

allocaArray :: (Storable a) => Int -> FortranIO r (Ptr a)
allocaArray = ContT . Array.allocaArray

bool :: Bool -> FortranIO r (Ptr Bool)
bool = ContT . Marshal.with

char :: Char -> FortranIO r (Ptr C.CChar)
char = ContT . Marshal.with . CStr.castCharToCChar

string :: String -> FortranIO r (Ptr C.CChar)
string = ContT . CStr.withCString

float :: Float -> FortranIO r (Ptr Float)
float = ContT . Marshal.with

double :: Double -> FortranIO r (Ptr Double)
double = ContT . Marshal.with

complexFloat :: Complex Float -> FortranIO r (Ptr (Complex Float))
complexFloat = ContT . Marshal.with

complexDouble :: Complex Double -> FortranIO r (Ptr (Complex Double))
complexDouble = ContT . Marshal.with


array :: (Storable a) => CArray i a -> FortranIO r (Ptr a)
array = ContT . withCArray

arrayBounds :: (Storable a, Ix i) => CArray i a -> FortranIO r (Ptr a, (i,i))
arrayBounds v = flip (,) (CArray.bounds v) <$> array v

ioarray :: (Storable a) => IOCArray i a -> FortranIO r (Ptr a)
ioarray = ContT . withIOCArray


unzipBounds :: ((i,j),(i,j)) -> ((i,i), (j,j))
unzipBounds ((i0,j0), (i1,j1)) = ((i0,i1), (j0,j1))

(^!) :: (Num a) => a -> Int -> a
x^!n = x^n