{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE CPP #-}

-- | This module is used to perform a dependency analysis of top-level
-- function definitions, i.e. to find out which defintions are
-- (mutual) recursive. To this end, this module also provides a
-- functions to compute, bound variables and variable occurrences.

module Rattus.Plugin.Dependency (dependency, HasBV (..)) where


import GhcPlugins
import Bag


#if __GLASGOW_HASKELL__ >= 810
import GHC.Hs.Extension
import GHC.Hs.Expr
import GHC.Hs.Pat
import GHC.Hs.Binds
import GHC.Hs.Types
#else 
import HsExtension
import HsExpr
import HsPat
import HsBinds
import HsTypes
#endif

import Data.Set (Set)
import qualified Data.Set as Set
import Data.Graph
import Data.Maybe
import Data.Either
import Prelude hiding ((<>))





-- | Compute the dependencies of a bag of bindings, returning a list
-- of the strongly-connected components.
dependency :: Bag (LHsBindLR GhcTc GhcTc) -> [SCC (LHsBindLR GhcTc GhcTc, Set Var)]
dependency binds = map AcyclicSCC noDeps ++ catMaybes (map filterJust (stronglyConnComp (concat deps)))
  where (deps,noDeps) = partitionEithers $ map mkDep $ bagToList binds
        mkDep :: GenLocated l (HsBindLR GhcTc GhcTc) ->
                 Either [(Maybe (GenLocated l (HsBindLR GhcTc GhcTc), Set Var), Name, [Name])]
                 (GenLocated l (HsBindLR GhcTc GhcTc), Set Var)
        mkDep b =
          let dep = map varName $ Set.toList (getFV b)
              vars = getBV b in
          case Set.toList vars of
            (v:vs) -> Left ((Just (b,vars), varName v , dep) : map (\ v' -> (Nothing, varName v' , dep)) vs)
            [] -> Right (b,vars)
        filterJust (AcyclicSCC Nothing) = Nothing -- this should not happen
        filterJust (AcyclicSCC (Just b)) = Just (AcyclicSCC b)
        filterJust (CyclicSCC bs) = Just (CyclicSCC (catMaybes bs))


-- printBinds (AcyclicSCC bind) = liftIO (putStr "acyclic bind: ") >> printBind (fst bind) >> liftIO (putStrLn "") 
-- printBinds (CyclicSCC binds) = liftIO (putStr "cyclic binds: ") >> mapM_ (printBind . fst) binds >> liftIO (putStrLn "") 


-- printBind (L _ FunBind{fun_id = L _ name}) = 
--   liftIO $ putStr $ (getOccString name ++ " ")
-- printBind (L _ (AbsBinds {abs_exports = exp})) = 
--   mapM_ (\ e -> liftIO $ putStr $ ((getOccString $ abe_poly e)  ++ " ")) exp
-- printBind (L _ (VarBind {var_id = name})) =   liftIO $ putStr $ (getOccString name ++ " ")
-- printBind _ = return ()

-- | Computes the variables that are bound by a given piece of syntax.

class HasBV a where
  getBV :: a -> Set Var

instance HasBV (HsBindLR GhcTc GhcTc) where
  getBV (FunBind{fun_id = L _ v}) = Set.singleton v
  getBV (AbsBinds {abs_exports = es}) = Set.fromList (map abe_poly es)
  getBV (PatBind {pat_lhs = pat}) = getBV pat
  getBV (VarBind {var_id = v}) = Set.singleton v
  getBV PatSynBind{} = Set.empty
  getBV XHsBindsLR{} = Set.empty

instance HasBV a => HasBV (GenLocated b a) where
  getBV (L _ e) = getBV e

instance HasBV a => HasBV [a] where
  getBV ps = foldl (\s p -> getBV p `Set.union` s) Set.empty ps



getConBV (PrefixCon ps) = getBV ps
getConBV (InfixCon p p') = getBV p `Set.union` getBV p'
getConBV (RecCon (HsRecFields {rec_flds = fs})) = foldl run Set.empty fs
      where run s (L _ f) = getBV (hsRecFieldArg f) `Set.union` s

instance HasBV (Pat GhcTc) where
  getBV (VarPat _ (L _ v)) = Set.singleton v
  getBV (LazyPat _ p) = getBV p
  getBV (AsPat _ (L _ v) p) = Set.insert v (getBV p)
  getBV (ParPat _ p) = getBV p
  getBV (BangPat _ p) = getBV p
  getBV (ListPat _ ps) = getBV ps
  getBV (TuplePat _ ps _) = getBV ps
  getBV (SumPat _ p _ _) = getBV p
  getBV (ConPatIn (L _ v) con) = Set.insert v (getConBV con)
  getBV (ConPatOut {pat_args = con}) = getConBV con
  getBV (ViewPat _ _ p) = getBV p
  getBV (SplicePat _ sp) =
    case sp of
      HsTypedSplice _ _ v _ -> Set.singleton v
      HsUntypedSplice _ _ v _ ->  Set.singleton v
      HsQuasiQuote _ p p' _ _ -> Set.fromList [p,p']
      HsSpliced _ _ (HsSplicedPat p) -> getBV p
      _ -> Set.empty
  getBV (NPlusKPat _ (L _ v) _ _ _ _) = Set.singleton v
  getBV (CoPat _ _ p _) = getBV p
  getBV (NPat {}) = Set.empty
  getBV (XPat p) = getBV p
  getBV (WildPat {}) = Set.empty
  getBV (LitPat {}) = Set.empty

