{-# LANGUAGE ScopedTypeVariables, FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Database.SQLite.Simple.Function
    (
      Function
    , createFunction
    , deleteFunction
    ) where

import Control.Exception
import Data.Proxy
import Database.SQLite3 as Base hiding (createFunction,deleteFunction,funcArgText,funcResultText)
import qualified Database.SQLite3.Direct as Base
import Database.SQLite.Simple
import Database.SQLite.Simple.Internal (Field(..))
import Database.SQLite.Simple.ToField
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.Ok
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE

class Function a where
  argCount :: Proxy a -> Int
  deterministicFn :: Proxy a -> Bool
  evalFunction :: Base.FuncContext -> Base.FuncArgs -> Int -> a -> IO ()

instance {-# OVERLAPPING #-} (ToField a) => Function a where
  argCount :: Proxy a -> Int
argCount = Int -> Proxy a -> Int
forall a b. a -> b -> a
const 0
  deterministicFn :: Proxy a -> Bool
deterministicFn = Bool -> Proxy a -> Bool
forall a b. a -> b -> a
const Bool
True
  evalFunction :: FuncContext -> FuncArgs -> Int -> a -> IO ()
evalFunction ctx :: FuncContext
ctx _ _ a :: a
a = case a -> SQLData
forall a. ToField a => a -> SQLData
toField a
a of
    SQLInteger r :: Int64
r -> FuncContext -> Int64 -> IO ()
Base.funcResultInt64 FuncContext
ctx Int64
r
    SQLFloat r :: Double
r -> FuncContext -> Double -> IO ()
Base.funcResultDouble FuncContext
ctx Double
r
    SQLText r :: Text
r -> FuncContext -> Utf8 -> IO ()
Base.funcResultText FuncContext
ctx (Utf8 -> IO ()) -> Utf8 -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Utf8
Base.Utf8 (ByteString -> Utf8) -> ByteString -> Utf8
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 Text
r
    SQLBlob r :: ByteString
r -> FuncContext -> ByteString -> IO ()
Base.funcResultBlob FuncContext
ctx ByteString
r
    SQLNull -> FuncContext -> IO ()
Base.funcResultNull FuncContext
ctx

instance {-# Overlapping #-} (Function a) => Function (IO a) where
  argCount :: Proxy (IO a) -> Int
argCount = Int -> Proxy (IO a) -> Int
forall a b. a -> b -> a
const 0
  deterministicFn :: Proxy (IO a) -> Bool
deterministicFn = Bool -> Proxy (IO a) -> Bool
forall a b. a -> b -> a
const Bool
False
  evalFunction :: FuncContext -> FuncArgs -> Int -> IO a -> IO ()
evalFunction ctx :: FuncContext
ctx args :: FuncArgs
args ca :: Int
ca a :: IO a
a = IO a
a IO a -> (a -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= FuncContext -> FuncArgs -> Int -> a -> IO ()
forall a.
Function a =>
FuncContext -> FuncArgs -> Int -> a -> IO ()
evalFunction FuncContext
ctx FuncArgs
args Int
ca

instance {-# Overlapping #-} forall f r . (Function r, FromField f) => Function (f -> r) where
  argCount :: Proxy (f -> r) -> Int
argCount = Int -> Proxy (f -> r) -> Int
forall a b. a -> b -> a
const (Int -> Proxy (f -> r) -> Int) -> Int -> Proxy (f -> r) -> Int
forall a b. (a -> b) -> a -> b
$ Proxy r -> Int
forall a. Function a => Proxy a -> Int
argCount (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1
  deterministicFn :: Proxy (f -> r) -> Bool
deterministicFn = Bool -> Proxy (f -> r) -> Bool
forall a b. a -> b -> a
const (Bool -> Proxy (f -> r) -> Bool) -> Bool -> Proxy (f -> r) -> Bool
forall a b. (a -> b) -> a -> b
$ Proxy r -> Bool
forall a. Function a => Proxy a -> Bool
deterministicFn (Proxy r
forall k (t :: k). Proxy t
Proxy :: Proxy r)
  evalFunction :: FuncContext -> FuncArgs -> Int -> (f -> r) -> IO ()
evalFunction ctx :: FuncContext
ctx args :: FuncArgs
args ca :: Int
ca fn :: f -> r
fn = let ca' :: ArgCount
ca' = Int -> ArgCount
Base.ArgCount Int
ca in do
    SQLData
sqlv <- FuncArgs -> ArgCount -> IO ColumnType
Base.funcArgType FuncArgs
args ArgCount
ca' IO ColumnType -> (ColumnType -> IO SQLData) -> IO SQLData
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ct :: ColumnType
ct -> case ColumnType
ct of
      Base.IntegerColumn -> Int64 -> SQLData
SQLInteger (Int64 -> SQLData) -> IO Int64 -> IO SQLData
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FuncArgs -> ArgCount -> IO Int64
Base.funcArgInt64 FuncArgs
args ArgCount
ca'
      Base.FloatColumn -> Double -> SQLData
SQLFloat (Double -> SQLData) -> IO Double -> IO SQLData
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FuncArgs -> ArgCount -> IO Double
Base.funcArgDouble FuncArgs
args ArgCount
ca'
      Base.TextColumn -> (\(Base.Utf8 b :: ByteString
b) -> Text -> SQLData
SQLText (Text -> SQLData) -> Text -> SQLData
forall a b. (a -> b) -> a -> b
$ ByteString -> Text
TE.decodeUtf8 ByteString
b) (Utf8 -> SQLData) -> IO Utf8 -> IO SQLData
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        FuncArgs -> ArgCount -> IO Utf8
Base.funcArgText FuncArgs
args ArgCount
ca'
      Base.BlobColumn -> ByteString -> SQLData
SQLBlob (ByteString -> SQLData) -> IO ByteString -> IO SQLData
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FuncArgs -> ArgCount -> IO ByteString
Base.funcArgBlob FuncArgs
args ArgCount
ca'
      Base.NullColumn -> SQLData -> IO SQLData
forall (f :: * -> *) a. Applicative f => a -> f a
pure SQLData
SQLNull
    case FieldParser f
forall a. FromField a => FieldParser a
fromField FieldParser f -> FieldParser f
forall a b. (a -> b) -> a -> b
$ SQLData -> Int -> Field
Field SQLData
sqlv Int
ca of
      Ok arg :: f
arg -> FuncContext -> FuncArgs -> Int -> r -> IO ()
forall a.
Function a =>
FuncContext -> FuncArgs -> Int -> a -> IO ()
evalFunction FuncContext
ctx FuncArgs
args (Int
ca Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (f -> r
fn f
arg)
      Errors ex :: [SomeException]
ex -> ManyErrors -> IO ()
forall a e. Exception e => e -> a
throw (ManyErrors -> IO ()) -> ManyErrors -> IO ()
forall a b. (a -> b) -> a -> b
$ [SomeException] -> ManyErrors
ManyErrors [SomeException]
ex

createFunction :: forall f . Function f => Connection -> T.Text -> f -> IO (Either Base.Error ())
createFunction :: Connection -> Text -> f -> IO (Either Error ())
createFunction (Connection db :: Database
db) fn :: Text
fn f :: f
f = Database
-> Utf8
-> Maybe ArgCount
-> Bool
-> (FuncContext -> FuncArgs -> IO ())
-> IO (Either Error ())
Base.createFunction
  Database
db
  (ByteString -> Utf8
Base.Utf8 (ByteString -> Utf8) -> ByteString -> Utf8
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 Text
fn)
  (ArgCount -> Maybe ArgCount
forall a. a -> Maybe a
Just (ArgCount -> Maybe ArgCount) -> ArgCount -> Maybe ArgCount
forall a b. (a -> b) -> a -> b
$ Int -> ArgCount
Base.ArgCount (Int -> ArgCount) -> Int -> ArgCount
forall a b. (a -> b) -> a -> b
$ Proxy f -> Int
forall a. Function a => Proxy a -> Int
argCount (Proxy f
forall k (t :: k). Proxy t
Proxy :: Proxy f))
  (Proxy f -> Bool
forall a. Function a => Proxy a -> Bool
deterministicFn (Proxy f
forall k (t :: k). Proxy t
Proxy :: Proxy f))
  (\ctx :: FuncContext
ctx args :: FuncArgs
args -> IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
    (FuncContext -> FuncArgs -> Int -> f -> IO ()
forall a.
Function a =>
FuncContext -> FuncArgs -> Int -> a -> IO ()
evalFunction FuncContext
ctx FuncArgs
args 0 f
f)
    ((IO () -> SomeException -> IO ()
forall a b. a -> b -> a
const :: IO () -> SomeException -> IO ()) (IO () -> SomeException -> IO ())
-> IO () -> SomeException -> IO ()
forall a b. (a -> b) -> a -> b
$ FuncContext -> IO ()
Base.funcResultNull FuncContext
ctx))

deleteFunction :: Connection -> T.Text -> IO (Either Base.Error ())
deleteFunction :: Connection -> Text -> IO (Either Error ())
deleteFunction (Connection db :: Database
db) fn :: Text
fn = Database -> Utf8 -> Maybe ArgCount -> IO (Either Error ())
Base.deleteFunction
  Database
db
  (ByteString -> Utf8
Base.Utf8 (ByteString -> Utf8) -> ByteString -> Utf8
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 Text
fn)
  Maybe ArgCount
forall a. Maybe a
Nothing