{-# LANGUAGE GADTs #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE OverloadedStrings #-} module Database.TDS.Query where import qualified Database.TDS.Proto as Proto import Database.TDS.Types import Control.Exception ( Exception, SomeException(..) , bracket, onException , throwIO, catch, mask ) import Data.Bifunctor import Data.Bits import qualified Data.ByteString.Streaming as SBS import qualified Data.ByteString.Internal as IBS import Data.Foldable import Data.Maybe import Data.Ratio import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified Data.Text.IO as T import Data.Word import Debug.Trace import Foreign.ForeignPtr import Foreign.Ptr import Foreign.Storable import Numeric import System.IO import qualified Streaming as S import qualified Streaming.Prelude as S newtype MsSQLRuntimeError = MsSQLRuntimeError T.Text deriving Show instance Exception MsSQLRuntimeError withTransaction :: Connection -> IO a -> IO a withTransaction conn go = mask $ \unmask -> do beginTransaction conn res <- (unmask (fmap Right go <* commitTransaction conn) `catch` (\e@(MsSQLRuntimeError {}) -> pure (Left e))) `catch` (\e@(SomeException {}) -> hPutStrLn stderr ("Caught exception " ++ show e) >> rollbackTransaction conn >> throwIO e) case res of Right x -> pure x Left err -> throwIO err execNoRows :: Connection -> T.Text -> IO () execNoRows tds q = do let sql = Proto.mkPacket (Proto.mkPacketHeader Proto.SQLBatch mempty) q getRes <- tdsSendPacket tds sql ResponseResultReceived (Proto.RowResults rows) <- getRes res <- S.inspect rows case res of Left () -> pure () Right {} -> throwIO (MsSQLRuntimeError (T.pack ("Expected no rows for statement: " ++ T.unpack q))) beginTransaction, commitTransaction, rollbackTransaction :: Connection -> IO () beginTransaction tds = execNoRows tds "BEGIN TRANSACTION" commitTransaction tds = execNoRows tds "COMMIT TRANSACTION" rollbackTransaction tds = execNoRows tds "ROLLBACK TRANSACTION" query :: Connection -> T.Text -> IO () query conn sqlTxt = do let sqlBatch = Proto.mkPacket (Proto.mkPacketHeader Proto.SQLBatch mempty) sqlTxt getRes <- tdsSendPacket conn sqlBatch ResponseResultReceived (Proto.RowResults rows) <- getRes S.mapsM_ (\(S.Compose (cols S.:> rows')) -> do S.mapsM_ (\row -> S.mapsM_ (\(Proto.RawColumn columnType columnData next) -> do bs' <- S.liftIO (printColumn columnType columnData) pure (next bs')) row) rows') rows pure () take8 :: Monad m => SBS.ByteString m () -> m (Word8, SBS.ByteString m ()) take8 bs = do r <- SBS.uncons bs case r of Nothing -> fail "take8: no more bytes" Just (r, bs') -> pure (r, bs') take16LE :: S.MonadIO m => SBS.ByteString m () -> m (Word16, SBS.ByteString m ()) take16LE bs = do a S.:> bs' <- SBS.toStrict (SBS.splitAt 2 bs) let (fPtr, ofs, _) = IBS.toForeignPtr a x <- S.liftIO . withForeignPtr fPtr $ \ptr -> peek (ptr `plusPtr` ofs) pure (x, bs') -- (lo, bs') <- first fromIntegral <$> take8 bs -- (hi, bs'') <- first fromIntegral <$> take8 bs' -- pure ((hi `shiftL` 8) .|. lo, bs'') takeLength :: S.MonadIO m => Proto.TypeLen -> SBS.ByteString m () -> m (Word16, SBS.ByteString m ()) takeLength Proto.ShortLen = take16LE takeLength Proto.ByteLen = fmap (first fromIntegral) . take16LE takeLE :: Monad m => Int -> SBS.ByteString m () -> m (Integer, SBS.ByteString m ()) takeLE n bs = foldlM (\(!a, bs') shift -> do (x, bs'') <- take8 bs' pure (a .|. (fromIntegral x `shiftL` shift), bs'')) (0, bs) (fmap (*8) [0..n-1]) printNumeric :: String -> Bool -> Word8 -> Proto.PrecScale -> SBS.ByteString IO () -> IO (SBS.ByteString IO ()) printNumeric s True sz precScale d = do (realSz, d') <- take8 d if realSz == 0 then do putStrLn (s ++ "(" ++ show sz ++ ", " ++ show precScale ++ "): (NULL)") pure d' else printNumeric s False realSz precScale d' printNumeric s False sz precScale@(Proto.PrecScale p scale) d = do putStr (s ++ "(" ++ show sz ++ ", " ++ show precScale ++ "):") (sign, d') <- take8 d let intSz | p <= 9 = 4 | p <= 19 = 8 | p <= 28 = 12 | otherwise = 16 (num, d'') <- takeLE (fromIntegral sz - 1) d' let res = num % (10 ^ fromIntegral scale) res' :: Rational res' = if sign == 0 then negate res else res putStrLn (show res') pure d'' printColumn :: Proto.ColumnData -> SBS.ByteString IO () -> IO (SBS.ByteString IO ()) printColumn ty d = case Proto.cdBaseTypeInfo ty of Proto.VarcharType typeLen Proto.NationalChar len coll -> do -- putStr ("NVARCHAR(" ++ show len ++ ") COLLATION " ++ show coll ++ ": ") (len, d') <- takeLength typeLen d if len == 0xFFFF then do putStrLn "(NULL)" pure d' else do let d'' = SBS.splitAt (fromIntegral len) d' byteData S.:> d''' <- SBS.toStrict d'' T.putStrLn (TE.decodeUtf16LE byteData) pure d''' Proto.IntNType False 4 -> do (n, d') <-takeLE 4 d -- putStrLn ("INT: " ++ show n) pure d' Proto.IntNType False bytes -> do (n, d') <- takeLE (fromIntegral bytes) d -- putStrLn ("INT(" ++ show bytes ++ "): " ++ show n) pure d' Proto.IntNType True bytes -> do (realWidth, d') <- take8 d if realWidth == 0 then do putStrLn ("INT(" ++ show bytes ++ "): (NULL)") pure d' else do (n, d'') <- takeLE (fromIntegral bytes) d' putStrLn ("INT(" ++ show bytes ++ "): " ++ show n) pure d'' Proto.DecimalNType nullable sz precScale -> printNumeric "DECIMAL" nullable sz precScale d Proto.NumericNType nullable sz precScale -> printNumeric "NUMERIC" nullable sz precScale d _ -> fail ("Can't print data of type " ++ show ty)