#if __GLASGOW_HASKELL__ >= 808
  getBV (SigPat _ p _) =
#else
  getBV (SigPat _ p) =
#endif
    getBV p

#if __GLASGOW_HASKELL__ >= 810
instance HasBV NoExtCon where
#else
instance HasBV NoExt where
#endif
  getBV _ = Set.empty


-- | Syntax that may contain variables.
class HasFV a where
  -- | Compute the set of variables occurring in the given piece of
  -- syntax.  The name falsely suggests that returns free variables,
  -- but in fact it returns all variable occurrences, no matter
  -- whether they are free or bound.
  getFV :: a -> Set Var

instance HasFV a => HasFV (GenLocated b a) where
  getFV (L _ e) = getFV e

instance HasFV a => HasFV [a] where
  getFV es = foldMap getFV es

instance HasFV a => HasFV (Bag a) where
  getFV es = foldMap getFV es

instance HasFV Var where
  getFV v = Set.singleton v

instance HasFV a => HasFV (MatchGroup GhcTc a) where
  getFV MG {mg_alts = alts} = getFV alts
  getFV XMatchGroup{} = Set.empty

instance HasFV a => HasFV (Match GhcTc a) where
  getFV Match {m_grhss = rhss} = getFV rhss
  getFV XMatch{} = Set.empty

instance HasFV (HsTupArg GhcTc) where
  getFV (Present _ e) = getFV e
  getFV _ = Set.empty


instance HasFV a => HasFV (GRHS GhcTc a) where
  getFV (GRHS _ g b) = getFV g `Set.union` getFV b
  getFV XGRHS{} = Set.empty

instance HasFV a => HasFV (GRHSs GhcTc a) where
  getFV GRHSs {grhssGRHSs = rhs, grhssLocalBinds = lbs} =
    getFV rhs `Set.union` getFV lbs
  getFV _ = Set.empty


instance HasFV (HsLocalBindsLR GhcTc GhcTc) where
  getFV (HsValBinds _ bs) = getFV bs
  getFV (HsIPBinds _ bs) = getFV bs
  getFV _ = Set.empty

instance HasFV (HsValBindsLR GhcTc GhcTc) where
  getFV (ValBinds _ b _) = getFV b
  getFV _ = Set.empty

instance HasFV (HsBindLR GhcTc GhcTc) where
  getFV FunBind {fun_matches = ms} = getFV ms
  getFV PatBind {pat_rhs = rhs} = getFV rhs
  getFV VarBind {var_rhs = rhs} = getFV rhs
  getFV AbsBinds {abs_binds = bs} = getFV bs
  getFV _ = Set.empty


instance HasFV (IPBind GhcTc) where
  getFV (IPBind _ _ e) = getFV e
  getFV _ = Set.empty

instance HasFV (HsIPBinds GhcTc) where
  getFV (IPBinds _ bs) = getFV bs
  getFV _ = Set.empty

instance HasFV (ApplicativeArg GhcTc) where
#if __GLASGOW_HASKELL__ >= 810
  getFV (ApplicativeArgOne _ _ e _ _)
#else
  getFV (ApplicativeArgOne _ _ e _)
#endif
    = getFV e
  getFV (ApplicativeArgMany _ es e _) = getFV es `Set.union` getFV e
  getFV XApplicativeArg{} = Set.empty

instance HasFV (ParStmtBlock GhcTc GhcTc) where
  getFV (ParStmtBlock _ es _ _) = getFV es
  getFV XParStmtBlock{} = Set.empty

instance HasFV a => HasFV (StmtLR GhcTc GhcTc a) where
  getFV (LastStmt _ e _ _) = getFV e
  getFV (BindStmt _ _ e _ _) = getFV e
  getFV (ApplicativeStmt _ args _) = foldMap (getFV . snd) args
  getFV (BodyStmt _ e _ _) = getFV e
  getFV (LetStmt _ bs) = getFV bs
  getFV (ParStmt _ stms e _) = getFV stms `Set.union` getFV e
  getFV TransStmt{} = Set.empty -- TODO
  getFV RecStmt{} = Set.empty -- TODO
  getFV XStmtLR{} = Set.empty

