{-# LANGUAGE OverloadedStrings #-}

-- | Implementation of an Sqlite-based 'Disk'.
module System.Mem.Disk.Sqlite where

import Control.Concurrent
    ( forkIO, killThread, ThreadId )
import Control.Exception
    ( bracket )
import Control.Monad
    ( forever, when )
import Data.ByteString
    ( ByteString )
import Data.Int
    ( Int64 )
import Data.IORef

import qualified Control.Concurrent.STM as STM
import qualified Data.Text as T
import qualified Database.SQLite3 as Sql
import qualified System.IO.Error as Sys
import qualified System.Directory as Sys

import qualified System.Mem.Disk.DiskApi as DiskApi

{-----------------------------------------------------------------------------
    File system helpers
------------------------------------------------------------------------------}
withFile :: FilePath -> IO a -> IO a
withFile :: forall a. FilePath -> IO a -> IO a
withFile FilePath
path IO a
action = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    (FilePath -> IO ()
throwIfAlreadyExists FilePath
path)
    (\()
_ -> FilePath -> IO ()
Sys.removeFile FilePath
path)
    (\()
_ -> IO a
action)

-- | Throw an 'IOError' that indicates that the file already exists.
throwIfAlreadyExists :: FilePath -> IO ()
throwIfAlreadyExists :: FilePath -> IO ()
throwIfAlreadyExists FilePath
path = do
    Bool
b <- FilePath -> IO Bool
Sys.doesFileExist FilePath
path
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
b forall a b. (a -> b) -> a -> b
$ forall a. IOError -> IO a
Sys.ioError forall a b. (a -> b) -> a -> b
$ IOErrorType
-> FilePath -> Maybe Handle -> Maybe FilePath -> IOError
Sys.mkIOError
        IOErrorType
Sys.alreadyExistsErrorType FilePath
"Creating 'Disk'" forall a. Maybe a
Nothing (forall a. a -> Maybe a
Just FilePath
path)

{-----------------------------------------------------------------------------
    Disk
------------------------------------------------------------------------------}
data Disk = Disk
    { Disk -> FilePath
path :: FilePath
    , Disk -> TChan (Cmd TMVar)
chan :: STM.TChan (Cmd STM.TMVar)
    , Disk -> IORef Int64
counter :: IORef Int64
    }

-- | Obtain the size of the 'Disk' in bytes.
-- Here, this is the file size.
getDiskSize_ :: Disk -> IO Integer
getDiskSize_ :: Disk -> IO Integer
getDiskSize_ = FilePath -> IO Integer
Sys.getFileSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. Disk -> FilePath
path

-- | Create a new file and use it for storing
-- t'System.Mem.Disk.DiskBytes'.
--
-- Throw an error if the file already exists,
-- delete the file after use.
withDiskSqlite :: FilePath -> (DiskApi.Disk -> IO a) -> IO a
withDiskSqlite :: forall a. FilePath -> (Disk -> IO a) -> IO a
withDiskSqlite FilePath
path Disk -> IO a
action =
    forall a. FilePath -> IO a -> IO a
withFile FilePath
path forall a b. (a -> b) -> a -> b
$
    forall {c}. FilePath -> (Database -> IO c) -> IO c
withDatabase FilePath
path forall a b. (a -> b) -> a -> b
$ \Database
db ->
    forall a. Database -> (SqlCmds -> IO a) -> IO a
withSql Database
db forall a b. (a -> b) -> a -> b
$ \SqlCmds
sql ->
    forall a. SqlCmds -> (TChan (Cmd TMVar) -> IO a) -> IO a
withThread SqlCmds
sql forall a b. (a -> b) -> a -> b
$ \TChan (Cmd TMVar)
chan -> do
        IORef Int64
counter <- forall a. a -> IO (IORef a)
newIORef Int64
0
        Disk -> IO a
action forall a b. (a -> b) -> a -> b
$ Disk -> Disk
mkDiskApi forall a b. (a -> b) -> a -> b
$ Disk{FilePath
path :: FilePath
path :: FilePath
path,TChan (Cmd TMVar)
chan :: TChan (Cmd TMVar)
chan :: TChan (Cmd TMVar)
chan,IORef Int64
counter :: IORef Int64
counter :: IORef Int64
counter}
  where
    withDatabase :: FilePath -> (Database -> IO c) -> IO c
withDatabase FilePath
path = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Text -> IO Database
Sql.open forall a b. (a -> b) -> a -> b
$ FilePath -> Text
T.pack FilePath
path) Database -> IO ()
Sql.close

