{-# LANGUAGE ForeignFunctionInterface #-} module Test.LLVM.DSL.Example.Median where import qualified LLVM.DSL.Example.Median as Median import LLVM.DSL.Example.Median (MV) import qualified LLVM.DSL.Execution as Exec import qualified LLVM.DSL.Expression as Expr import LLVM.DSL.Expression (Exp) import qualified LLVM.Extra.Storable as Memory import qualified LLVM.Extra.Multi.Vector as MVec import qualified LLVM.Extra.Multi.Value as MV import qualified LLVM.Core as LLVM import Type.Data.Num.Decimal (D4) import qualified Data.Traversable as Trav import qualified Data.Foldable as Fold import qualified Data.List as List import Data.Int (Int32) import Control.Monad.IO.Class (liftIO) import Control.Applicative (liftA3) import Foreign (Ptr, peek, with, alloca) import qualified Test.DocTest.Driver as DocTest import System.IO.Unsafe (unsafePerformIO) unliftM3ExprFloat :: (Exp Float -> Exp Float -> Exp Float -> Exp Float) -> LLVM.Value Float -> LLVM.Value Float -> LLVM.Value Float -> LLVM.CodeGenFunction Float (LLVM.Value Float) unliftM3ExprFloat f a b c = do MV.Cons m <- Expr.unliftM3 f (MV.Cons a) (MV.Cons b) (MV.Cons c) return m unliftM3ExprInt32 :: (Exp Int32 -> Exp Int32 -> Exp Int32 -> Exp Int32) -> LLVM.Value Int32 -> LLVM.Value Int32 -> LLVM.Value Int32 -> LLVM.CodeGenFunction Int32 (LLVM.Value Int32) unliftM3ExprInt32 f a b c = do MV.Cons m <- Expr.unliftM3 f (MV.Cons a) (MV.Cons b) (MV.Cons c) return m unliftM3MVInt32 :: (MV Int32 -> MV Int32 -> MV Int32 -> LLVM.CodeGenFunction Int32 (MV Int32)) -> LLVM.Value Int32 -> LLVM.Value Int32 -> LLVM.Value Int32 -> LLVM.CodeGenFunction Int32 (LLVM.Value Int32) unliftM3MVInt32 f a b c = do MV.Cons m <- f (MV.Cons a) (MV.Cons b) (MV.Cons c) return m type ValPtrV4Int32 = LLVM.Value (Ptr (LLVM.Vector D4 Int32)) unliftM3MVV4Int32 :: (MVec.T D4 Int32 -> MVec.T D4 Int32 -> MVec.T D4 Int32 -> LLVM.CodeGenFunction r (MVec.T D4 Int32)) -> ValPtrV4Int32 -> ValPtrV4Int32 -> ValPtrV4Int32 -> ValPtrV4Int32 -> LLVM.CodeGenFunction r () unliftM3MVV4Int32 f aPtr bPtr cPtr mPtr = do a <- MVec.Cons <$> Memory.load aPtr b <- MVec.Cons <$> Memory.load bPtr c <- MVec.Cons <$> Memory.load cPtr MVec.Cons m <- f a b c Memory.store m mPtr foreign import ccall safe "dynamic" derefMedian3Ptr :: Exec.Importer (Int32 -> Int32 -> Int32 -> IO Int32) foreign import ccall safe "dynamic" derefMedian3V4Ptr :: Exec.Importer (Ptr (LLVM.Vector D4 Int32) -> Ptr (LLVM.Vector D4 Int32) -> Ptr (LLVM.Vector D4 Int32) -> Ptr (LLVM.Vector D4 Int32) -> IO ()) foreign import ccall safe "dynamic" derefMedian3FloatPtr :: Exec.Importer (Float -> Float -> Float -> IO Float) run :: DocTest.T () run = do let (funcNames, funcs) = unzip $ let func name f = (name, Exec.createFunction derefMedian3Ptr name f) in func "median3IfThen" (unliftM3ExprInt32 Median.median3IfThen) : func "median3Select" (unliftM3ExprInt32 Median.median3Select) : func "median3SelectS" (unliftM3ExprInt32 Median.median3SelectShared) : func "median3MinMax" (unliftM3ExprInt32 Median.median3MinMax) : func "median3Case" (unliftM3MVInt32 Median.median3Case) : func "median3CaseVec" (unliftM3MVInt32 Median.median3CaseVec) : [] (medianFloat, medianVector, medianFuncs) <- liftIO $ Exec.compile "median" $ liftA3 (,,) (Exec.createFunction derefMedian3FloatPtr "median3MinMaxFloat" (unliftM3ExprFloat Median.median3MinMax)) (Exec.createFunction derefMedian3V4Ptr "median3MinMaxVector" (unliftM3MVV4Int32 Median.median3MinMaxVector)) (Trav.sequenceA funcs) let check expected m = do DocTest.printPrefix (show m ++ " ") DocTest.property $ m == expected do check 3 =<< liftIO (medianFloat 3 1 4) DocTest.printPrefix "medianFloat: " DocTest.property $ \a b c -> unsafePerformIO (medianFloat a b c) == List.sort [a,b,c] !! 1 liftIO $ alloca $ \mv -> with (LLVM.consVector 3 1 4 1) $ \av -> with (LLVM.consVector 2 7 1 8) $ \bv -> with (LLVM.consVector 5 7 7 2) $ \cv -> do medianVector av bv cv mv print =<< peek mv Fold.for_ (zip funcNames medianFuncs) $ \(name, medianFunc) -> do check 3 =<< liftIO (medianFunc 3 1 4) DocTest.printPrefix (name ++ ": ") DocTest.property $ \a b c -> unsafePerformIO (medianFunc a b c) == List.sort [a,b,c] !! 1