{-# LANGUAGE CPP                        #-}
{-# LANGUAGE TemplateHaskell            #-}

module Data.API.Tools.CBOR
    ( cborTool
    ) where

import           Data.API.TH
import           Data.API.Tools.Combinators
import           Data.API.Tools.Datatypes
import           Data.API.Tools.Enum
import           Data.API.Types

import           Control.Applicative
import           Codec.Serialise.Class
import           Codec.Serialise.Decoding
import           Codec.Serialise.Encoding
import           Data.Binary.Serialise.CBOR.Extra
import           Data.List (foldl', sortBy)
import qualified Data.Map                       as Map
import           Data.Monoid
import           Data.Ord (comparing)
import qualified Data.Text                      as T
import           Language.Haskell.TH
import           Prelude

-- | Tool to generate 'Serialise' instances for types generated by
-- 'datatypesTool'. This depends on 'enumTool'.
cborTool :: APITool
cborTool = apiNodeTool $
             apiSpecTool gen_sn_to gen_sr_to gen_su_to gen_se_to mempty
             <> gen_pr

{-
instance Serialise JobId where
    encode = encode . _JobId
    decode = JobId <$> decode

In this version we don't check the @snFilter@, for simplicity and speed.
This is safe, since the CBOR code is used only internally as a data
representation format, not as a communication format with clients
that could potentially send faulty data.
-}

gen_sn_to :: Tool (APINode, SpecNewtype)
gen_sn_to = mkTool $ \ ts (an, sn) -> optionalInstanceD ts ''Serialise [nodeRepT an]
                                          [ simpleD 'encode (bdy_in an sn)
                                          , simpleD 'decode (bdy_out ts an sn)]
  where
    bdy_in an sn = [e| $(ine sn) . $(newtypeProjectionE an) |]
    bdy_out ts an sn = [e| $(nodeNewtypeConE ts an sn) <$> $(oute sn) |]

    ine sn = case snType sn of
            BTstring -> [e| encodeString |]
            BTbinary -> [e| encode |]
            BTbool   -> [e| encodeBool |]
            BTint    -> [e| encodeInt |]
            BTutc    -> [e| encode |]


    oute sn =
        case snType sn of
            BTstring -> [e| decodeString |]
            BTbinary -> [e| decode |]
            BTbool   -> [e| decodeBool |]
            BTint    -> [e| decodeInt |]
            BTutc    -> [e| decode |]



{-
instance Serialise JobSpecId where
     encode = \ x ->
        encodeMapLen 4 >>
        encodeRecordFields
            [ encodeString "Id"         <> encode (jsiId         x)
            , encodeString "Input"      <> encode (jsiInput      x)
            , encodeString "Output"     <> encode (jsiOutput     x)
            , encodeString "PipelineId" <> encode (jsiPipelineId x)
            ]
     decode (Record v) =
        decodeMapLen >>
        JobSpecId <$> (decodeString >> decode)
                  <*> (decodeString >> decode)
                  <*> (decodeString >> decode)
                  <*> (decodeString >> decode)

Note that fields are stored alphabetically ordered by field name, so
that we are insensitive to changes in field order in the schema.
-}

gen_sr_to :: Tool (APINode, SpecRecord)
gen_sr_to = mkTool $ \ ts (an, sr) -> do
    x <- newName "x"
    optionalInstanceD ts ''Serialise [nodeRepT an] [ simpleD 'encode (bdy_in an sr x)
                                                   , simpleD 'decode (cl an sr)
                                                   ]
  where
    bdy_in an sr x =
        let fields = sortFields sr
            len = fromIntegral (length fields)  -- to Integer
            lenE = varE 'fromIntegral  -- to Word
                     `appE` (sigE (litE (integerL len))
                                  (conT ''Integer))
            -- Micro-optimization: we use the statically known @len@ value
            -- instead of creating a list of thunks from the argument of
            -- @encodeRecordFields@ and dynamically calculating
            -- it's length, long before the list is fully forced.
            writeRecordHeader = varE 'encodeMapLen `appE` lenE
            encFields =
                varE 'encodeRecordFields `appE`
                    listE [ [e| encodeString $(fieldNameE fn)
                                <> encode ($(nodeFieldE an fn) $(varE x)) |]
                            | (fn, _fty) <- fields ]
        in lamE [varP x] $
               varE '(<>)
                 `appE` writeRecordHeader
                 `appE` encFields

    cl an sr    = varE '(>>)
                    `appE` (varE 'decodeMapLen)  -- TODO (extra check): check len with srFields
                    `appE` bdy
      where
        sorted_fields   = map fst $ sortFields sr
        original_fields = map fst $ srFields sr
        bdy = applicativeE dataCon $ map project sorted_fields
        project _fn = [e| decodeString >> decode |]
          -- TODO (correctness): check that $(fieldNameE fn) matches the decoded name
          -- and if not, use the default value, etc.

        -- If the fields are sorted, just use the data constructor,
        -- but if not, generate a reordering function like
        --   \ _foo_a _foo_b -> Con _foo_b _foo_a
        dataCon | sorted_fields == original_fields = nodeConE an
                | otherwise = lamE (map (nodeFieldP an) sorted_fields)
                                   (foldl' appE (nodeConE an) (map (nodeFieldE an) original_fields))

    sortFields sr = sortBy (comparing fst) $ srFields sr

{-
instance Serialise Foo where
    encode (Bar x) = encodeUnion "x" x
    encode (Baz x) = object [ "y" .= x ]
    decode = decodeUnion [ ("x", fmap Bar . decode)
                         , ("y", fmap Baz . decode) ]

-}

gen_su_to :: Tool (APINode, SpecUnion)
gen_su_to = mkTool $ \ ts (an, su) -> optionalInstanceD ts ''Serialise [nodeRepT an]
                                        [ funD    'encode (cls an su)
                                        , simpleD 'decode (bdy_out an su)
                                        ]
  where
    cls an su = map (cl an) (suFields su)

    cl an (fn, (_ty, _)) = do
      x <- newName "x"
      clause [nodeAltConP an fn [varP x]] (bdy fn x) []

    bdy fn x = normalB [e| encodeUnion $(fieldNameE fn) (encode $(varE x)) |]


    bdy_out an su = varE 'decodeUnion `appE` listE (map (alt an) (suFields su))

    alt an (fn, _) = [e| ( $(fieldNameE fn) , fmap $(nodeAltConE an fn) decode ) |]


{-
instance Serialise FrameRate where
    encode = encodeString . _text_FrameRate
    decode = decodeString >>= cborStrMap_p _map_FrameRate
-}

gen_se_to :: Tool (APINode, SpecEnum)
gen_se_to = mkTool $ \ ts (an, _se) -> optionalInstanceD ts ''Serialise [nodeRepT an]
                                         [ simpleD 'encode (bdy_in an)
                                         , simpleD 'decode (bdy_out an)
                                         ]
  where
    bdy_in an = [e| encodeString . $(varE (text_enum_nm an)) |]

    bdy_out an = [e| decodeString >>= cborStrMap_p $(varE (map_enum_nm an)) |]

-- In a monad, to @fail@ instead of crashing with @error@.
cborStrMap_p :: (Monad m, Ord a) => Map.Map T.Text a -> T.Text -> m a
cborStrMap_p mp t = case Map.lookup t mp of
  Nothing -> fail "Unexpected enumeration key in CBOR"
  Just r -> return r


gen_pr :: Tool APINode
gen_pr = mkTool $ \ ts an -> case anConvert an of
  Nothing               -> return []
  Just (inj_fn, prj_fn) -> optionalInstanceD ts ''Serialise [nodeT an] [ simpleD 'encode bdy_in
                                                                       , simpleD 'decode bdy_out
                                                                       ]
   where
    bdy_in  = [e| encode . $(fieldNameVarE prj_fn) |]
    bdy_out = [e| decode >>= $(fieldNameVarE inj_fn) |]