-- Provides a more convenient way to use SQLCLI API from Haskell
module SQL.CLI.Utils where

import Prelude hiding (fail)

import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Fail (MonadFail, fail)

import System.IO (hPutStrLn, stderr)

import Foreign.C.String (withCStringLen, peekCString, peekCStringLen, CStringLen, CString)
import Foreign.Marshal.Alloc (alloca, allocaBytes)
import Foreign.Storable (Storable, peek, peekElemOff, sizeOf)
import Foreign.Ptr (nullPtr, castPtr, Ptr)

import Data.Maybe (catMaybes, maybe)

import Control.Monad.Trans.Maybe (MaybeT, runMaybeT)
import Control.Monad.Trans.Reader (ReaderT, asks)

import SQL.CLI (sqlallochandle,
                sqlfreehandle,
                sqlgetdiagfield,
                sqlgetdiagrec,
                sqlconnect,
                sqldescribecol,
                sqldisconnect,
                sqlexecdirect,
                sqlbindcol,
                sqlfetch,
                sqlgetdata,
                sqltables,
                sqlcolumns,
                sql_handle_env,
                sql_handle_dbc,
                sql_handle_stmt,
                sql_null_handle,
                sql_error,
                sql_diag_number,
                sql_success,
                sql_success_with_info,
                sql_invalid_handle,
                sql_no_data,
                sql_need_data,
                sql_max_message_length,
                sql_null_data,
                sql_char,
                sql_smallint,
                sql_integer,
                sql_no_nulls,
                SQLSMALLINT,
                SQLINTEGER,
                SQLHENV,
                SQLHDBC,
                SQLHSTMT,
                SQLPOINTER)

-- | configuration values dependent on the actual CLI implementation
data SQLConfig = SQLConfig {
  sql_cli_flds_table_cat        :: SQLSMALLINT, -- ^ position of TABLE_CAT column in the resultset returned by Columns API call
  sql_cli_flds_table_schem      :: SQLSMALLINT, -- ^ position of TABLE_SCHEM column in the resultset returned by Columns API call
  sql_cli_flds_table_name       :: SQLSMALLINT, -- ^ position of TABLE_NAME column in the resultset returned by Columns API call
  sql_cli_flds_column_name      :: SQLSMALLINT, -- ^ position of COLUMN_NAME column in the resultset returned by Columns API call
  sql_cli_flds_data_type        :: SQLSMALLINT, -- ^ position of DATA_TYPE column in the resultset returned by Columns API call
  sql_cli_flds_type_name        :: SQLSMALLINT, -- ^ position of TYPE_NAME column in the resultset returned by Columns API call
  sql_cli_flds_column_size      :: SQLSMALLINT, -- ^ position of COLUMN_SIZE column in the resultset returned by Columns API call
  sql_cli_flds_buffer_length    :: SQLSMALLINT, -- ^ position of BUFFER_LENGTH column in the resultset returned by Columns API call
  sql_cli_flds_decimal_digits   :: SQLSMALLINT, -- ^ position of DECIMAL_DIGITS column in the resultset returned by Columns API call
  sql_cli_flds_num_prec_radix   :: SQLSMALLINT, -- ^ position of NUM_PREC_RADIX column in the resultset returned by Columns API call
  sql_cli_flds_nullable         :: SQLSMALLINT, -- ^ position of NULLABLE column in the resultset returned by Columns API call
  sql_cli_flds_remarks          :: SQLSMALLINT, -- ^ position of REMARKS column in the resultset returned by Columns API call
  sql_cli_flds_column_def       :: SQLSMALLINT, -- ^ position of COLUMN_DEF column in the resultset returned by Columns API call
  sql_cli_flds_datetime_code    :: SQLSMALLINT, -- ^ position of DATETIME_CODE column in the resultset returned by Columns API call
  sql_cli_flds_char_octet_length :: SQLSMALLINT, -- ^ position of CHAR_OCTET_LENGTH column in the resultset returned by Columns API call
  sql_cli_flds_ordinal_position :: SQLSMALLINT, -- ^ position of ORDINAL_POSITION column in the resultset returned by Columns API call
  sql_cli_flds_is_nullable      :: SQLSMALLINT  -- ^ position of IS_NULLABLE column in the resultset returned by Columns API call
  }


-- | information about column in the database; the meaning of fields is detailed
-- in the SQL CLI specification in the documenation of Columns API call
data ColumnInfo = ColumnInfo {
  ci_TableCat           :: Maybe String,
  ci_TableSchem         :: String,
  ci_TableName          :: String,
  ci_ColumnName         :: String,
  ci_DataType           :: SQLSMALLINT,
  ci_TypeName           :: String,
  ci_ColumnSize         :: Maybe SQLINTEGER,
  ci_BufferLength       :: Maybe SQLINTEGER,
  ci_DecimalDigits      :: Maybe SQLSMALLINT,
  ci_NumPrecRadix       :: Maybe SQLSMALLINT,
  ci_Nullable           :: SQLSMALLINT,
  ci_Remarks            :: Maybe String,
  ci_ColumnDef          :: Maybe String,
  ci_DatetimeCode       :: Maybe SQLINTEGER,
  ci_CharOctetLength    :: Maybe SQLINTEGER,
  ci_OrdinalPosition    :: SQLINTEGER,
  ci_IsNullable         :: Maybe String }
  deriving Show

