{-# LANGUAGE CPP             #-}
{-# LANGUAGE LambdaCase      #-}
{-# LANGUAGE TemplateHaskell #-}

module Data.InvertibleGrammar.TH where

#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Data.Foldable (toList)
import Data.InvertibleGrammar.Base
import Data.Maybe
import Data.Text (pack)
import Language.Haskell.TH as TH
import Data.Set (Set)
import qualified Data.Set as S
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup ((<>))
#endif


{- | Build a prism and the corresponding grammar that will match on the
     given constructor and convert it to reverse sequence of :- stacks.

     E.g. consider a data type:

     > data FooBar a b c = Foo a b c | Bar

     For constructor Foo

     > fooGrammar = $(grammarFor 'Foo)

     will expand into

     > fooGrammar = PartialIso
     >   (\(c :- b :- a :- t) -> Foo a b c :- t)
     >   (\case { Foo a b c :- t -> Just $ c :- b :- a :- t; _ -> Nothing })

     Note the order of elements on the stack:

     > ghci> :t fooGrammar
     > fooGrammar :: Grammar p (c :- (b :- (a :- t))) (FooBar a b c :- t)
-}

grammarFor :: Name -> ExpQ
grammarFor constructorName = do
#if defined(__GLASGOW_HASKELL__)
# if __GLASGOW_HASKELL__ <= 710
  DataConI realConstructorName _typ parentName _fixity <- reify constructorName
# else
  DataConI realConstructorName _typ parentName <- reify constructorName
# endif
#endif
  TyConI dataDef <- reify parentName

  let Just (single, constructorInfo) = do
        (single, allConstr) <- constructors dataDef
        constr <- findConstructor realConstructorName allConstr
        return (single, constr)

  let ts = fieldTypes constructorInfo
  vs <- mapM (const $ newName "x") ts
  t <- newName "t"

  let matchStack []      = varP t
      matchStack (_v:vs) = [p| $(varP _v) :- $_vs' |]
        where
          _vs' = matchStack vs
      fPat  = matchStack vs
      buildConstructor = foldr (\v acc -> appE acc (varE v)) (conE realConstructorName) vs
      fBody = [e| $buildConstructor :- $(varE t) |]
      fFunc = lamE [fPat] fBody

  let gPat  = [p| $_matchConsructor :- $(varP t) |]
        where
          _matchConsructor = conP realConstructorName (map varP (reverse vs))
      gBody = foldr (\v acc -> [e| $(varE v) :- $acc |]) (varE t) vs
      gFunc = lamCaseE $ catMaybes
        [ Just $ TH.match gPat (normalB [e| Right ($gBody) |]) []
        , if single
          then Nothing
          else Just $ TH.match wildP (normalB [e| Left (expected $ "constructor " <> pack ( $(stringE (show constructorName))) ) |]) []
        ]

  [e| PartialIso $fFunc $gFunc |]


{- | Build prisms and corresponding grammars for all data constructors of given
     type. Expects grammars to zip built ones with.

     > $(match ''Maybe)

     Will expand into a lambda:

     > (\nothingG justG -> ($(grammarFor 'Nothing) . nothingG) <>
     >                     ($(grammarFor 'Just)    . justG))
-}
match :: Name -> ExpQ
match tyName = do
  names  <- concatMap (toList . constructorNames) <$> (extractConstructors =<< reify tyName)
  argTys <- mapM (\_ -> newName "a") names
  let grammars = map (\(con, arg) -> [e| $(varE arg) $(grammarFor con) |]) (zip names argTys)
  lamE (map varP argTys) (foldr1 (\e1 e2 -> [e| $e1 <> $e2 |]) grammars)
  where
    extractConstructors :: Info -> Q [Con]
    extractConstructors (TyConI dataDef) =
      case constructors dataDef of
        Just (_, cs) -> pure cs
        Nothing      -> fail $ "Data type " ++ show tyName ++ " defines no constructors"
    extractConstructors _ =
      fail $ "Data definition expected for name " ++ show tyName

----------------------------------------------------------------------
-- Utils

constructors :: Dec -> Maybe (Bool, [Con])
#if defined(__GLASGOW_HASKELL__)
# if __GLASGOW_HASKELL__ <= 710
constructors (DataD _ _ _ cs _)     = Just (length cs == 1, cs)
constructors (NewtypeD _ _ _ c _)   = Just (True, [c])
# else
constructors (DataD _ _ _ _ cs _)   = Just (length cs == 1, cs)
constructors (NewtypeD _ _ _ _ c _) = Just (True, [c])
# endif
#endif
constructors _                      = Nothing

findConstructor :: Name -> [Con] -> Maybe Con
findConstructor _    [] = Nothing
findConstructor name (c:cs)
  | name `S.member` constructorNames c = Just c
  | otherwise                          = findConstructor name cs

constructorNames :: Con -> Set Name
constructorNames = \case
  NormalC name _   -> S.singleton name
  RecC name _      -> S.singleton name
  InfixC _ name _  -> S.singleton name
  ForallC _ _ con' -> constructorNames con'
#if MIN_VERSION_template_haskell(2, 11, 0)
  GadtC cs _ _     -> S.fromList cs
  RecGadtC cs _ _  -> S.fromList cs
#endif

fieldTypes :: Con -> [Type]
fieldTypes = \case
  NormalC _ fieldTypes  -> map extractType fieldTypes
  RecC _ fieldTypes     -> map extractType' fieldTypes
  InfixC (_,a) _b (_,b) -> [a, b]
  ForallC _ _ con'      -> fieldTypes con'
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
  GadtC _ fs _          -> map extractType fs
  RecGadtC _ fs _       -> map extractType' fs
#endif
  where
    extractType (_, t) = t
    extractType' (_, _, t) = t