{-# LANGUAGE BlockArguments #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} module Main where import Control.Exception import Data.Int import Data.Text (Text) import GHC.Generics (Generic) import qualified Hasql.Connection as Hasql import Hasql.Decoders (column) import Hasql.Interpolate import Hasql.Interpolate.Internal.TH import qualified Hasql.Session as Hasql import Language.Haskell.TH import Test.Tasty import Test.Tasty.HUnit import Prelude main :: IO () main = defaultMain tests tests :: TestTree tests = testGroup "Tests" [ parserTests, executionTests ] parserTests :: TestTree parserTests = testGroup "parser" [ testCase "quote" testParseQuotes, testCase "comment" testParseComment, testCase "param" testParseParam ] executionTests :: TestTree executionTests = testGroup "execution" [ testCase "basic" testBasic, testCase "composite test" testComposite, testCase "row" testRow, testCase "row generic" testRowGeneric ] testParseQuotes :: IO () testParseQuotes = do let expected = SqlExpr expectedSqlExpr [] [] 0 expectedSqlExpr = [ Sbe'Quote "#{bonk}", Sbe'Sql " ", Sbe'Quote "^{z''onk}", Sbe'Sql " ", Sbe'Ident "#{k\"\"onk}", Sbe'Sql " ", Sbe'DollarQuote "tag" "#{kiplonk}", Sbe'Sql " ", Sbe'Cquote "newline \\n escaped \\'string\\'" ] parseSqlExpr "'#{bonk}' '^{z''onk}' \"#{k\"\"onk}\" $tag$#{kiplonk}$tag$ E'newline \\n escaped \\'string\\''" @?= Right expected testParseComment :: IO () testParseComment = do let expected = SqlExpr expectedSqlExpr [] [] 0 expectedSqlExpr = [ Sbe'Sql "content ", Sbe'Sql "\nhello ", Sbe'Sql " world\n", Sbe'Sql " end\n" ] inputStr = unlines [ "content -- trailing comment", "hello /* / comment * */ world", "/* comment", "blerg /* nested comment */", "*/ end" ] parseSqlExpr inputStr @?= Right expected testParseParam :: IO () testParseParam = do let expected = SqlExpr [Sbe'Param, Sbe'Sql " ", Sbe'Param] [Pe'Exp (VarE (mkName "x")), Pe'Exp (LitE (IntegerL 2))] [] 0 parseSqlExpr "#{x} #{2}" @?= Right expected testBasic :: IO () testBasic = do withLocalTransaction \conn -> do let relation :: [(Int64, Bool, Int64)] relation = [ (0, True, 5), (1, True, 6), (2, False, 7) ] createRes <- run conn [sql| create table hasql_interpolate_test(x int8, y boolean, z int8) |] createRes @?= () RowsAffected insertRes <- run conn [sql| insert into hasql_interpolate_test (x,y,z) select * from ^{toTable relation} |] insertRes @?= 3 selectRes <- run conn [sql| select x, y, z from hasql_interpolate_test where x > #{0 :: Int64} order by x |] selectRes @?= filter (\(x, _, _) -> x > 0) relation testComposite :: IO () testComposite = do withLocalTransaction \conn -> do let expected = [Point 0 0, Point 1 1] res <- run conn [sql| select * from (values (row(0,0)), (row(1,1)) ) as t |] res @?= map OneColumn expected data T = T Int64 Bool Text deriving stock (Eq, Show) instance DecodeRow T where decodeRow = T <$> column decodeField <*> column decodeField <*> column decodeField testRow :: IO () testRow = do withLocalTransaction \conn -> do let expected = [T 0 True "foo", T 1 False "bar"] res <- run conn [sql| select * from (values (0,true,'foo'), (1,false,'bar') ) as t |] res @?= expected testRowGeneric :: IO () testRowGeneric = do withLocalTransaction \conn -> do let expected = [Point 0 0, Point 1 1] res <- run conn [sql| select * from (values (0,0), (1,1) ) as t |] res @?= expected withLocalTransaction :: (Hasql.Connection -> IO a) -> IO a withLocalTransaction k = bracket (either (fail . show) pure =<< Hasql.acquire "host=localhost") Hasql.release \conn -> do let beginTrans = do Hasql.run (Hasql.statement () (interp False [sql| begin |])) conn >>= \case Left err -> fail (show err) Right () -> pure () rollbackTrans = do Hasql.run (Hasql.statement () (interp False [sql| rollback |])) conn >>= \case Left err -> fail (show err) Right () -> pure () bracket beginTrans (\() -> rollbackTrans) \() -> k conn run :: DecodeResult a => Hasql.Connection -> Sql -> IO a run conn stmt = do Hasql.run (Hasql.statement () (interp False stmt)) conn >>= \case Left err -> assertFailure ("Hasql statement unexpectedly failed with error: " <> show err) Right x -> pure x data Point = Point Int64 Int64 deriving stock (Generic, Eq, Show) deriving (DecodeValue) via CompositeValue Point deriving anyclass (DecodeRow)