-- | Read columns information for a given table on a database connection. It returns a 'ReaderT' value
-- that will get implementation dependent fieled numbers in the result set returned by Columns API call
-- from a 'SQLConfig' value.
collectColumnsInfo :: (MonadIO m, MonadFail m) => SQLHDBC -> String -> ReaderT SQLConfig m [ColumnInfo]
collectColumnsInfo hdbc tableName = do
  hstmt <- allocHandle sql_handle_stmt hdbc
  columns hstmt Nothing Nothing (Just tableName) Nothing
  liftIO $ hPutStrLn stderr "Fetching columns info..."
  
  table_cat_fld         <- asks sql_cli_flds_table_cat        
  table_schem_fld       <- asks sql_cli_flds_table_schem      
  table_name_fld        <- asks sql_cli_flds_table_name       
  column_name_fld       <- asks sql_cli_flds_column_name      
  data_type_fld         <- asks sql_cli_flds_data_type        
  type_name_fld         <- asks sql_cli_flds_type_name        
  column_size_fld       <- asks sql_cli_flds_column_size      
  buffer_length_fld     <- asks sql_cli_flds_buffer_length    
  decimal_digits_fld    <- asks sql_cli_flds_decimal_digits   
  num_prec_radix_fld    <- asks sql_cli_flds_num_prec_radix   
  nullable_fld          <- asks sql_cli_flds_nullable         
  remarks_fld           <- asks sql_cli_flds_remarks          
  column_def_fld        <- asks sql_cli_flds_column_def       
  datetime_code_fld     <- asks sql_cli_flds_datetime_code    
  char_octet_length_fld <- asks sql_cli_flds_char_octet_length
  ordinal_position_fld  <- asks sql_cli_flds_ordinal_position 
  is_nullable_fld       <- asks sql_cli_flds_is_nullable      

  
  cols <- liftIO $
    allocaBytes 129
    (\ p_table_cat ->
       alloca
       (\ p_table_cat_ind ->
       allocaBytes 129
       (\ p_table_schem ->
          allocaBytes 129
          (\ p_table_name ->
             allocaBytes 129
             (\ p_column_name ->
                alloca
                (\ p_data_type ->
                    allocaBytes 129
                    (\ p_type_name ->
                       alloca
                       (\ p_column_size ->
                          alloca
                          (\ p_column_size_ind ->
                            alloca
                            ( \ p_buffer_length ->
                                alloca
                                (\ p_buffer_length_ind ->
                                   alloca
                                   (\ p_decimal_digits ->
                                      alloca
                                      (\ p_decimal_digits_ind ->
                                         alloca
                                         (\ p_num_prec_radix ->
                                            alloca
                                            (\ p_num_prec_radix_ind ->
                                               alloca
                                               (\ p_nullable ->
                                                  allocaBytes 255
                                                  (\ p_remarks ->
                                                     alloca
                                                     (\ p_remarks_ind ->
                                                        allocaBytes 255
                                                        (\ p_column_def ->
                                                           alloca
                                                           (\ p_column_def_ind ->
                                                              alloca
                                                              (\ p_datetime_code ->
                                                                 alloca
                                                                 (\ p_datetime_code_ind ->
                                                                    alloca
                                                                    (\ p_char_octet_length ->
                                                                       alloca
                                                                       (\ p_char_octet_length_ind ->
                                                                          alloca
                                                                          (\ p_ordinal_position ->
                                                                             allocaBytes 255
                                                                             (\ p_is_nullable ->
                                                                                alloca
                                                                                (\ p_is_nullable_ind ->
                                                                                   let readColumnInfo :: [ColumnInfo] -> MaybeT IO [ColumnInfo]
                                                                                       readColumnInfo cols' = do
                                                                                         col <- liftIO $ ColumnInfo
                                                                                           <$> (peekMaybeTextCol   p_table_cat             p_table_cat_ind)
                                                                                           <*> (peekCString        p_table_schem)
                                                                                           <*> (peekCString        p_table_name)
                                                                                           <*> (peekCString        p_column_name)
                                                                                           <*> (peek               p_data_type)
                                                                                           <*> (peekCString        p_type_name)
                                                                                           <*> (peekMaybeCol       p_column_size           p_column_size_ind)
                                                                                           <*> (peekMaybeCol       p_buffer_length         p_buffer_length_ind)
                                                                                           <*> (peekMaybeCol       p_decimal_digits        p_decimal_digits_ind)
                                                                                           <*> (peekMaybeCol       p_num_prec_radix        p_num_prec_radix_ind)
                                                                                           <*> (peek               p_nullable)
                                                                                           <*> (peekMaybeTextCol   p_remarks               p_remarks_ind)
                                                                                           <*> (peekMaybeTextCol   p_column_def            p_column_def_ind)
                                                                                           <*> (peekMaybeCol       p_datetime_code         p_datetime_code_ind)
                                                                                           <*> (peekMaybeCol       p_char_octet_length     p_char_octet_length_ind)
                                                                                           <*> (peek               p_ordinal_position)
                                                                                           <*> (peekMaybeTextCol   p_is_nullable           p_is_nullable_ind)
                                                                                         return (col:cols')

                                                                                   in runMaybeT $ do
                                                                                     bindVarcharCol  hstmt  table_cat_fld         p_table_cat            129 p_table_cat_ind
                                                                                     bindVarcharCol  hstmt  table_schem_fld       p_table_schem          129 nullPtr
                                                                                     bindVarcharCol  hstmt  table_name_fld        p_table_name           129 nullPtr
                                                                                     bindVarcharCol  hstmt  column_name_fld       p_column_name          129 nullPtr
                                                                                     bindSmallIntCol hstmt  data_type_fld         p_data_type                nullPtr
                                                                                     bindVarcharCol  hstmt  type_name_fld         p_type_name            129 nullPtr
                                                                                     bindIntegerCol  hstmt  column_size_fld       p_column_size              p_column_size_ind
                                                                                     bindIntegerCol  hstmt  buffer_length_fld     p_buffer_length            p_buffer_length_ind
                                                                                     bindSmallIntCol hstmt  decimal_digits_fld    p_decimal_digits           p_decimal_digits_ind
                                                                                     bindSmallIntCol hstmt  num_prec_radix_fld    p_num_prec_radix           p_num_prec_radix_ind
                                                                                     bindSmallIntCol hstmt  nullable_fld          p_nullable                 nullPtr
                                                                                     bindVarcharCol  hstmt  remarks_fld           p_remarks              255 p_remarks_ind
                                                                                     bindVarcharCol  hstmt  column_def_fld        p_column_def           255 p_column_def_ind
                                                                                     bindIntegerCol  hstmt  datetime_code_fld     p_datetime_code            p_datetime_code_ind
                                                                                     bindIntegerCol  hstmt  char_octet_length_fld p_char_octet_length        p_char_octet_length_ind
                                                                                     bindIntegerCol  hstmt  ordinal_position_fld  p_ordinal_position         nullPtr
                                                                                     bindVarcharCol  hstmt  is_nullable_fld       p_is_nullable          255 p_is_nullable_ind
                                                                                     
                                                                                     forAllRecords hstmt readColumnInfo [])))))))))))))))))))))))))))
  liftIO $ freeHandle sql_handle_stmt hstmt
  maybe (fail "collectColumnsInfo failed") return cols



-- | Checks if a table exists on the current connection.
tableExists :: (MonadIO m, MonadFail m) => SQLHDBC -> String -> m Bool
tableExists hdbc tableName = do
  tables_stmt <- allocHandle sql_handle_stmt hdbc
  tables tables_stmt Nothing Nothing (Just tableName) Nothing
  exists <- fetch tables_stmt
  liftIO $ freeHandle sql_handle_stmt tables_stmt
  return exists
  

-- SQLCLI wrappers

-- | concise information about a column of a result set, mapping
-- the result of SQL CLI API call DescribeCol
data ConciseColInfo = ConciseColInfo {
  cci_ColumnName        :: String,
  cci_DataType          :: SQLSMALLINT,
  cci_ColumnSize        :: SQLINTEGER,
  cci_DecimalDigits     :: SQLSMALLINT,
  cci_Nullable          :: Bool }

-- | wrapper for DescribeCol SQL CLI API call
describeCol :: (MonadIO m, MonadFail m) => SQLHSTMT -> SQLSMALLINT -> m ConciseColInfo
describeCol hstmt colnum = do
  info <- liftIO $ allocaBytes 255
    (\ p_columnName ->
        alloca
        (\ p_nameLength ->
           alloca
           (\ p_dataType ->
              alloca
              (\ p_columnSize ->
                 alloca
                 (\ p_decimalDigits ->
                    alloca
                    (\ p_nullable -> do
                        result <- sqldescribecol hstmt colnum p_columnName 255 p_nameLength p_dataType p_columnSize p_decimalDigits p_nullable
                        let readInfo = Just <$> do
                              nameLength <- peek p_nameLength
                              nullable   <- peek p_nullable
                              ConciseColInfo
                                <$> peekCStringLen (castPtr p_columnName, fromIntegral nameLength)
                                <*> peek p_dataType
                                <*> peek p_columnSize
                                <*> peek p_decimalDigits
                                <*> (return $ if nullable == sql_no_nulls then False else True)
                        case result of
                          x | x == sql_success -> readInfo
                            | x == sql_success_with_info -> do
                                hPutStrLn stderr "More information returned by DescribeCol"
                                displayDiagInfo sql_handle_stmt hstmt
                                readInfo
                            | x == sql_error -> do
                                hPutStrLn stderr "Error calling DescribeCol"
                                displayDiagInfo sql_handle_stmt hstmt
                                return Nothing
                            | x == sql_invalid_handle -> do
                                hPutStrLn stderr "Invalid handle calling DescribeCol"
                                return Nothing
                            | otherwise -> do
                                hPutStrLn stderr $ "Unexpected result returned by the call to DescribeCol: " ++ (show x)
                                displayDiagInfo sql_handle_stmt hstmt
                                return Nothing))))))                                
  maybe (fail $ "describeCol " ++ (show colnum) ++ " failed") return info


-- | wrapper for SQL CLI Columns API call
columns :: (MonadIO m, MonadFail m) => SQLHSTMT -> Maybe String -> Maybe String -> Maybe String -> Maybe String -> m ()
columns hstmt catalogName schemaName tableName columnName = do
  result <- liftIO $ withMaybeCStringLen catalogName
    (\ (p_catalogName, catalogNameLen) ->
        withMaybeCStringLen schemaName
        (\ (p_schemaName, schemaNameLen) ->
            withMaybeCStringLen tableName
            (\ (p_tableName, tableNameLen) ->
                withMaybeCStringLen columnName
                (\ (p_columnName, columnNameLen) ->
                    sqlcolumns hstmt
                    (castPtr p_catalogName) (fromIntegral catalogNameLen)
                    (castPtr p_schemaName)  (fromIntegral schemaNameLen)
                    (castPtr p_tableName)   (fromIntegral tableNameLen)
                    (castPtr p_columnName)  (fromIntegral columnNameLen)))))
  case result of
    x | x == sql_success -> return ()
      | x == sql_error -> do
          liftIO $ hPutStrLn stderr "Error calling Columns"
          liftIO $ displayDiagInfo sql_handle_stmt hstmt
          fail "Columns failed"
      | x == sql_success_with_info -> do
          liftIO $ hPutStrLn stderr "Columns returned more info"
          liftIO $ displayDiagInfo sql_handle_stmt hstmt
      | x == sql_invalid_handle -> do
          liftIO $ hPutStrLn stderr "Invalid statement handle passed to Columns call"
          fail "Columns failed"
      | otherwise -> do
          liftIO $ hPutStrLn stderr "Unexpected return code returned by call to Columns. Trying to display diagnostic info:"
          liftIO $ displayDiagInfo sql_handle_stmt hstmt
          fail "Columns failed"
          
-- | wrapper for SQL CLI Tables API call
tables :: (MonadIO m, MonadFail m) => SQLHSTMT -> Maybe String -> Maybe String -> Maybe String -> Maybe String -> m ()
tables hstmt catalogName schemaName tableName tableType = do
  result <- liftIO $
    withMaybeCStringLen catalogName
    (\ (p_catalogName, catalogNameLen) ->
        withMaybeCStringLen schemaName
        ( \ (p_schemaName, schemaNameLen) ->
            withMaybeCStringLen tableName
            ( \ (p_tableName, tableNameLen) ->
                withMaybeCStringLen tableType
                ( \ (p_tableType, tableTypeLen) ->
                    sqltables hstmt
                    (castPtr p_catalogName) (fromIntegral catalogNameLen)
                    (castPtr p_schemaName)  (fromIntegral schemaNameLen)
                    (castPtr p_tableName)   (fromIntegral tableNameLen)
                    (castPtr p_tableType)   (fromIntegral tableTypeLen)))))
  case result of
    x | x == sql_success -> return ()
      | x == sql_error -> do
          liftIO $ do
            hPutStrLn stderr "Error calling Tables"
            displayDiagInfo sql_handle_stmt hstmt
          fail "Tables failed"
      | x == sql_success_with_info -> do
          liftIO $ do
            hPutStrLn stderr "Tables returned more info"
            displayDiagInfo sql_handle_stmt hstmt
      | x == sql_invalid_handle -> do
          liftIO $ hPutStrLn stderr "Invalid handle calling Tables"
          fail "Tables failed"
      | otherwise -> do
          liftIO $ do
            hPutStrLn stderr $ "Tables returned unexpected result: " ++ (show x)
            displayDiagInfo sql_handle_stmt hstmt
          fail "Tables failed"

-- | applies a function through all the records in a statment, passing an accumulator value and
-- combining the actions returned by the function
forAllRecords :: (MonadIO m, MonadFail m) => SQLHSTMT -> (a -> m a) -> a -> m a
forAllRecords stmt f accum = fetchAndRun stmt (f accum >>= (\ accum' -> forAllRecords stmt f accum')) (return accum)

-- | Read data from a column and checks the diagnostics, returning a 'True' or 'False' value inside a monadic action.
-- It returns 'True' if more data is available for read, and 'False' otherwise. It fails in 'MaybeT' 'IO' monad if
-- an error occured. It displays the diagnostics on the error on the standard error.
getData :: (MonadIO m, MonadFail m) => SQLHSTMT -> SQLSMALLINT -> SQLSMALLINT -> SQLPOINTER -> SQLINTEGER -> Ptr SQLINTEGER -> m Bool
getData hstmt colNum targetType p_buf bufLen p_lenOrInd = getDataAndRun hstmt colNum targetType p_buf bufLen p_lenOrInd (return True) (return False)

-- | Read data available in a column of a fetched database record inside a monadic. It fails if
-- an error occurs, displaying the diagnostics on the standard error. It receives 2 monadic actions
-- parameters:
-- 
-- *   more
-- *   end
-- 
-- It executes the more action if there is more data available and it executes the end action if all
-- data in the column has been read.
getDataAndRun :: (MonadIO m, MonadFail m) => SQLHSTMT -> SQLSMALLINT -> SQLSMALLINT -> SQLPOINTER -> SQLINTEGER -> Ptr SQLINTEGER -> m a -> m a -> m a
getDataAndRun hstmt colNum targetType p_buf bufLen p_lenOrInd more end = do
  result <- liftIO $ sqlgetdata hstmt colNum targetType p_buf bufLen p_lenOrInd
  case result of
    x | x == sql_success -> end
      | x == sql_invalid_handle -> do
          liftIO $ hPutStrLn stderr "Invalid handle when calling GetData"
          fail "GetData failed"
      | x == sql_error -> do
          liftIO $ do
            hPutStrLn stderr "Error calling GetData"
            displayDiagInfo sql_handle_stmt hstmt
          fail "GetData failed"
      | x == sql_no_data -> do
          liftIO $ hPutStrLn stderr "GetData -> no data available"
          fail "GetData failed"
      | x == sql_success_with_info -> do
          liftIO $ do
            hPutStrLn stderr "GetData returned more info"
            displayDiagInfo sql_handle_stmt hstmt
          moreData <- isMoreData
          if moreData
            then more
            else end
      | otherwise -> do
          liftIO $ do
            hPutStrLn stderr $ "GetData returned unexpected result: " ++ (show x)
            displayDiagInfo sql_handle_stmt hstmt
          fail "GetData failed"
    where isMoreData :: (MonadIO m, MonadFail m) => m Bool
          isMoreData = do
            recs <- getCountOfDiagRecs sql_handle_stmt hstmt
            diags <- liftIO $ fmap catMaybes $ sequence [runMaybeT $ getDiagRec sql_handle_stmt hstmt (fromIntegral i) | i <- [1..recs]]
            return $ any (\d -> sqlstate d == "01004") diags

-- | Create a monadic action to fetch the next record in an executed statement producing
-- 'True' if there are more records available or 'False' if all the records have been read.
-- 
-- If an error occurs, the monadic action fails, displaying the error diagnostics on 
-- the standard error.
fetch :: (MonadIO m, MonadFail m) => SQLHSTMT -> m Bool
fetch hstmt = fetchAndRun hstmt (return True) (return False)

-- | Create a monadic action to fetch the next record in an excecuted statement. It, then,
-- executes one of the 2 actions received as parameters -- more and end -- depending on if there
-- are more records available or if the last record has been fetched.
-- 
-- If an error occrus, the monadic action fails, displaying error diagnostics on the standard
-- error.
fetchAndRun :: (MonadIO m, MonadFail m) => SQLHSTMT -> m a -> m a -> m a
fetchAndRun hstmt action end = do
  result <- liftIO $ sqlfetch hstmt
  case result of
    x | x == sql_success -> action
      | x == sql_error -> do
          liftIO $ do
            hPutStrLn stderr "Error fetching record"
            displayDiagInfo sql_handle_stmt hstmt
          fail "Fetch failed"
      | x == sql_invalid_handle -> do
          liftIO $ hPutStrLn stderr "Invalid handle when fetching record"
          fail "Fetch failed"
      | x == sql_no_data -> do
          liftIO $ hPutStrLn stderr "All records have been fetched"
          end
      | x == sql_success_with_info -> do
          liftIO $ do
            hPutStrLn stderr "More diagnostic info returned for record"
            displayDiagInfo sql_handle_stmt hstmt
          action
      | otherwise -> do
          liftIO $ do
            hPutStrLn stderr $ "Fetch returned unexepected result: " ++ (show x)
            displayDiagInfo sql_handle_stmt hstmt
          fail "Fetch failed"

-- | helper function to bind a SMALLINT column
bindSmallIntCol :: (MonadIO m, MonadFail m) =>
  SQLHSTMT                      -- ^ statement handle
  -> SQLSMALLINT                -- ^ column number (starting with 1)
  -> Ptr SQLSMALLINT            -- ^ buffer to receive the value
  -> Ptr SQLINTEGER             -- ^ buffer to receive the indicator or length; it can be null
  -> m ()
bindSmallIntCol hstmt colNum p_buf p_ind = bindCol hstmt colNum sql_smallint (castPtr p_buf) (fromIntegral $ sizeOf (undefined :: SQLSMALLINT)) p_ind

-- | helper function to bind an INTEGER column
bindIntegerCol :: (MonadIO m, MonadFail m) =>
  SQLHSTMT                      -- ^ statement handle
  ->  SQLSMALLINT               -- ^ column number (starting with 1)
  -> Ptr SQLINTEGER             -- ^ buffer to receive the value
  -> Ptr SQLINTEGER             -- ^ buffer to receive the indicator or length; it can be null
  -> m ()
bindIntegerCol hstmt colNum p_buf p_ind = bindCol hstmt colNum sql_integer (castPtr p_buf) (fromIntegral $ sizeOf (undefined :: SQLINTEGER)) p_ind

-- | helper function to bind a VARCHAR column. The buffer length parameter must include the
-- NULL terminating character of the 'CString'.
bindVarcharCol :: (MonadIO m, MonadFail m) =>
  SQLHSTMT                      -- ^ statement handle
  -> SQLSMALLINT                -- ^ column number (starting with 1)
  -> CString                    -- ^ buffer to receive the null terminated text data
  -> SQLINTEGER                 -- ^ buffer length in bytes, including the null terminating character
  -> Ptr SQLINTEGER             -- ^ pointer to indicator or length; it can be null
  -> m ()
bindVarcharCol hstmt colNum p_buf buflen p_ind = bindCol hstmt colNum sql_char (castPtr p_buf) buflen p_ind

-- | wrapper for BindCol SQL CLI API call; if an error occurs
-- the computation is stopped and diagnostics are displayed on the standard error
bindCol :: (MonadIO m, MonadFail m) => SQLHSTMT -> SQLSMALLINT -> SQLSMALLINT -> SQLPOINTER -> SQLINTEGER -> Ptr SQLINTEGER -> m ()
bindCol hstmt colNum colType p_buf len_buf p_ind = do
  result <- liftIO $ sqlbindcol hstmt colNum colType p_buf len_buf p_ind
  case result of
    x | x == sql_success -> return ()
      | x == sql_error -> do
          liftIO $ do
            hPutStrLn stderr $ "Error binding column " ++ (show colNum)
            displayDiagInfo sql_handle_stmt hstmt
          fail "Binding column failed"
      | x == sql_success_with_info -> do
          liftIO $ do
            hPutStrLn stderr $ "Binding col " ++ (show colNum) ++ " returned warnings:"
            displayDiagInfo sql_handle_stmt hstmt
      | x == sql_invalid_handle -> do
          liftIO $ hPutStrLn stderr $ "Invalid handle when binding column " ++ (show colNum)
          fail "Binding column failed"
      | otherwise -> do
          liftIO $ do
            hPutStrLn stderr $ "Invalid result when binding column " ++ (show colNum)
            displayDiagInfo sql_handle_stmt hstmt
          fail "Biniding column failed"

-- | wrapper for SQL CLI ExecDirect API call; if an error occurs, the
-- computation exits displaying diagnostics on the standard error
execDirect :: (MonadIO m, MonadFail m) => SQLHSTMT -> String -> m ()
execDirect hstmt sqlstr = do
  result <- liftIO $ withCStringLen sqlstr
    (\(sql, sqlLen) -> sqlexecdirect hstmt (castPtr sql) (fromIntegral sqlLen))
  case result of
    x | x == sql_success -> liftIO $ hPutStrLn stderr "sql statement executed"
      | x == sql_success_with_info -> liftIO $ do
          hPutStrLn stderr "Execution of sql returned more info"
          displayDiagInfo sql_handle_stmt hstmt
      | x == sql_error -> do
          liftIO $ do
            hPutStrLn stderr "Execution of sql returned error"
            displayDiagInfo sql_handle_stmt hstmt
          fail "execute sql statement failed"
      | x == sql_invalid_handle -> do
          liftIO $ do
            hPutStrLn stderr "Invaild statement handle"
            displayDiagInfo sql_handle_stmt hstmt
          fail "execute statemnt failed"
      | x == sql_need_data -> do
          liftIO $ do
            hPutStrLn stderr "Unexpected NEED_DATA returned by statement execution"
            displayDiagInfo sql_handle_stmt hstmt
          fail "exeucute statement failed"
      | x == sql_no_data -> do
          liftIO $ hPutStrLn stderr "Execution of statement returned no data"
          fail "execute statement failed"
      | otherwise -> do
          liftIO $ do
            hPutStrLn stderr $ "Execute statement returned unexpected result: " ++ (show x)
            displayDiagInfo sql_handle_stmt hstmt
          fail "Execute statement failed"

-- | utility function that allocates a database connection handle and connects to
-- the database.
-- 
-- On success, the computation returns the handle to the database conncection.
-- 
-- On error, the computation exits, displaying diagnostics on the standard error.
connect :: (MonadIO m, MonadFail m) => SQLHENV -> String -> String -> String -> m SQLHDBC
connect henv server user pass = do
  liftIO $ hPutStrLn stderr $ "connect to server " ++ server
  hdbc <- allocHandle sql_handle_dbc henv
  result <- liftIO $ withCStringLen server
    (\(p_server, serverLen) -> withCStringLen user
      (\(p_user, userLen) -> withCStringLen pass
        (\(p_pass, passLen) -> sqlconnect hdbc (castPtr p_server) (fromIntegral serverLen) (castPtr p_user) (fromIntegral userLen) (castPtr p_pass) (fromIntegral passLen))))
  case result of
    x | x == sql_success -> return hdbc
      | x == sql_success_with_info -> do
          liftIO $ hPutStrLn stderr $ "connect to server " ++ server ++ " returned warnings:"
          liftIO $ displayDiagInfo sql_handle_dbc hdbc
          return hdbc
      | x == sql_error -> do
          liftIO $ hPutStrLn stderr $ "connection to server " ++ server ++ " failed:"
          liftIO $ displayDiagInfo sql_handle_dbc hdbc
          liftIO $ freeHandle sql_handle_dbc hdbc
          fail $ "connection to server " ++ server ++ " failed"
      | x == sql_invalid_handle -> do
          liftIO $ hPutStrLn stderr $ "connection to server " ++ server ++ " failed because of invalid handle"
          fail $ "connection to server " ++ server ++ " failed because of invalid handle"
      | otherwise -> do
          liftIO $ do
            hPutStrLn stderr $ "Unexpected response code got from connecting to server " ++ server ++ ": " ++ (show x)
            hPutStrLn stderr "Trying to extract diagnostic info:"
            displayDiagInfo sql_handle_dbc hdbc
            hPutStrLn stderr "Try call disconnect on the connection handle, to make sure we release all resources"
            disconnect hdbc
          fail $ "Unexpected response code got from connecting to server " ++ server ++ ": " ++ (show x)

-- | wrapper for SQL CLI Disconnect API call; displays diagnostics
-- on the standard error.
disconnect :: SQLHDBC -> IO ()
disconnect hdbc = do
  result <- sqldisconnect hdbc
  case result of
    x | x == sql_success -> return ()
      | x == sql_success_with_info -> do
          hPutStrLn stderr "disconnect returned warnings:"
          displayDiagInfo sql_handle_dbc hdbc
      | x == sql_error -> do
          hPutStrLn stderr "disconnect failed:"
          displayDiagInfo sql_handle_dbc hdbc
      | x == sql_invalid_handle -> do
          hPutStrLn stderr "disconnect failed because of invalid handle"
      | otherwise -> do
          hPutStrLn stderr "Unexpected response code got from Disconnect function"
          hPutStrLn stderr "Trying to extract diagnostic info:"
          displayDiagInfo sql_handle_dbc hdbc
  freeHandle sql_handle_dbc hdbc

-- | wrapper to SQL CLI AllocHandle API call; it displays diagnostics info
-- on the standard error and fails if the handle could not be allocated
allocHandle :: (MonadIO m, MonadFail m) => SQLSMALLINT -> SQLINTEGER -> m SQLINTEGER
allocHandle handleType handleParent = do
  handle <- liftIO $ alloca
    (\p_handle -> do
        result <- sqlallochandle handleType handleParent p_handle
        case result of
          x | x == sql_success -> Just <$> peek p_handle
            | x == sql_invalid_handle -> do
                hPutStrLn stderr "alloc handle failed because of invalid handler"
                displayDiagnostic
                return Nothing
            | x == sql_error -> do
                hPutStrLn stderr "alloc handle failed with error"
                displayDiagnostic
                return Nothing
            | otherwise -> do
                hPutStrLn stderr $ "alloc handle returned unexpected result" ++ (show x)
                displayDiagnostic
                return Nothing
                  where displayDiagnostic = if x == sql_handle_env
                                            then peek p_handle >>= displayDiagInfo x
                                            else displayDiagInfo handleType handleParent)
  maybe (fail $ "AllocHandle failed for handle type " ++ (show handleType)) return handle                       
                       
-- | wrapper for SQL CLI FreeHandle API call; it displays diagnostics
-- on the standard error; it does not fail
freeHandle :: SQLSMALLINT -> SQLINTEGER -> IO ()
freeHandle handleType handle = do
  result <- sqlfreehandle handleType handle
  case result of
    x | x == sql_success -> return ()
      | x == sql_error -> do
          hPutStrLn stderr $ "Error freeing handle of type " ++ (show handleType)
          displayDiagInfo handleType handle
      | x == sql_invalid_handle -> do
          hPutStrLn stderr "FreeHandle failed because of invalid handle"
          displayDiagInfo handleType handle
      | otherwise -> do
          hPutStrLn stderr $ "FreeHandle returned unexpected result " ++ (show x)
          hPutStrLn stderr "Trying to get diagnostic info on FreeHandle:"
          displayDiagInfo handleType handle

-- | create an 'IO' action that displays diagnostic records for a given handle on the
-- standard error; this action will not fail
displayDiagInfo :: SQLSMALLINT -> SQLINTEGER -> IO ()
displayDiagInfo handleType handle = (runMaybeT $ displayDiagInfo' handleType handle) >> return ()

-- | create a monadic action to display the diagnostic records for a given handle on the
-- standard error; it fails if an error occurs while reading diagnostic records.
displayDiagInfo' :: (MonadIO m, MonadFail m) => SQLSMALLINT -> SQLINTEGER -> m ()
displayDiagInfo' handleType handle = do
  recs <- getCountOfDiagRecs handleType handle
  liftIO $ hPutStrLn stderr $ "there "
    ++ (if recs /= 1 then "are " else "is ")
    ++ (show recs) ++ " diagnostic record"
    ++ (if recs /= 1 then "s" else "")
  let diags = [showDiag $ fromIntegral i | i <- [1..recs]]
      showDiag i = do
        liftIO $ hPutStrLn stderr $ "Diagnostic record " ++ (show i)
        r <- getDiagRec handleType handle i
        liftIO $ hPutStrLn stderr $ (show i) ++ ": " ++ (sqlstate r) ++ " - " ++ (show $ nativeError r) ++ " - " ++ (messageText r)
    in sequence_ diags

-- | create a monadic action to read the number of the diagnostic records for a given handle;
-- it fails if an error occurs and it displays diagnostics on standard error
getCountOfDiagRecs :: (MonadIO m, MonadFail m) => SQLSMALLINT -> SQLINTEGER -> m Int
getCountOfDiagRecs handleType handle = do
  recs <- liftIO $ alloca
    (\ptrRecs -> do
        result <- sqlgetdiagfield handleType handle 0 sql_diag_number (castPtr ptrRecs) 0 nullPtr
        case result of
          x | x == sql_success        -> Just <$> peek ptrRecs
            | x == sql_invalid_handle -> do
                hPutStrLn stderr "Count of diagnostic records could not be retrieved due to an invalid handle"
                return Nothing
            | x == sql_error          -> do
                hPutStrLn stderr "Count of diagnostic records could not be retrieved because wrong arguments were passed to GetDiagField function"
                return Nothing
            | x == sql_no_data        -> do
                hPutStrLn stderr "No diagnostic data available"
                return $ Just 0
            | otherwise               -> do
                hPutStrLn stderr $ "Getting the number of diagnostic records returned unexpected return code " ++ (show x)
                return Nothing)
  maybe (fail "GetDiagField api call failed when reading number of diagnostic errors") return recs
  
-- | information in a diagnostic record
data DiagRecord = DiagRecord {
  sqlstate      :: String,
  nativeError   :: SQLINTEGER,
  messageText   :: String
  }

-- | wrapper for SQL CLI GetDiagRec API call; the computation fails if an error
-- occurs and it displays diagnostics on standard error
getDiagRec :: (MonadIO m, MonadFail m) => SQLSMALLINT -> SQLINTEGER -> SQLSMALLINT -> m DiagRecord
getDiagRec handleType handle recnum = do
  diagrecord <- liftIO $ allocaBytes 5
    (\p_sqlstate -> alloca
      (\p_nativeErr -> allocaBytes sql_max_message_length
        (\p_messageText -> alloca
          (\p_textLen -> do
              result <- sqlgetdiagrec handleType handle recnum p_sqlstate p_nativeErr p_messageText sql_max_message_length p_textLen
              case result of
                x | x == sql_success -> do
                      l_sqlstate <- (map (toEnum . fromIntegral)) <$> (sequence [peekElemOff p_sqlstate j | j <- [0..4]])
                      l_nativeErr <- peek p_nativeErr
                      textLen <- fromIntegral <$> peek p_textLen
                      l_messageText <- (map (toEnum . fromIntegral)) <$> (sequence [peekElemOff p_messageText j | j <- [0..textLen]])
                      return $ Just $ DiagRecord l_sqlstate l_nativeErr l_messageText 
                  | x == sql_error -> do
                      hPutStrLn stderr $ (show recnum) ++ ": Diagnostic information could not be retrieved becuase wrong arguments passed to GetDagRec function"
                      return Nothing
                  | x == sql_invalid_handle -> do
                      hPutStrLn stderr $ (show recnum) ++ ": Diagnosic information could not be retrieved because of wrong handler"
                      return Nothing
                  | x == sql_no_data -> do
                      hPutStrLn stderr $ (show recnum) ++ ": No diagnostic data available"
                      return Nothing
                  | otherwise -> do
                      hPutStrLn stderr $ (show recnum) ++ ": Getting diagnostic information returned unexpected error code " ++ (show x)
                      return Nothing))))
  maybe (fail "GetDiagRec call failed") return diagrecord

-- | helper function to allocate a 'CStringLen'; it calls the function
-- received as parameter with the address of the allocated string or
-- with a null pointer if no string was received as input (i.e. 'Nothing')
withMaybeCStringLen :: Maybe String -> (CStringLen -> IO a) -> IO a
withMaybeCStringLen Nothing  f = f (nullPtr, 0)
withMaybeCStringLen (Just s) f = withCStringLen s f

-- | helper function to read a nullable column; returns Nothing if the
-- column is null
peekMaybeCol :: (Storable a) => Ptr a -> Ptr SQLINTEGER -> IO (Maybe a)
peekMaybeCol p_col p_ind = do
  ind <- peek p_ind
  if ind == sql_null_data
    then return Nothing
    else Just <$> peek p_col

-- | helper function to read a nullable text column; returns Nothing if the
-- column is null
peekMaybeTextCol :: CString -> Ptr SQLINTEGER -> IO (Maybe String)
peekMaybeTextCol p_col p_ind = do
  ind <- peek p_ind
  if ind == sql_null_data
    then return Nothing
    else Just <$> peekCString p_col