{-# LANGUAGE TemplateHaskell #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Test.StateMachine.Types.HFunctor.TH
-- Copyright   :  (C) 2017, ATS Advanced Telematic Systems GmbH, Li-yao Xia
-- License     :  BSD-style (see the file LICENSE)
--
-- Maintainer  :  Li-yao Xia <lysxia@gmail.com>
-- Stability   :  provisional
-- Portability :  non-portable (GHC extensions)
--
-- Template Haskell functions to derive higher-order structures.
--
-----------------------------------------------------------------------------

module Test.StateMachine.Types.HFunctor.TH
  ( deriveHClasses
  , deriveHTraversable
  , mkhtraverse
  , deriveHFoldable
  , mkhfoldMap
  , deriveHFunctor
  , mkhfmap
  ) where

import           Control.Applicative
                   (liftA3)
import           Control.Monad
                   (when, (>=>))
import           Data.Foldable
                   (foldl')
import           Data.Monoid
                   (mempty, (<>))
import qualified Data.Set                         as Set
import           Data.Traversable
                   (for)
import           Language.Haskell.TH
import           Language.Haskell.TH.Datatype

import           Test.StateMachine.Internal.Utils
                   (dropLast, nub, toLast)
import           Test.StateMachine.Types.HFunctor

-- | Derive 'HFunctor', 'HFoldable', 'HTraversable'.
deriveHClasses :: Name -> Q [Dec]
deriveHClasses =
  (liftA3 . liftA3) (\a b c -> a ++ b ++ c)
    deriveHFunctor
    deriveHFoldable
    deriveHTraversable

-- |
-- @
-- 'deriveHTraversable' ''Action
-- ===>
-- instance 'HTraversable' Action where ...
-- @
deriveHTraversable :: Name -> Q [Dec]
deriveHTraversable = reifying deriveIFor dictHTraversable

-- | Derive the body of 'htraverse'.
mkhtraverse :: Name -> Q Exp
mkhtraverse = reifying mkFFor dictHTraversable

-- |
-- @
-- 'deriveHFoldable' ''Action
-- ===>
-- instance 'HFoldable' Action where ...
-- @
deriveHFoldable :: Name -> Q [Dec]
deriveHFoldable = reifying deriveIFor dictHFoldable

-- | Derive the body of 'hfoldMap'.
mkhfoldMap :: Name -> Q Exp
mkhfoldMap = reifying mkFFor dictHFoldable

-- |
-- @
-- 'deriveHFunctor' ''Action
-- ===>
-- instance 'HFunctor' Action where ...
-- @
deriveHFunctor :: Name -> Q [Dec]
deriveHFunctor = reifying deriveIFor dictHFunctor

-- | Derive the body of 'hfmap'.
mkhfmap :: Name -> Q Exp
mkhfmap = reifying mkFFor dictHFunctor

data Dictionary = Dictionary
  { className :: Name
  , funName   :: Name
  , pureE     :: Exp -> Exp
  , apE       :: Exp -> Exp -> Exp
  }

dictHFunctor :: Dictionary
dictHFunctor = Dictionary
  { className = ''HFunctor
  , funName = 'hfmap
  , pureE = id
  , apE = AppE
  }

dictHFoldable :: Dictionary
dictHFoldable = Dictionary
  { className = ''HFoldable
  , funName = 'hfoldMap
  , pureE = const (VarE 'mempty)
  , apE = apE'
  } where
    -- mempty <> e = e
    -- e <> mempty = e
    apE' (VarE m) e | m == 'mempty = e
    apE' e (VarE m) | m == 'mempty = e
    apE' e1 e2      = infixE_ e1 '(<>) e2

dictHTraversable :: Dictionary
dictHTraversable = Dictionary
  { className = ''HTraversable
  , funName = 'htraverse
  , pureE = AppE (VarE 'pure)
  , apE = apE'
  } where
    -- pure f <*> v = f <$> v
    apE' (AppE (VarE pure_) f) v | pure_ == 'pure = infixE_ f '(<$>) v
    apE' u v                     = infixE_ u '(<*>) v

reifying :: (Dictionary -> DatatypeInfo -> Q r) -> Dictionary -> Name -> Q r
reifying derive dict = reifyDatatype >=> derive dict

deriveIFor :: Dictionary -> DatatypeInfo -> Q [Dec]
deriveIFor dict info = fmap (: []) $ do
  when (length (datatypeVars info) < 2)
    (fail $ "Type " ++ show (datatypeName info) ++ " should have arity >= 2")
  (cxt_, htraversalDec) <- htraversalWithCxtFor dict info
  let instanceHead = AppT
        (ConT (className dict))
        (foldl' AppT (ConT (datatypeName info)) (dropLast 2 (datatypeVars info)))
  return
    (InstanceD Nothing cxt_ instanceHead [htraversalDec])

mkFFor :: Dictionary -> DatatypeInfo -> Q Exp
mkFFor dict info =
  fmap mkF (htraversalBodyFor dict info)
  where
    mkF (_, pats, body) = LamE pats body

htraversalWithCxtFor :: Dictionary -> DatatypeInfo -> Q (Cxt, Dec)
htraversalWithCxtFor dict info =
  fmap mkFunD (htraversalBodyFor dict info)
  where
    mkFunD (cxt_, pats, body) =
      (cxt_, FunD (funName dict) [Clause pats (NormalB body) []])

htraversalBodyFor :: Dictionary -> DatatypeInfo -> Q (Cxt, [Pat], Exp)
htraversalBodyFor dict info = do
  fN <- newName "f"
  aN <- newName "a"
  let SigT v _ = toLast 1 (datatypeVars info)
  tucs <- traverse (htraversalMatchFor dict v (VarE fN)) (datatypeCons info)
  let (ts, usedF', matches) = unzip3 tucs
      usedF = or usedF'
      fP = if usedF then VarP fN else WildP
      pats = [fP, VarP aN]
      cxt_ = fmap (AppT (ConT (className dict))) (nub (concat ts))
  return (cxt_, pats, CaseE (VarE aN) matches)

htraversalMatchFor :: Dictionary -> Type -> Exp -> ConstructorInfo -> Q ([Type], Bool, Match)
htraversalMatchFor dict v f info = do
  xts <- for (constructorFields info) (\t ->  fmap (\x -> (x, t)) (newName "x"))
  cyfs <- for xts (uncurry (htraversalFieldFor dict v f))
  let conPattern = ConP (constructorName info) [mkVarP x | (x, _) <- xts]
      -- HFoldable instances may have unused fields, replaced with wildcards.
      mkVarP x | className dict == ''HFoldable && x `Set.member` ys = WildP
               | otherwise = VarP x
      c = ConE (constructorName info)
      (cnstrnts', ys', fields) = unzip3 cyfs
      -- f gets used if at least one field did not use pure
      usedF = any null ys'
      cnstrnts = concat cnstrnts'
      ys = Set.fromList (concat ys')
      body = foldl' (apE dict) (pureE dict c) fields
  return
    (cnstrnts, usedF, Match conPattern (NormalB body) [])

infixE_ :: Exp -> Name -> Exp -> Exp
infixE_ x (+.) y = InfixE (Just x) (VarE (+.)) (Just y)

htraversalFieldFor :: Dictionary -> Type -> Exp -> Name -> Type -> Q ([Type], [Name], Exp)
htraversalFieldFor dict v f x' t' = do
  let x = VarE x'
  t <- resolveTypeSynonyms t'
  return $ case t of
    AppT (AppT u v') _ | v == v' ->
      ( [u | variableHead u]
      , []
      , VarE (funName dict) `AppE` f `AppE` x)
    AppT v' _ | v == v' ->
      ([], [], f `AppE` x)
    _ ->
      ([], [x'], pureE dict x)

variableHead :: Type -> Bool
variableHead (AppT u _) = variableHead u
variableHead (VarT _)   = True
variableHead _          = False