instance HasFV (HsRecordBinds GhcTc) where
  getFV HsRecFields{rec_flds = fs} = getFV fs

instance HasFV (HsRecField' o (LHsExpr GhcTc)) where
  getFV HsRecField {hsRecFieldArg = arg}  = getFV arg

instance HasFV (ArithSeqInfo GhcTc) where
  getFV (From e) = getFV e
  getFV (FromThen e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (FromTo e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (FromThenTo e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3

instance HasFV (HsBracket GhcTc) where
  getFV (ExpBr _ e) = getFV e
  getFV (VarBr _ _ e) = getFV e
  getFV _ = Set.empty

instance HasFV (HsCmd GhcTc) where
  getFV (HsCmdArrApp _ e1 e2 _ _) = getFV e1 `Set.union` getFV e2
  getFV (HsCmdArrForm _ e _ _ cmd) = getFV e `Set.union` getFV cmd
  getFV (HsCmdApp _ e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (HsCmdLam _ l) = getFV l
  getFV (HsCmdPar _ cmd) = getFV cmd
  getFV (HsCmdCase _ _ mg) = getFV mg
  getFV (HsCmdIf _ _ e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
  getFV (HsCmdLet _ bs _) = getFV bs
  getFV (HsCmdDo _ cmd) = getFV cmd
  getFV (HsCmdWrap _ _ cmd) = getFV cmd
  getFV XCmd{} = Set.empty


instance HasFV (HsCmdTop GhcTc) where
  getFV (HsCmdTop _ cmd) = getFV cmd
  getFV XCmdTop{} = Set.empty

instance HasFV (HsExpr GhcTc) where
  getFV (HsVar _ v) = getFV v
  getFV HsUnboundVar {} = Set.empty
  getFV HsConLikeOut {} = Set.empty
  getFV HsRecFld {} = Set.empty
  getFV HsOverLabel {} = Set.empty
  getFV HsIPVar {} = Set.empty
  getFV HsOverLit {} = Set.empty
  getFV HsLit {} = Set.empty
  getFV (HsLam _ mg) = getFV mg
  getFV (HsLamCase _ mg) = getFV mg
  getFV (HsApp _ e1 e2) = getFV e1 `Set.union` getFV e2

#if __GLASGOW_HASKELL__ >= 808
  getFV (HsAppType _ e _)
#else
  getFV (HsAppType _ e)
#endif
    = getFV e

  getFV (OpApp _ e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
  getFV (NegApp _ e _) = getFV e
  getFV (HsPar _ e) = getFV e
  getFV (SectionL _ e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (SectionR _ e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (ExplicitTuple _ es _) = getFV es
  getFV (ExplicitSum _ _ _ e) = getFV e
  getFV (HsCase _ e mg) = getFV e  `Set.union` getFV mg
  getFV (HsIf _ _ e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
  getFV (HsMultiIf _ es) = getFV es
  getFV (HsLet _ bs e) = getFV bs `Set.union` getFV e
  getFV (HsDo _ _ e) = getFV e
  getFV (ExplicitList _ _ es) = getFV es
  getFV (RecordCon {rcon_flds = fs}) = getFV fs
  getFV (RecordUpd {rupd_expr = e, rupd_flds = fs}) = getFV e `Set.union` getFV fs

#if __GLASGOW_HASKELL__ >= 808
  getFV (ExprWithTySig _ e _)
#else
  getFV (ExprWithTySig _ e)
#endif
    = getFV e

  getFV (ArithSeq _ _ e) = getFV e
  getFV (HsSCC _ _ _ e) = getFV e
  getFV (HsCoreAnn _ _ _ e) = getFV e
  getFV (HsBracket _ e) = getFV e
  getFV HsRnBracketOut {} = Set.empty
  getFV HsTcBracketOut {} = Set.empty
  getFV HsSpliceE{} = Set.empty
  getFV (HsProc _ _ e) = getFV e
  getFV (HsStatic _ e) = getFV e

#if __GLASGOW_HASKELL__ < 810  
  getFV (HsArrApp _ e1 e2 _ _) = getFV e1 `Set.union` getFV e2
  getFV (HsArrForm _ e _ cmd) = getFV e `Set.union` getFV cmd
  getFV EWildPat {} = Set.empty
  getFV (EAsPat _ e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (EViewPat _ e1 e2) = getFV e1 `Set.union` getFV e2
  getFV (ELazyPat _ e) = getFV e
#endif

  getFV (HsTick _ _ e) = getFV e
  getFV (HsBinTick _ _ _ e) = getFV e
  getFV (HsTickPragma _ _ _ _ e) = getFV e
  getFV (HsWrap _ _ e) = getFV e
  getFV XExpr{} = Set.empty