{-# LANGUAGE CPP, FlexibleContexts, FlexibleInstances, ScopedTypeVariables, TemplateHaskell #-}
module Language.Haskell.TH.TypeGraph.Free
    ( freeTypeVars
    , typeArity
    ) where

import Control.Lens hiding (Strict, cons)
import Control.Monad.State (MonadState, execStateT)
import Data.Set as Set (Set, delete, difference, empty, fromList, insert, member)
import Language.Haskell.TH
import Language.Haskell.TH.Desugar ({- instances -})
import Language.Haskell.TH.Syntax (Quasi(qReify))
import Language.Haskell.TH.TypeGraph.Prelude (pprint')

data St
    = St { _result :: Set Name
         , _stack :: Set Name
         } deriving Show

st0 :: St
st0 = St {_result = empty, _stack = empty}

$(makeLenses ''St)

-- | Compute the arity of a type - the number of type parameters that
-- must be applied to it in order to obtain a concrete type.  I'm not
-- quite sure I understand the relationship between this and 'freeTypeVars'.
typeArity :: Quasi m => Type -> m Int
typeArity (ForallT _ _ typ) = typeArity typ -- Shouldn't a forall affect the arity?
typeArity ListT = return 1
typeArity (TupleT n) = return n
typeArity (VarT _) = return 1
typeArity (AppT t _) = typeArity t >>= \ n -> return $ n - 1
typeArity (ConT name) = qReify name >>= infoArity
    where
      infoArity (TyConI dec) = decArity dec
      infoArity (PrimTyConI _ _ _) = return 0
      infoArity (FamilyI dec _) = decArity dec
      infoArity info = error $ "typeArity - unexpected: " ++ pprint' info
      decArity (DataD _ _ vs _ _) = return $ length vs
      decArity (NewtypeD _ _ vs _ _) = return $ length vs
      decArity (TySynD _ vs t) = typeArity t >>= \ n -> return $ n + length vs
      decArity (FamilyD _ _ vs _mk) = return $ {- not sure what to do with the kind mk here -} length vs
      decArity dec = error $ "decArity - unexpected: " ++ show dec
typeArity typ = error $ "typeArity - unexpected type: " ++ show typ

-- | Return the names of the type variables that are free in x.  I.e.,
-- type variables that appear in the type expression but are not bound
-- by an enclosing forall or by the type parameters of a Dec.
freeTypeVars :: (FreeTypeVars t, Quasi m) => t -> m (Set Name)
freeTypeVars x = view result <$> execStateT (ftv x) st0

-- | This is based on the freeNamesOfTypes function from the
-- th-desugar package.
class FreeTypeVars t where
    ftv :: (Quasi m, MonadState St m) => t -> m ()

instance FreeTypeVars a => FreeTypeVars [a] where
    ftv ts = mapM_ ftv ts

instance FreeTypeVars Type where
    ftv (ForallT tvbs cx ty) = do
      ftv ty
      mapM_ go_pred cx
      result %= (`Set.difference` (Set.fromList (map tvbName tvbs)))
        where
#if __GLASGOW_HASKELL__ >= 709
          go_pred typ =
              -- This looks wrong as the one below looks wrong.  Wronger maybe.
              ftv typ
#else
          go_pred (ClassP _ tys) = ftv tys
          go_pred (EqualP t1 t2) = do
            -- This looks wrong - we need to unify t1 and t2 and look
            -- at the free type variables in the resulting bindings
            ftv t1
            ftv t2
#endif
    ftv (SigT ty _) = ftv ty
    ftv (VarT n) = result %= Set.insert n
    ftv (AppT t1 t2) = {-trace ("go_app " ++ show typ) (return ()) >>-} go_app [t2] t1
    ftv typ@(ConT _) = {-trace ("go_app " ++ show typ) (return ()) >>-} go_app [] typ
    ftv _ = return ()


go_app :: (Quasi m, MonadState St m) => [Type] -> Type -> m ()
go_app params (AppT t1 t2) = go_app (t2 : params) t1
go_app params (ConT n) = do
    stk <- use stack
    case Set.member n stk of
      True -> return ()
      False -> do
        stack %= Set.insert n
        qReify n >>= go_info (reverse params)
go_app params typ = mapM_ ftv (typ : params)
go_info :: (Quasi m, MonadState St m) => [Type] -> Info -> m ()
go_info params (TyConI dec) = go_dec params ({-trace ("go_dec " ++ show dec)-} dec)
go_info params (FamilyI dec _insts) = go_dec params dec
go_info _params (PrimTyConI _name _arity _unlifed) = return ()
go_info _params info = error $ "go_info - unexpected: " ++ pprint' info
go_dec :: (Quasi m, MonadState St m) => [Type] -> Dec -> m ()
go_dec params (NewtypeD cx tname tvs con supers) = go_dec params (DataD cx tname tvs [con] supers)
go_dec params (DataD _ tname tvs _ _) | length params > length tvs = error $ "Too many arguments to " ++ show tname
go_dec params (DataD _cx tname tvs cons _supers) = do
  -- For each type variable bound to a type parameter,
  -- replace the type variable with the free variables
  -- in the parameter
  ftv cons
  go_params tname tvs params
go_dec params (TySynD tname tvs typ) = do
  -- Add the free variables in the type, then subtract the ones that
  -- are bound here.
  ftv typ
  go_params tname tvs params

-- I have a feeling this is utterly wrong.  Example, with this class:
--
-- class OrderKey k => OrderMap k where
--    data Order k :: * -> *
--    ...
--
-- the resulting declaration of Order is
--
--    FamilyD DataFam Language.Haskell.TH.Path.Order.Order [PlainTV k,PlainTV $a] (Just StarT)
--    params=[ConT AbbrevPairID]
--
-- so the parameter is bound to k, and $a should be free.
go_dec params (FamilyD _flavour tname tvs _mkind) = go_params tname tvs params
go_dec params dec = error $ "go_dec - unexpected: " ++ pprint' dec ++ ", params=" ++ show params

go_params :: (Quasi m, MonadState St m) => Name -> [TyVarBndr] -> [Type] -> m ()
go_params tname tvs params | length params  > length tvs = error $ "Too many arguments to " ++ show tname
go_params _ tvs params = mapM_ (uncurry go_param) (zip tvs (map Just params ++ repeat Nothing))

-- | Update the free variable set for a type parameter
go_param :: (Quasi m, MonadState St m) => TyVarBndr -> Maybe Type -> m ()
go_param tvb (Just param) = do
  -- If there is a binding, add the free variables found in the type
  -- and remove the variable bound here
  -- trace ("go_param " ++ "(" ++ pprint tvb ++ ", " ++ pprint' param ++ ")") (return ())
  ftv param
  result %= Set.delete (tvbName tvb)
  -- let tv = tvbName tvb
  -- r <- use result
  -- when (Set.member tv r) (ftv param >> result %= Set.delete tv)
go_param tvb Nothing = do
  -- If there is a variable not bound to a type parameter it is fee
  result %= Set.insert (tvbName tvb)

{-
instance FreeTypeVars Info where
    ftv (TyConI dec) = ftv dec

instance FreeTypeVars Dec where
    ftv dec@(DataD _ _ _ _ _ _) = ftv dec
#if __GLASGOW_HASKELL__ >= 709
    go_pred = go
#else
    go_pred (ClassP _ tys) = freeNamesOfTypes tys
    go_pred (EqualP t1 t2) = go t1 <> go t2
#endif
-}

instance FreeTypeVars Con where
    ftv (NormalC _name sts) = ftv sts
    ftv (RecC _name vsts) = ftv vsts
    ftv (InfixC st1 _ st2) = ftv [st1, st2]
    -- I'm not sure what effect this forall has.
    ftv (ForallC _tvbs _cx con) = ftv con

instance FreeTypeVars (Strict, Type) where
    ftv (_, typ) = ftv typ

instance FreeTypeVars (Name, Strict, Type) where
    ftv (_, _, typ) = ftv typ

-- | Extract a 'Name' from a 'TyVarBndr'
tvbName :: TyVarBndr -> Name
tvbName (PlainTV n)    = n
tvbName (KindedTV n _) = n