{-# LANGUAGE CPP #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
-- | @multipart/form-data@ client-side support for servant.
--   See servant-multipart-api for the API definitions.
module Servant.Multipart.Client
  ( genBoundary
  , ToMultipart(..)
  , multipartToBody
  ) where

import Servant.Multipart.API

import Control.Monad (replicateM)
import Data.Array (listArray, (!))
import Data.List (foldl')
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
import Data.Text.Encoding           (encodeUtf8)
import Data.Typeable
import Network.HTTP.Media.MediaType ((//), (/:))
import Servant.API
import Servant.Client.Core          (HasClient (..), RequestBody (RequestBodySource),
                                     setRequestBody)
import Servant.Types.SourceT        (SourceT (..), StepT (..), fromActionStep, source)
import System.IO                    (IOMode (ReadMode), withFile)
import System.Random                (getStdRandom, randomR)

import qualified Data.ByteString.Lazy as LBS

-- | Upon seeing @MultipartForm a :> ...@ in an API type,
--   servant-client will take a parameter of type @(LBS.ByteString, a)@,
--   where the bytestring is the boundary to use (see 'genBoundary'), and
--   replace the request body with the contents of the form.
instance (ToMultipart tag a, HasClient m api, MultipartClient tag)
      => HasClient m (MultipartForm' mods tag a :> api) where

  type Client m (MultipartForm' mods tag a :> api) =
    (LBS.ByteString, a) -> Client m api

  clientWithRoute :: Proxy m
-> Proxy (MultipartForm' mods tag a :> api)
-> Request
-> Client m (MultipartForm' mods tag a :> api)
clientWithRoute Proxy m
pm Proxy (MultipartForm' mods tag a :> api)
_ Request
req (ByteString
boundary, a
param) =
      forall (m :: * -> *) api.
HasClient m api =>
Proxy m -> Proxy api -> Request -> Client m api
clientWithRoute Proxy m
pm (forall {k} (t :: k). Proxy t
Proxy @api) forall a b. (a -> b) -> a -> b
$ RequestBody -> MediaType -> Request -> Request
setRequestBody RequestBody
newBody MediaType
newMedia Request
req
    where
      newBody :: RequestBody
newBody = forall tag.
MultipartClient tag =>
ByteString -> MultipartData tag -> RequestBody
multipartToBody ByteString
boundary forall a b. (a -> b) -> a -> b
$ forall tag a. ToMultipart tag a => a -> MultipartData tag
toMultipart @tag a
param
      newMedia :: MediaType
newMedia = ByteString
"multipart" ByteString -> ByteString -> MediaType
// ByteString
"form-data" MediaType -> (ByteString, ByteString) -> MediaType
/: (ByteString
"boundary", ByteString -> ByteString
LBS.toStrict ByteString
boundary)

  hoistClientMonad :: forall (mon :: * -> *) (mon' :: * -> *).
Proxy m
-> Proxy (MultipartForm' mods tag a :> api)
-> (forall x. mon x -> mon' x)
-> Client mon (MultipartForm' mods tag a :> api)
-> Client mon' (MultipartForm' mods tag a :> api)
hoistClientMonad Proxy m
pm Proxy (MultipartForm' mods tag a :> api)
_ forall x. mon x -> mon' x
f Client mon (MultipartForm' mods tag a :> api)
cl = \(ByteString, a)
a ->
      forall (m :: * -> *) api (mon :: * -> *) (mon' :: * -> *).
HasClient m api =>
Proxy m
-> Proxy api
-> (forall x. mon x -> mon' x)
-> Client mon api
-> Client mon' api
hoistClientMonad Proxy m
pm (forall {k} (t :: k). Proxy t
Proxy @api) forall x. mon x -> mon' x
f (Client mon (MultipartForm' mods tag a :> api)
cl (ByteString, a)
a)

class MultipartClient tag where
    loadFile :: Proxy tag -> MultipartResult tag -> SourceIO LBS.ByteString

instance MultipartClient Tmp where
    -- streams the file from disk
    loadFile :: Proxy Tmp -> MultipartResult Tmp -> SourceIO ByteString
loadFile Proxy Tmp
_ MultipartResult Tmp
fp =
        forall (m :: * -> *) a.
(forall b. (StepT m a -> m b) -> m b) -> SourceT m a
SourceT forall a b. (a -> b) -> a -> b
$ \StepT IO ByteString -> IO b
k ->
        forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
withFile MultipartResult Tmp
fp IOMode
ReadMode forall a b. (a -> b) -> a -> b
$ \Handle
hdl ->
        StepT IO ByteString -> IO b
k (Handle -> StepT IO ByteString
readHandle Handle
hdl)
      where
        readHandle :: Handle -> StepT IO ByteString
readHandle Handle
hdl = forall (m :: * -> *) a.
Functor m =>
(a -> Bool) -> m a -> StepT m a
fromActionStep ByteString -> Bool
LBS.null (Handle -> Int -> IO ByteString
LBS.hGet Handle
hdl Int
4096)

instance MultipartClient Mem where
    loadFile :: Proxy Mem -> MultipartResult Mem -> SourceIO ByteString
loadFile Proxy Mem
_ = forall (f :: * -> *) a (m :: * -> *).
Foldable f =>
f a -> SourceT m a
source forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure @[]

-- | Generates a boundary to be used to separate parts of the multipart.
-- Requires 'IO' because it is randomized.
genBoundary :: IO LBS.ByteString
genBoundary :: IO ByteString
genBoundary = [Word8] -> ByteString
LBS.pack
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Array Int Word8
validChars forall i e. Ix i => Array i e -> i -> e
!)
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO [Int]
indices
  where
    -- the standard allows up to 70 chars, but most implementations seem to be
    -- in the range of 40-60, so we pick 55
    indices :: IO [Int]
indices = forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
55 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom forall a b. (a -> b) -> a -> b
$ forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0,Int
61)
    -- Following Chromium on this one:
    -- > The RFC 2046 spec says the alphanumeric characters plus the
    -- > following characters are legal for boundaries:  '()+_,-./:=?
    -- > However the following characters, though legal, cause some sites
    -- > to fail: (),./:=+
    -- https://github.com/chromium/chromium/blob/6efa1184771ace08f3e2162b0255c93526d1750d/net/base/mime_util.cc#L662-L670
    validChars :: Array Int Word8
validChars = forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Int
0 :: Int, Int
61)
                           -- 0-9
                           [ Word8
0x30, Word8
0x31, Word8
0x32, Word8
0x33, Word8
0x34, Word8
0x35, Word8
0x36, Word8
0x37
                           , Word8
0x38, Word8
0x39, Word8
0x41, Word8
0x42
                           -- A-Z, a-z
                           , Word8
0x43, Word8
0x44, Word8
0x45, Word8
0x46, Word8
0x47, Word8
0x48, Word8
0x49, Word8
0x4a
                           , Word8
0x4b, Word8
0x4c, Word8
0x4d, Word8
0x4e, Word8
0x4f, Word8
0x50, Word8
0x51, Word8
0x52
                           , Word8
0x53, Word8
0x54, Word8
0x55, Word8
0x56, Word8
0x57, Word8
0x58, Word8
0x59, Word8
0x5a
                           , Word8
0x61, Word8
0x62, Word8
0x63, Word8
0x64, Word8
0x65, Word8
0x66, Word8
0x67, Word8
0x68
                           , Word8
0x69, Word8
0x6a, Word8
0x6b, Word8
0x6c, Word8
0x6d, Word8
0x6e, Word8
0x6f, Word8
0x70
                           , Word8
0x71, Word8
0x72, Word8
0x73, Word8
0x74, Word8
0x75, Word8
0x76, Word8
0x77, Word8
0x78
                           , Word8
0x79, Word8
0x7a
                           ]

-- | Given a bytestring for the boundary, turns a `MultipartData` into
-- a 'RequestBody'
multipartToBody :: forall tag
                .  MultipartClient tag
                => LBS.ByteString
                -> MultipartData tag
                -> RequestBody
multipartToBody :: forall tag.
MultipartClient tag =>
ByteString -> MultipartData tag -> RequestBody
multipartToBody ByteString
boundary MultipartData tag
mp = SourceIO ByteString -> RequestBody
RequestBodySource forall a b. (a -> b) -> a -> b
$ SourceIO ByteString
files' forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a (m :: * -> *).
Foldable f =>
f a -> SourceT m a
source [ByteString
"--", ByteString
boundary, ByteString
"--"]
  where
    -- at time of writing no Semigroup or Monoid instance exists for SourceT and StepT
    -- in releases of Servant; they are in master though
    (SourceT forall b. (StepT m a -> m b) -> m b
l) mappend' :: SourceT m a -> SourceT m a -> SourceT m a
`mappend'` (SourceT forall b. (StepT m a -> m b) -> m b
r) = forall (m :: * -> *) a.
(forall b. (StepT m a -> m b) -> m b) -> SourceT m a
SourceT forall a b. (a -> b) -> a -> b
$ \StepT m a -> m b
k ->
                                                   forall b. (StepT m a -> m b) -> m b
l forall a b. (a -> b) -> a -> b
$ \StepT m a
lstep ->
                                                   forall b. (StepT m a -> m b) -> m b
r forall a b. (a -> b) -> a -> b
$ \StepT m a
rstep ->
                                                   StepT m a -> m b
k (forall {m :: * -> *} {a}.
Functor m =>
StepT m a -> StepT m a -> StepT m a
appendStep StepT m a
lstep StepT m a
rstep)
    appendStep :: StepT m a -> StepT m a -> StepT m a
appendStep StepT m a
Stop        StepT m a
r = StepT m a
r
    appendStep (Error FilePath
err) StepT m a
_ = forall (m :: * -> *) a. FilePath -> StepT m a
Error FilePath
err
    appendStep (Skip StepT m a
s)    StepT m a
r = StepT m a -> StepT m a -> StepT m a
appendStep StepT m a
s StepT m a
r
    appendStep (Yield a
x StepT m a
s) StepT m a
r = forall (m :: * -> *) a. a -> StepT m a -> StepT m a
Yield a
x (StepT m a -> StepT m a -> StepT m a
appendStep StepT m a
s StepT m a
r)
    appendStep (Effect m (StepT m a)
ms) StepT m a
r = forall (m :: * -> *) a. m (StepT m a) -> StepT m a
Effect forall a b. (a -> b) -> a -> b
$ (forall a b c. (a -> b -> c) -> b -> a -> c
flip StepT m a -> StepT m a -> StepT m a
appendStep StepT m a
r forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (StepT m a)
ms)
    mempty' :: SourceT m a
mempty' = forall (m :: * -> *) a.
(forall b. (StepT m a -> m b) -> m b) -> SourceT m a
SourceT (forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. StepT m a
Stop)
    crlf :: ByteString
crlf = ByteString
"\r\n"
    lencode :: Text -> ByteString
lencode = ByteString -> ByteString
LBS.fromStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8
    renderInput :: Input -> SourceIO ByteString
renderInput Input
input = ByteString
-> ByteString
-> ByteString
-> SourceIO ByteString
-> SourceIO ByteString
renderPart (Text -> ByteString
lencode forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Text
iName forall a b. (a -> b) -> a -> b
$ Input
input)
                                   ByteString
"text/plain"
                                   ByteString
""
                                   (forall (f :: * -> *) a (m :: * -> *).
Foldable f =>
f a -> SourceT m a
source forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure @[] forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
lencode forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Text
iValue forall a b. (a -> b) -> a -> b
$ Input
input)
    inputs' :: SourceIO ByteString
inputs' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\SourceIO ByteString
acc Input
x -> SourceIO ByteString
acc forall {m :: * -> *} {a}.
Functor m =>
SourceT m a -> SourceT m a -> SourceT m a
`mappend'` Input -> SourceIO ByteString
renderInput Input
x) forall {m :: * -> *} {a}. SourceT m a
mempty' (forall tag. MultipartData tag -> [Input]
inputs MultipartData tag
mp)
    renderFile :: FileData tag -> SourceIO LBS.ByteString
    renderFile :: FileData tag -> SourceIO ByteString
renderFile FileData tag
file = ByteString
-> ByteString
-> ByteString
-> SourceIO ByteString
-> SourceIO ByteString
renderPart (Text -> ByteString
lencode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall tag. FileData tag -> Text
fdInputName forall a b. (a -> b) -> a -> b
$ FileData tag
file)
                                 (Text -> ByteString
lencode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall tag. FileData tag -> Text
fdFileCType forall a b. (a -> b) -> a -> b
$ FileData tag
file)
                                 ((forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Monoid a => a -> a -> a
mappend) ByteString
"\"" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => a -> a -> a
mappend ByteString
"; filename=\""
                                                      forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
lencode
                                                      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall tag. FileData tag -> Text
fdFileName forall a b. (a -> b) -> a -> b
$ FileData tag
file)
                                 (forall tag.
MultipartClient tag =>
Proxy tag -> MultipartResult tag -> SourceIO ByteString
loadFile (forall {k} (t :: k). Proxy t
Proxy @tag) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall tag. FileData tag -> MultipartResult tag
fdPayload forall a b. (a -> b) -> a -> b
$ FileData tag
file)
    files' :: SourceIO ByteString
files' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\SourceIO ByteString
acc FileData tag
x -> SourceIO ByteString
acc forall {m :: * -> *} {a}.
Functor m =>
SourceT m a -> SourceT m a -> SourceT m a
`mappend'` FileData tag -> SourceIO ByteString
renderFile FileData tag
x) SourceIO ByteString
inputs' (forall tag. MultipartData tag -> [FileData tag]
files MultipartData tag
mp)
    renderPart :: ByteString
-> ByteString
-> ByteString
-> SourceIO ByteString
-> SourceIO ByteString
renderPart ByteString
name ByteString
contentType ByteString
extraParams SourceIO ByteString
payload =
      forall (f :: * -> *) a (m :: * -> *).
Foldable f =>
f a -> SourceT m a
source [ ByteString
"--"
             , ByteString
boundary
             , ByteString
crlf
             , ByteString
"Content-Disposition: form-data; name=\""
             , ByteString
name
             , ByteString
"\""
             , ByteString
extraParams
             , ByteString
crlf
             , ByteString
"Content-Type: "
             , ByteString
contentType
             , ByteString
crlf
             , ByteString
crlf
             ] forall {m :: * -> *} {a}.
Functor m =>
SourceT m a -> SourceT m a -> SourceT m a
`mappend'` SourceIO ByteString
payload forall {m :: * -> *} {a}.
Functor m =>
SourceT m a -> SourceT m a -> SourceT m a
`mappend'` forall (f :: * -> *) a (m :: * -> *).
Foldable f =>
f a -> SourceT m a
source [ByteString
crlf]