mkDiskApi :: Disk -> DiskApi.Disk
mkDiskApi :: Disk -> Disk
mkDiskApi Disk
disk = DiskApi.Disk
    { put :: ByteString -> IO Int64
DiskApi.put = Disk -> ByteString -> IO Int64
put_ Disk
disk
    , get :: Int64 -> IO ByteString
DiskApi.get = Disk -> Int64 -> IO ByteString
get_ Disk
disk
    , delete :: Int64 -> IO ()
DiskApi.delete = Disk -> Int64 -> IO ()
delete_ Disk
disk
    , getDiskSize :: IO Integer
DiskApi.getDiskSize = Disk -> IO Integer
getDiskSize_ Disk
disk
    }

{-----------------------------------------------------------------------------
    Disk operations
------------------------------------------------------------------------------}
-- | Operations to be performed on the database
data Cmd cont
    = Put !Int64 ByteString
    | Get !Int64 (cont ByteString)
    | Delete !Int64

put_ :: Disk -> ByteString -> IO Int64
put_ :: Disk -> ByteString -> IO Int64
put_ Disk{TChan (Cmd TMVar)
chan :: TChan (Cmd TMVar)
chan :: Disk -> TChan (Cmd TMVar)
chan,IORef Int64
counter :: IORef Int64
counter :: Disk -> IORef Int64
counter} ByteString
bytes = do
    Int64
index <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int64
counter (\Int64
x -> (Int64
xforall a. Num a => a -> a -> a
+Int64
1,Int64
x))
    forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> a -> STM ()
STM.writeTChan TChan (Cmd TMVar)
chan forall a b. (a -> b) -> a -> b
$ forall (cont :: * -> *). Int64 -> ByteString -> Cmd cont
Put Int64
index ByteString
bytes
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
index

get_ :: Disk -> Int64 -> IO ByteString
get_ :: Disk -> Int64 -> IO ByteString
get_ Disk{TChan (Cmd TMVar)
chan :: TChan (Cmd TMVar)
chan :: Disk -> TChan (Cmd TMVar)
chan} Int64
index = do
    TMVar ByteString
k <- forall a. IO (TMVar a)
STM.newEmptyTMVarIO
    forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> a -> STM ()
STM.writeTChan TChan (Cmd TMVar)
chan forall a b. (a -> b) -> a -> b
$ forall (cont :: * -> *). Int64 -> cont ByteString -> Cmd cont
Get Int64
index TMVar ByteString
k
    forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> STM a
STM.takeTMVar TMVar ByteString
k

delete_ :: Disk -> Int64 -> IO ()
delete_ :: Disk -> Int64 -> IO ()
delete_ Disk{TChan (Cmd TMVar)
chan :: TChan (Cmd TMVar)
chan :: Disk -> TChan (Cmd TMVar)
chan} Int64
index =
    forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> a -> STM ()
STM.writeTChan TChan (Cmd TMVar)
chan forall a b. (a -> b) -> a -> b
$ forall (cont :: * -> *). Int64 -> Cmd cont
Delete Int64
index

{-----------------------------------------------------------------------------
    Worker Thread
------------------------------------------------------------------------------}
-- | Worker thread for sequencing SQL commands.
withThread :: SqlCmds -> (STM.TChan (Cmd STM.TMVar) -> IO a) -> IO a
withThread :: forall a. SqlCmds -> (TChan (Cmd TMVar) -> IO a) -> IO a
withThread SqlCmds
sql TChan (Cmd TMVar) -> IO a
action =
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (SqlCmds -> IO (ThreadId, TChan (Cmd TMVar))
mkDatabaseThread SqlCmds
sql) (ThreadId -> IO ()
killThread forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (\(ThreadId
_,TChan (Cmd TMVar)
c) -> TChan (Cmd TMVar) -> IO a
action TChan (Cmd TMVar)
c)

mkDatabaseThread :: SqlCmds -> IO (ThreadId, STM.TChan (Cmd STM.TMVar))
mkDatabaseThread :: SqlCmds -> IO (ThreadId, TChan (Cmd TMVar))
mkDatabaseThread SqlCmds
sql = do
    TChan (Cmd TMVar)
chan <- forall a. IO (TChan a)
STM.newTChanIO
    ThreadId
threadId <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$
        SqlCmds -> Cmd TMVar -> IO ()
cmdSql SqlCmds
sql forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. STM a -> IO a
STM.atomically (forall a. TChan a -> STM a
STM.readTChan TChan (Cmd TMVar)
chan)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (ThreadId
threadId, TChan (Cmd TMVar)
chan)

{-----------------------------------------------------------------------------
    Sql
------------------------------------------------------------------------------}
data SqlCmds = SqlCmds
    { SqlCmds -> Statement
sput_ :: Sql.Statement
    , SqlCmds -> Statement
sget_ :: Sql.Statement
    , SqlCmds -> Statement
sdelete_ :: Sql.Statement
    }

withSql :: Sql.Database -> (SqlCmds -> IO a) -> IO a
withSql :: forall a. Database -> (SqlCmds -> IO a) -> IO a
withSql Database
db = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Database -> IO SqlCmds
initSql Database
db) SqlCmds -> IO ()
finalizeSql

