{-# LANGUAGE OverloadedStrings #-}
-- | This module contains many helper functions, as well the code for 'Source',
--   which is a pretty important structure
module Network.MiniHTTP.HTTPConnection
  ( -- * Sources, and related functions
    Source
  , SourceResult(..)
  , bsSource
  , hSource
  , nullSource
  , sourceToLBS
  , sourceToBS
  , connSource
  , connChunkedSource
  , connEOFSource
  , sourceDrain
  , streamSource
  , streamSourceChunked

  -- * Misc functions
  , readIG
  , sourceIG
  , maybeRead
  , sslToBaseConnection
  ) where

import           Control.Concurrent.STM
import qualified Control.Exception as Exception
import           Control.Monad (liftM)

import qualified Data.ByteString as B
import           Data.ByteString.Char8 ()
import           Data.ByteString.Internal (c2w, w2c)
import qualified Data.ByteString.Lazy.Internal as BL
import qualified Data.Binary.Strict.IncrementalGet as IG
import           Data.IORef
import           Data.Int (Int64)

import           System.IO

import           Text.Printf (printf)

import           System.IO.Unsafe (unsafeInterleaveIO)

import qualified Network.Connection as C
import           Network.Socket as Socket
import           Network.MiniHTTP.Marshal (parseChunkHeader)

import qualified OpenSSL.Session as SSL

-- | A source is a stream of data, like a lazy data structure, but without
--   some of the dangers that such entail. A source returns a 'SourceResult'
--   each time you evaluate it.
type Source = IO SourceResult

data SourceResult = SourceError  -- ^ error - please don't read this source again
                  | SourceEOF  -- ^ end of data
                  | SourceData B.ByteString  -- ^ some data
                  deriving (Show)

-- | Construct a source from a ByteString
bsSource :: B.ByteString -> IO Source
bsSource bs = do
  ref <- newIORef $ SourceData bs
  return $ do
    v <- readIORef ref
    writeIORef ref SourceEOF
    return v

-- | Construct a source from a Handle
hSource :: (Int64, Int64)  -- ^ the first and last byte to include
        -> Handle  -- ^ the handle to read from
        -> IO Source
hSource (from, to) handle = do
  bytesSoFar <- newIORef (from :: Int64)
  hSeek handle AbsoluteSeek (fromIntegral from)
  return $ do
    Exception.catch
      (do done <- readIORef bytesSoFar
          bytes <- B.hGet handle $ min (128 * 1024) (fromIntegral $ (to + 1) - done)
          if B.length bytes == 0
             then do
               if to + 1 == done
                  then return SourceEOF
                  else return SourceError
             else do modifyIORef bytesSoFar ((+) (fromIntegral $ B.length bytes))
                     return $ SourceData bytes)
      (const $ return SourceError)

-- | A source with no data (e.g. @/dev/null@)
nullSource :: Source
nullSource = return SourceEOF

-- | A source which reads from the given 'Connection' until the connection
--   signals end-of-file.
connEOFSource :: C.Connection -> IO Source
connEOFSource conn = do
  return $
    catch (liftM SourceData $ C.read conn 1024) (const $ return SourceEOF)

-- | A source which reads from a 'C.Connection'
connSource :: Int64  -- ^ the number of bytes to read
           -> B.ByteString  -- ^ a string which is prepended to the output
           -> C.Connection  -- ^ the connection to read from
           -> IO Source
connSource n bs conn =
  if fromIntegral (B.length bs) == n
     then bsSource bs
     else do
       ref <- newIORef (False, 0 :: Int64)
       return $ do
         (doneBS, n') <- readIORef ref
         if not doneBS
            then do writeIORef ref (True, fromIntegral $ B.length bs)
                    return $ SourceData bs
            else if n' == n
                    then return SourceEOF
                    else do bytes <- C.read conn $ min (32 * 1024) $ fromIntegral (n - n')
                            if B.length bytes == 0
                               then return SourceError
                               else do writeIORef ref (doneBS, n' + (fromIntegral $ B.length bytes))
                                       return $ SourceData bytes

-- | A source which reads an HTTP chunked reply from a 'C.Connection'
connChunkedSource :: C.Connection -> IO Source
connChunkedSource conn = do
  -- the contents of this reference are the number of bytes remaining in the
  -- current chunk. If zero, a chunk headers needs to be read. If < 0, we have
  -- hit EOF. If we read the end of a chunk, we always read the trailing \r\n
  -- before returning (so one need never consider that case on entry)
  ref <- newIORef (0 :: Int64)
  let f = do
        remainingInThisChunk <- readIORef ref
        case remainingInThisChunk of
             0 -> do
               m <- readIG conn 16 256 parseChunkHeader
               case m of
                    Nothing -> return SourceError
                    Just n ->
                      if n == 0
                         then C.reada conn 2 >> writeIORef ref (-1) >> return SourceEOF
                         else writeIORef ref n >> f
             (-1) -> return SourceEOF
             remainingInThisChunk -> do
               bytes <- C.read conn $ fromIntegral $ min remainingInThisChunk $ 32*1024
               if B.null bytes
                  then return SourceError
                  else do
                    let stillRemaining = remainingInThisChunk - (fromIntegral $ B.length bytes)
                    if stillRemaining == 0
                       then do
                         C.reada conn 2  -- read \r\n
                         writeIORef ref 0
                       else do
                         writeIORef ref stillRemaining
                    return $ SourceData bytes
  return f

-- | Read a source until it returns 'SourceEOF'
sourceDrain :: Source -> IO ()
sourceDrain s = do
  v <- s
  case v of
       SourceEOF -> return ()
       SourceError -> return ()
       SourceData _ -> sourceDrain s

-- | Convert a source to a lazy ByteString
sourceToLBS :: Source -> IO BL.ByteString
sourceToLBS s = do
  bytes <- s
  case bytes of
       SourceEOF -> return $ BL.Empty
       SourceError -> fail "Error in reading from client"
       SourceData bs -> do
         rest <- unsafeInterleaveIO $ sourceToLBS s
         return $ BL.Chunk bs rest

-- | Take, at most, the first n bytes from a Source and return a strict
--   ByteString. Returns Nothing on error. (A short read is not an error)
sourceToBS :: Int -> Source -> IO (Maybe B.ByteString)
sourceToBS n source = f 0 where
  f soFar = do
    s <- source
    case s of
         SourceEOF -> return $ Just B.empty
         SourceError -> return Nothing
         SourceData bs -> do
           if B.length bs + soFar >= n
              then return $ Just $ B.take (n - soFar) bs
              else do
                rest <- f (soFar + B.length bs)
                return $ (rest >>= return . B.append bs)

-- | Stream a source to a connection while not enqueuing more than lowWater
--   bytes in the outbound queue (not inc the kernel buffer)
streamSource :: Int -> C.Connection -> Source -> IO Bool
streamSource lowWater conn source = do
  next <- source
  case next of
       SourceEOF -> return True
       SourceError -> return False
       SourceData bs -> do
         atomically $ C.writeAtLowWater lowWater conn bs
         streamSource lowWater conn source

-- | Stream a source to a connection, with chunked encoding, while not
--   enqueuing more than lowWater bytes in the outbound queue (not inc the
--   kernel buffer)
streamSourceChunked :: Int -> C.Connection -> Source -> IO Bool
streamSourceChunked lowWater conn source = do
  next <- source
  case next of
       SourceEOF -> do
         atomically $ C.writeAtLowWater lowWater conn "0\r\n\r\n"
         return True
       SourceError -> return False
       SourceData bs -> do
         atomically $ C.writeAtLowWater lowWater conn $ B.pack $ map c2w $
           printf "%d\r\n\r\n" $ B.length bs
         atomically $ C.writeAtLowWater lowWater conn bs
         atomically $ C.writeAtLowWater lowWater conn "\r\n"
         streamSourceChunked lowWater conn source

-- | Run an incremental parser from the network
readIG :: C.Connection  -- ^ the connection to read from
       -> Int  -- ^ the block size to use
       -> Int  -- ^ maximum number of bytes to parse
       -> IG.Get a a -- ^ the parser
       -> IO (Maybe a)
readIG conn blockSize maxBytes parser = do
  let f sofar result
        | sofar >= maxBytes = return Nothing
        | otherwise = do
            case result of
                 IG.Failed _ -> return Nothing
                 IG.Partial cont -> C.read conn blockSize >>= \bs -> f (sofar + B.length bs) $ cont bs
                 IG.Finished rest result -> do
                   atomically $ C.pushBack conn rest
                   return $ Just result
  C.read conn blockSize >>= f 0 . IG.runGet parser

-- | Run an incremental parser from a 'Source'
sourceIG :: Source  -- ^ the source to read from
         -> Int  -- ^ the maximum number of bytes to parse
         -> IG.Get a a  -- ^ the parser
         -> IO (Maybe a)
sourceIG source maxBytes parser = do
  let f sofar result
        | sofar >= maxBytes = return Nothing
        | otherwise =
            case result of
                 IG.Failed _ -> return Nothing
                 IG.Partial cont -> do
                   s <- source
                   case s of
                        SourceError -> return Nothing
                        SourceEOF -> return Nothing
                        SourceData bytes -> f (B.length bytes + sofar) $ cont bytes
                 IG.Finished _ result -> do
                   return $ Just result
  f 0 $ IG.runGet parser B.empty

maybeRead :: Read a => B.ByteString -> Maybe a
maybeRead s = case reads $ map w2c $ B.unpack s of
                   [(x, "")] -> Just x
                   _         -> Nothing

-- | Convert an SSL connection to a BaseConnection for Network.Connection
sslToBaseConnection :: SSL.SSL -> C.BaseConnection
sslToBaseConnection ssl = C.BaseConnection r w c where
  r n = do
    bytes <- SSL.read ssl n
    return bytes
  w bs = SSL.write ssl bs >> return (B.length bs)
  c = SSL.shutdown ssl SSL.Unidirectional >> sClose (SSL.sslSocket ssl)