{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Database.TDS.Connection
( newConnection
) where
import Database.TDS.Types
import Database.TDS.Proto
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception ( SomeException, IOException
, Exception, throwIO, bracket
, catch, finally, onException )
import Control.Exception (mask)
import Control.Monad.IO.Class
import Control.Monad.Identity
import Control.Monad.Trans
import Data.Bits
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as IBS
import qualified Data.ByteString.Streaming as SBS
import Data.Char
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Data.Proxy
import Data.Text (Text)
import Data.Word
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Marshal.Alloc
import Foreign.Ptr
import Foreign.Storable
import qualified Network.Socket as BSD
import qualified Streaming as S
import qualified Streaming.Prelude as S
type family IsCancelable (resp :: ResponseInfo a) :: Bool where
IsCancelable ('ExpectsResponse ('ResponseType r a)) = r
IsCancelable _ = 'False
data SomeSentPacket where
SomeSentPacket :: Show d
=> Packet sender resp d Identity
-> CancelInfo (IsCancelable resp)
-> SomeSentPacket
data TDSCanceledException
= TDSCanceled
deriving Show
instance Exception TDSCanceledException
data TDSWasCanceledException
= TDSWasCanceled
deriving Show
instance Exception TDSWasCanceledException
data WriteEnd
= WriteEnd !BSD.Socket
!(Maybe (ForeignPtr (), CSize))
data ReadEnd = ReadEnd !BSD.Socket
data ConnectionState
= ConnectionQuit
| ConnectionState
{ connectionWriteSock :: TVar (Either ThreadId WriteEnd)
, connectionReadSock :: TVar (Either ThreadId ReadEnd)
, connectionCurReq :: TVar (Maybe SomeSentPacket)
, connectionCurState :: TVar ClientState
, connectionEnvironment :: !ClientEnv
}
closeWriteEnd :: WriteEnd -> IO ()
closeWriteEnd (WriteEnd _ _) = pure ()
closeReadEnd :: ReadEnd -> IO ()
closeReadEnd (ReadEnd s) = BSD.close s
newConnection :: Options -> IO Connection
newConnection opts = do
let opts' = defaultOptions <> opts
connInfo = _tdsConnInfo opts'
addrInfo = BSD.defaultHints
{ BSD.addrSocketType = BSD.Stream }
addrs <- BSD.getAddrInfo (Just addrInfo) (_tdsConnHost connInfo)
(show <$> _tdsConnPort connInfo)
`catch` \(e :: IOException) ->
throwIO (tdsErrorNoReq
TDSNoSuchHost Connecting
(show e))
case addrs of
[] -> throwIO (tdsErrorNoReq
TDSNoSuchHost Connecting
"No addresses returned by getAddrInfo")
addr:_ ->
do sock <- BSD.socket (BSD.addrFamily addr)
(BSD.addrSocketType addr)
(BSD.addrProtocol addr)
`catch` \(e :: IOException) ->
throwIO (tdsErrorNoReq
TDSSocketError Connecting
(show e))
BSD.connect sock (BSD.addrAddress addr)
`catch` (\(e :: IOException) ->
throwIO (tdsErrorNoReq TDSSocketError Connecting
(show e)))
`onException` BSD.close sock
writeV <- newTVarIO (Right (WriteEnd sock Nothing))
readV <- newTVarIO (Right (ReadEnd sock))
reqV <- newTVarIO Nothing
stV <- newTVarIO Connecting
let connSt = ConnectionState writeV readV reqV stV
(clientEnvFromOptions opts)
connStV <- newTVarIO connSt
pure (Connection (sendPacket opts connStV)
(cancel connStV)
(quit connStV)
stV opts)
debugPtr :: Ptr a -> CSize -> IO ()
debugPtr _ _ = pure ()
sendPackets :: BSD.Socket -> ForeignPtr () -> CSize
-> SplitPacket 'Client resp d -> IO ()
sendPackets sock fPtr bufSz =
withForeignPtr fPtr . go
where
go pkt ptr =
case pktData pkt of
LastPacket writePkt ->
do pktSz <- writePkt (ptr `plusPtr` fromIntegral pktHdrSz)
let totalSz = fromIntegral (pktSz + pktHdrSz)
let hdr = pktHdr pkt
writeHdr ptr (hdr { pktHdrStatus = pktHdrStatus hdr <>
pktStatusEndOfMessage
, pktHdrLength = fromIntegral totalSz })
debugPtr ptr totalSz
BSD.sendBuf sock (castPtr ptr) (fromIntegral totalSz)
`catch` \(e :: SomeException) -> do
putStrLn ("Exception " ++ show e)
error "Bad"
pure ()
OnePacket writePkt ->
do pkt' <- writePkt (ptr `plusPtr` fromIntegral pktHdrSz)
let hdr = pktHdr pkt
writeHdr ptr (hdr { pktHdrLength = fromIntegral bufSz })
debugPtr ptr bufSz
BSD.sendBuf sock (castPtr ptr) (fromIntegral bufSz)
go pkt' ptr
recvExactly :: BSD.Socket -> Ptr a -> Word16 -> IO ()
recvExactly sock p sz = do
recvd <- fromIntegral <$> BSD.recvBuf sock (castPtr p) (fromIntegral sz)
if recvd == sz
then pure ()
else recvExactly sock (p `plusPtr` fromIntegral recvd) (sz - recvd)
dataStream :: ReadEnd -> CancelInfo resp -> SBS.ByteString IO ()
dataStream (ReadEnd sock) cancel = do
hdr <-
liftIO . alloca $ \hdrP ->
recvExactly sock (hdrP :: Ptr Word64) 8 >>
readHdr (TabularResult :: PacketType 'Server 'NoResponse ()) hdrP
case hdr of
Nothing -> fail "Invalid header"
Just hdr' -> do
let pktLength = pktHdrLength hdr' - 8
bufSz = 65536 - 8
getChunk :: Word16 -> SBS.ByteString IO ()
getChunk 0 = pure ()
getChunk len = do
case cancel of
Cancelable isCanceled sync -> do
cancelSt <- lift (atomically (readTVar isCanceled))
when cancelSt (liftIO (throwIO TDSWasCanceled))
_ -> pure ()
chunk <-
liftIO $ bracket startRead (\_ -> endRead) $ \_ -> do
fPtr <- mallocForeignPtrBytes (fromIntegral len)
actuallyRead <-
withForeignPtr fPtr $ \ptr -> do
bytesRead <- BSD.recvBuf sock ptr (fromIntegral len)
debugPtr ptr (fromIntegral bytesRead)
pure bytesRead
pure (IBS.fromForeignPtr fPtr 0 actuallyRead)
SBS.chunk chunk
getChunk (len - fromIntegral (BS.length chunk))
(startRead, endRead) =
case cancel of
Cancelable _ sync ->
(atomically $ takeTMVar sync,
atomically $ putTMVar sync ())
_ -> (pure (), pure ())
liftIO endRead
getChunk pktLength
if pktHdrStatus hdr' `hasStatus` pktStatusEndOfMessage
then pure ()
else dataStream (ReadEnd sock) cancel
withConnectionState :: TVar ConnectionState -> STM a
-> (ConnectionState -> STM a)
-> STM a
withConnectionState stV onQuit onState = do
connSt <- readTVar stV
case connSt of
ConnectionQuit -> onQuit
ConnectionState {} -> onState connSt
waitUntilSendable :: ThreadId -> ConnectionState
-> (ClientState -> Bool)
-> STM (Either TDSError (WriteEnd, ReadEnd))
waitUntilSendable _ ConnectionQuit _ = retry
waitUntilSendable threadId
(ConnectionState { connectionWriteSock = writeEndV
, connectionReadSock = readEndV
, connectionCurState = stateV
, connectionCurReq = reqV })
canSendInState =
do state <- readTVar stateV
case state of
_ | canSendInState state ->
do writeSock <- either (\_ -> retry) pure =<<
readTVar writeEndV
readSock <- either (\_ -> retry) pure =<<
readTVar readEndV
maybe (pure ()) (\_ -> retry) =<<
readTVar reqV
writeTVar writeEndV (Left threadId)
writeTVar readEndV (Left threadId)
pure (Right (writeSock, readSock))
SentClientRequest ->
tdsError TDSServerBusy SentClientRequest
"Can't send request while server is still processing"
SentAttention -> retry
Final ->
tdsError TDSServerQuit Final "Connection is closing"
_ -> tdsError TDSServerUninitialized Final "The connection is not yet ready"
where
tdsError ty st msg = pure (Left (tdsErrorNoReq ty st msg))
surrenderWrite :: TVar ConnectionState -> WriteEnd -> STM ()
surrenderWrite stV we = do
st <- readTVar stV
case st of
ConnectionQuit -> pure ()
ConnectionState { connectionWriteSock = writeEndV } ->
do we' <- readTVar writeEndV
case we' of
Left {} -> writeTVar writeEndV (Right we)
_ -> pure ()
sendPacket :: forall cancelable r d
. ( Payload d, Response r, MkCancelable cancelable
, KnownBool cancelable )
=> Options -> TVar ConnectionState
-> Packet 'Client ('ExpectsResponse ('ResponseType cancelable r)) d Identity
-> IO (IO (ResponseResult ('ResponseType cancelable r)))
sendPacket options stV pkt =
myThreadId >>= go
where
go threadId =
mask $ \unmask ->
join . atomically .
withConnectionState stV
(pure (throwIO (tdsErrorNoReq TDSServerQuit Final
"Can't send request to closed connection"))) $ \st ->
do sock <- waitUntilSendable threadId st (sendableInState (pktHdrType (pktHdr pkt)))
case sock of
Left err -> pure (throwIO err)
Right (writeEnd, readEnd) -> do
cancel <- mkCancelable
writeTVar (connectionCurReq st) (Just (SomeSentPacket pkt cancel))
pure (unmask $
let go = doSend writeEnd readEnd cancel
go' = if boolVal (Proxy :: Proxy cancelable)
then go `catch`
\(e :: TDSCanceledException) ->
cancelRequest st writeEnd >>
throwIO e
else go
internalError =
join . atomically $ quitSTM' stV writeEnd readEnd
in go' `onException` internalError)
encoding = packetEncoding pkt
doSend we@(WriteEnd sock buf) readEnd cancel = do
let (maxSz, splitEncoding) =
splitPacket (maybe maximumPayloadPacketSize snd buf - pktHdrSz) encoding
(buf', sz') <-
case buf of
Nothing ->
case maxSz of
Nothing -> fail "Don't know what size of buffer"
Just maxSz' ->
(, maxSz') <$> mallocForeignPtrBytes (fromIntegral (maxSz' + pktHdrSz))
Just (buf, sz) -> pure (buf, sz)
sendPackets sock buf' sz' splitEncoding
join . atomically . withConnectionState stV
(pure (throwIO (TDSError TDSInvalidStateTransition Final (Just pkt)
"Server quit before response received"))) $ \st ->
do surrenderWrite stV we
oldSt <- readTVar (connectionCurState st)
let st' = stateTransition (pktHdrType (pktHdr pkt)) oldSt
writeTVar (connectionCurState st) st'
disconnectOnFinal st' we readEnd
(throwIO (TDSError TDSInvalidStateTransition st' (Just pkt)
"The state transitioned to Final before a response could be received"))
(pure (getResult readEnd cancel))
getResult readEnd@(ReadEnd sock) cancel =
case responseDecoder :: ResponseDecoder (ResponseStreaming r) r of
DecodeBatchResponse decode -> do
hdr <- alloca $ \hdrP ->
do recvExactly sock (hdrP :: Ptr Word64) 8
readHdr (TabularResult :: PacketType 'Server 'NoResponse r) hdrP
case hdr of
Nothing -> invalidResponse readEnd
Just hdr'
| not (pktHdrStatus hdr' `hasStatus` pktStatusEndOfMessage) ->
fail "Cannot decode batch message split over multiple packets"
| otherwise ->
allocaBytes (fromIntegral $ pktHdrLength hdr') $ \pktBuf ->
do recvExactly sock pktBuf (pktHdrLength hdr' - 8)
debugPtr pktBuf (fromIntegral $ pktHdrLength hdr' - 8)
resp <- decode pktBuf (pktHdrLength hdr')
case resp of
Nothing -> fail "Invalid response"
Just resp' -> do
atomically . withConnectionState stV (pure ()) $ \st -> do
writeTVar (connectionReadSock st) (Right readEnd)
writeTVar (connectionCurReq st) Nothing
pure (ResponseResultReceived resp')
DecodeTokenStream streamDecoder ->
do let tokenStream = parseTokenStream (dataStream readEnd cancel)
validTokens s = do
res <- S.lift (S.inspect s)
case res of
Left a -> pure a
Right (OneToken tok next) ->
do includeToken <- S.lift (handleToken pkt options stV tok)
if includeToken
then S.wrap (OneToken tok (validTokens next))
else validTokens next
Right (ContParse tok next) ->
S.wrap (ContParse tok (validTokens . next))
finishUp = atomically . withConnectionState stV (pure ()) $
\st -> do
writeTVar (connectionCurState st) LoggedIn
writeTVar (connectionReadSock st) (Right readEnd)
writeTVar (connectionCurReq st) Nothing
res <- streamDecoder finishUp (validTokens tokenStream)
pure (ResponseResultReceived res)
invalidResponse :: ReadEnd -> IO (ResponseResult ('ResponseType cancelable r))
invalidResponse readEnd =
join . atomically . withConnectionState stV
(pure $ throwIO (TDSError TDSInvalidResponse Final
(Just pkt)
"Invalid response received, but we've already quit")) $
\st -> do
clientSt <- readTVar (connectionCurState st)
writeTVar (connectionCurState st) Final
writeTVar (connectionReadSock st) (Right readEnd)
quitSTM stV
pure (throwIO (TDSError TDSInvalidResponse clientSt
(Just pkt)
"Invalid response received"))
disconnectOnFinal :: ClientState -> WriteEnd -> ReadEnd
-> IO a -> IO a
-> STM (IO a)
disconnectOnFinal Final we re failer _ = quitSTM' stV we re >> pure failer
disconnectOnFinal _ _ _ _ action = pure action
cancelRequest st writeEnd@(WriteEnd sock buf) = do
curReq <- atomically (readTVar (connectionCurReq st))
case curReq of
Just (SomeSentPacket (Packet (PacketHeader { pktHdrType = SQLBatch }) _)
(Cancelable signal sync)) -> do
let bufSz = cancelPacketSize
buf <- mallocForeignPtrBytes (fromIntegral bufSz)
atomically (writeTVar signal True)
atomically (takeTMVar sync)
let (_, splitEncoding) = splitPacket bufSz
(Packet (mkPacketHeader Attention pktStatusEndOfMessage)
(PacketEncoding (encodePayload ())))
sendPackets sock buf bufSz splitEncoding
let tokenStream = parseTokenStream (dataStream (ReadEnd sock) NonCancelable)
readUntilDoneToken tokenStream
atomically $ do
writeTVar (connectionCurState st) LoggedIn
writeTVar (connectionReadSock st) (Right (ReadEnd sock))
writeTVar (connectionCurReq st) Nothing
_ -> pure ()
handleToken :: Packet clientServer respType d f
-> Options -> TVar ConnectionState
-> Token -> IO Bool
handleToken _ options stV (Info msg) = False <$ _tdsOnMessage options msg
handleToken _ options stV (Error msg) = False <$ _tdsOnMessage options msg
handleToken _ options stV (EnvChange chg) =
join . atomically $ do
st <- readTVar stV
case st of
ConnectionQuit -> pure (pure False)
ConnectionState {} ->
do let st' = st { connectionEnvironment = updateEnv chg (connectionEnvironment st) }
writeTVar stV st'
pure (False <$ _tdsOnEnvChange options chg)
handleToken pkt options stV (LoginAck {})
| Login7 <- pktHdrType (pktHdr pkt) =
atomically $ do
st <- readTVar stV
case st of
ConnectionQuit -> pure True
ConnectionState { connectionCurState = protoStV } ->
do protoSt <- readTVar protoStV
let setLoggedIn = writeTVar protoStV LoggedIn
case protoSt of
SentLogin7WithCompleteAuthenticationToken -> setLoggedIn
SentLogin7WithSPNEGO -> setLoggedIn
SentLogin7WithFAIR -> setLoggedIn
_ -> pure ()
pure True
handleToken _ _ _ _ = pure True
cancel :: TVar ConnectionState -> IO ()
cancel _ = fail "Cancel"
quitSTM :: TVar ConnectionState -> STM (IO ())
quitSTM stV = do
withConnectionState stV (pure (pure ())) $ \st -> do
we <- either (\_ -> retry) pure =<<
readTVar (connectionWriteSock st)
re <- either (\_ -> retry) pure =<<
readTVar (connectionReadSock st)
quitSTM' stV we re
quitSTM' :: TVar ConnectionState -> WriteEnd -> ReadEnd -> STM (IO ())
quitSTM' stV we re = do
withConnectionState stV (pure (pure ())) $ \st -> do
writeTVar stV ConnectionQuit
pure (closeWriteEnd we >> closeReadEnd re)
quit :: TVar ConnectionState -> IO ()
quit = join . atomically . quitSTM
data ClientEnv
= ClientEnv
{ clientEnvDatabase :: Text
, clientEnvLanguage :: Text
} deriving Show
clientEnvFromOptions :: Options -> ClientEnv
clientEnvFromOptions _ = ClientEnv "master" "us_english"
updateEnv :: EnvChange -> ClientEnv -> ClientEnv
updateEnv (EnvChangeDatabase _ new) env = env { clientEnvDatabase = new }
updateEnv _ env = env
readUntilDoneToken :: S.Stream TokenStream IO () -> IO ()
readUntilDoneToken s = do
res <- S.inspect s
case res of
Left {} -> pure ()
Right (OneToken Done {} s') -> readUntilDoneToken s'
Right (ContParse Done {} cont) -> pure ()
Right (ContParse _ f) -> fail "Can't read (TODO)"