initSql :: Sql.Database -> IO SqlCmds
initSql :: Database -> IO SqlCmds
initSql Database
db = do
    Database -> Text -> IO ()
Sql.exec Database
db Text
"CREATE TABLE db ( ix INTEGER PRIMARY KEY, bytes BLOB );"
    Statement
sput_ <- Database -> Text -> IO Statement
Sql.prepare Database
db Text
"INSERT INTO db VALUES (?1,?2);"
    Statement
sget_ <- Database -> Text -> IO Statement
Sql.prepare Database
db Text
"SELECT bytes FROM db WHERE ix = ?1;"
    Statement
sdelete_ <- Database -> Text -> IO Statement
Sql.prepare Database
db Text
"DELETE FROM db WHERE ix = ?1;"
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SqlCmds{Statement
sput_ :: Statement
sput_ :: Statement
sput_,Statement
sget_ :: Statement
sget_ :: Statement
sget_,Statement
sdelete_ :: Statement
sdelete_ :: Statement
sdelete_}

finalizeSql :: SqlCmds -> IO ()
finalizeSql :: SqlCmds -> IO ()
finalizeSql SqlCmds{Statement
sput_ :: Statement
sput_ :: SqlCmds -> Statement
sput_,Statement
sget_ :: Statement
sget_ :: SqlCmds -> Statement
sget_,Statement
sdelete_ :: Statement
sdelete_ :: SqlCmds -> Statement
sdelete_} = do
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Statement -> IO ()
Sql.finalize [Statement
sput_, Statement
sget_, Statement
sdelete_]

cmdSql :: SqlCmds -> Cmd STM.TMVar -> IO ()
cmdSql :: SqlCmds -> Cmd TMVar -> IO ()
cmdSql SqlCmds{Statement
sput_ :: Statement
sput_ :: SqlCmds -> Statement
sput_,Statement
sget_ :: Statement
sget_ :: SqlCmds -> Statement
sget_,Statement
sdelete_ :: Statement
sdelete_ :: SqlCmds -> Statement
sdelete_} Cmd TMVar
cmd = case Cmd TMVar
cmd of
    Put Int64
index ByteString
bytes -> do
        let s :: Statement
s = Statement
sput_
        Statement -> [SQLData] -> IO ()
Sql.bind Statement
s [Int64 -> SQLData
Sql.SQLInteger Int64
index, ByteString -> SQLData
Sql.SQLBlob ByteString
bytes]
        StepResult
_ <- Statement -> IO StepResult
Sql.stepNoCB Statement
s
        Statement -> IO ()
reset Statement
s

    Get Int64
index TMVar ByteString
k -> do
        let s :: Statement
s = Statement
sget_
        Statement -> [SQLData] -> IO ()
Sql.bind Statement
s [Int64 -> SQLData
Sql.SQLInteger Int64
index]
        StepResult
Sql.Row <- Statement -> IO StepResult
Sql.stepNoCB Statement
s
        ByteString
bytes <- Statement -> ColumnIndex -> IO ByteString
Sql.columnBlob Statement
s ColumnIndex
0
        Statement -> IO ()
reset Statement
s
        forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> a -> STM ()
STM.putTMVar TMVar ByteString
k ByteString
bytes

    Delete Int64
index -> do
        let s :: Statement
s = Statement
sdelete_
        Statement -> [SQLData] -> IO ()
Sql.bind Statement
s [Int64 -> SQLData
Sql.SQLInteger Int64
index]
        StepResult
Sql.Done <- Statement -> IO StepResult
Sql.stepNoCB Statement
s
        Statement -> IO ()
reset Statement
s

  where
    reset :: Statement -> IO ()
reset Statement
s = Statement -> IO ()
Sql.reset Statement
s forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Statement -> IO ()
Sql.clearBindings Statement
s