-- | Haskell Translation
module SSTG.Core.Translation.Haskell
    ( mkCompileClosure
    , mkTargetBindings
    , mkIOStr
    ) where

import qualified SSTG.Core.Syntax.Language as SL

import Coercion
import CorePrep
import CoreSyn
import CoreToStg
import DataCon
import FastString
import GHC
import GHC.Paths
import HscTypes
import Literal
import Name
import Outputable
import Pair
import PrimOp
import StgSyn
import TyCon
import TyCoRep
import Unique
import Var as V

import qualified Data.Maybe  as MB

-- | Make IO String from Outputable.
mkIOStr :: (Outputable a) => a -> IO String
mkIOStr obj = runGhc (Just libdir) $ do
    dflags <- getSessionDynFlags
    let ppr_str = showPpr dflags obj
    return ppr_str

-- | Given the project directory and the source file path, compiles the
-- `ModuleGraph` and translates it into a SSTG `Binding`s.
mkTargetBindings :: FilePath -> FilePath -> IO [SL.Binding]
mkTargetBindings proj src = do
    (sums_gutss, dflags, env) <- mkCompileClosure proj src
    let (sums, gutss) = (map fst sums_gutss, map snd sums_gutss)
    let mod_lcs = map (\s -> (ms_mod s, ms_location s)) sums
    let m_bndss = map mg_binds gutss
    let m_tcss  = map mg_tcs gutss
    -- Zip in preparation for STG transformation.
    let z1      = zip3 mod_lcs m_bndss m_tcss
    preps   <- mapM (\((m, l), b, t) -> corePrepPgm env m l b t) z1
    let z2      = zip (map fst mod_lcs) preps
    s_bndss <- mapM (\(m, p) -> coreToStg dflags m p) z2
    -- Create the bindings.
    let sl_bnds = map mkBinding (concat s_bndss)
    return sl_bnds

-- | Captures a snapshot of the `DynFlags` and `HscEnv` in addition to the
-- `ModGuts` in the `ModuleGraph`. This allows compilation to be, in theory,
-- more portable across different applications, since `ModGuts` is a crucial
-- intermediary for compilation in general.
mkCompileClosure :: FilePath -> FilePath ->
                    IO ([(ModSummary, ModGuts)], DynFlags, HscEnv)
mkCompileClosure proj src = runGhc (Just libdir) $ do
    beta_flags <- getSessionDynFlags
    let dflags = beta_flags { importPaths = [proj] }
    _          <- setSessionDynFlags dflags
    env        <- getSession
    target     <- guessTarget src Nothing
    _          <- setTargets [target]
    _          <- load LoadAllTargets
    -- Now that things are loaded, make the compilation closure.
    mod_graph  <- getModuleGraph
    pmods      <- mapM parseModule mod_graph
    tmods      <- mapM typecheckModule pmods
    dmods      <- mapM desugarModule tmods
    let m_gtss = map coreModule dmods
    let zipd   = (zip mod_graph m_gtss, dflags, env)
    return zipd

-- | Make SSTG `Expr`.
mkExpr :: StgExpr -> SL.Expr
mkExpr (StgLit lit) = SL.Atom (SL.LitAtom (mkLit lit))
mkExpr (StgApp occ args)    = SL.FunApp (mkVar occ) (map mkAtom args)
mkExpr (StgConApp dc args)  = SL.ConApp (mkData dc) (map mkAtom args)
mkExpr (StgOpApp op args _) = SL.PrimApp (mkPrimOp op) (map mkAtom args)
mkExpr (StgTick _ expr)     = mkExpr expr
mkExpr (StgLam _ _)         = error "mkExpr: StgLam detected"
mkExpr (StgLet bnd expr)    = SL.Let (mkBinding bnd) (mkExpr expr)
mkExpr (StgLetNoEscape _ _ bnd expr)     = mkExpr (StgLet bnd expr)
mkExpr (StgCase mexpr _ _ bndr _ _ alts) =
    SL.Case (mkExpr mexpr) (mkVar bndr) (map mkAlt alts)

-- | Make SSTG `Atom`.
mkAtom :: StgArg -> SL.Atom
mkAtom (StgVarArg occ) = SL.VarAtom (mkVar occ)
mkAtom (StgLitArg lit) = SL.LitAtom (mkLit lit)

-- | Make SSTG `Name`.
mkName :: Name -> SL.Name
mkName name = SL.Name occ mdl ns unq
  where occ = (occNameString . nameOccName) name
        ns  = (mkNameSpace . occNameSpace . nameOccName) name
        unq = (getKey . nameUnique) name
        mdl = case nameModule_maybe name of
            Nothing -> Nothing
            Just md -> Just ((moduleNameString . moduleName) md)

-- | Make SSTG `NameSpace`.
mkNameSpace :: NameSpace -> SL.NameSpace
mkNameSpace ns | isVarNameSpace ns     = SL.VarNSpace
               | isTvNameSpace  ns     = SL.TvNSpace
               | isDataConNameSpace ns = SL.DataNSpace
               | isTcClsNameSpace ns   = SL.TcClsNSpace
               | otherwise = error "mkNameSpace: unrecognized namespace"

-- | Make SSTG Var
mkVar :: Var -> SL.Var
mkVar var = SL.Var vname vtype
  where vname = (mkName . V.varName) var
        vtype = (mkType . varType) var

-- | Make SSTG Binding
mkBinding :: StgBinding -> SL.Binding
mkBinding (StgNonRec bnd rhs) = SL.Binding SL.NonRec [(mkVar bnd, mkRhs rhs)]
mkBinding (StgRec bnd) = SL.Binding SL.Rec
                                    (map (\(b, r) -> (mkVar b, mkRhs r)) bnd)

-- | Make SSTG `BindRhs`.
mkRhs :: StgRhs -> SL.BindRhs
mkRhs (StgRhsCon _ dc args) = SL.ConForm (mkData dc) (map mkAtom args)
mkRhs (StgRhsClosure _ _ _ _ _ params expr) =
    SL.FunForm (map mkVar params) (mkExpr expr)

-- | Make SSTG `Lit`.
mkLit :: Literal -> SL.Lit
mkLit lit = case lit of
  (MachChar chr)    -> SL.MachChar chr ((mkType . literalType) lit)
  (MachStr bstr)    -> SL.MachStr (show bstr)   ((mkType . literalType) lit)
  (MachInt i)       -> SL.MachInt (fromInteger i) ((mkType . literalType) lit)
  (MachInt64 i)     -> SL.MachInt (fromInteger i) ((mkType . literalType) lit)
  (MachWord i)      -> SL.MachWord (fromInteger i) ((mkType . literalType) lit)
  (MachWord64 i)    -> SL.MachWord (fromInteger i) ((mkType . literalType) lit)
  (MachFloat rat)   -> SL.MachFloat rat ((mkType . literalType) lit)
  (MachDouble rat)  -> SL.MachDouble rat ((mkType . literalType) lit)
  (LitInteger i _)  -> SL.MachInt (fromInteger i) ((mkType . literalType) lit)
  MachNullAddr      -> SL.MachNullAddr ((mkType . literalType) lit)
  (MachLabel f m _) -> SL.MachLabel (unpackFS f) m ((mkType . literalType) lit)

-- | Make SSTG `ConTag`.
mkDataTag :: DataCon -> SL.ConTag
mkDataTag datacon = SL.ConTag name tag
  where name = (mkName . dataConName) datacon
        tag  = dataConTag datacon

-- | Make SSTG `DataCon`.
mkData :: DataCon -> SL.DataCon
mkData datacon = SL.DataCon dcid ty args
  where dcid = mkDataTag datacon
        ty   = (mkType . dataConRepType) datacon
        args = map mkType (dataConOrigArgTys datacon)

-- | Make SSTG `PrimFun`.
mkPrimOp :: StgOp -> SL.PrimFun
mkPrimOp (StgPrimOp op) = SL.PrimFun (SL.Name occ Nothing ns unq) ty
  where occname = primOpOcc op
        occ     = occNameString occname
        ns      = (mkNameSpace . occNameSpace) occname
        unq     = primOpTag op
        ty      = (mkType . primOpType) op
mkPrimOp _              = error "mkPrimOp: got StgPrimCallOp or StgFCallOp"

-- | Make SSTG `Alt`.
mkAlt :: StgAlt -> SL.Alt
mkAlt (a, b, _, e) = SL.Alt (mkAltCon a) (map mkVar b) (mkExpr e)

-- | Make SSTG `AltCon`.
mkAltCon :: AltCon -> SL.AltCon
mkAltCon (DataAlt dc) = SL.DataAlt (mkData dc)
mkAltCon (LitAlt lit) = SL.LitAlt  (mkLit lit)
mkAltCon (DEFAULT)    = SL.Default

-- | Make SSTG `Type`.
mkType :: Type -> SL.Type
mkType (TyVarTy v) = SL.TyVarTy (mkName (V.varName v)) (mkType (varType v))
mkType (AppTy t1 t2)    = SL.AppTy (mkType t1) (mkType t2)
mkType (TyConApp tc ts) = SL.TyConApp (mkTyCon tc) (map mkType ts)
mkType (ForAllTy b ty)  = SL.ForAllTy (mkTyBndr b) (mkType ty)
mkType (LitTy tlit)     = SL.LitTy (mkTyLit tlit)
mkType (CastTy ty cor)  = SL.CastTy (mkType ty) (mkCoercion cor)
mkType (CoercionTy cor) = SL.CoercionTy (mkCoercion cor)

-- | Make SSTG `TyCon`.
mkTyCon :: TyCon -> SL.TyCon
mkTyCon tc | isFunTyCon         tc = SL.FunTyCon     name
           | isAlgTyCon         tc = SL.AlgTyCon     name algrhs
           | isFamilyTyCon      tc = SL.FamilyTyCon  name
           | isPrimTyCon        tc = SL.PrimTyCon    name
           | isTcTyCon          tc = SL.TcTyCon      name
           | isTypeSynonymTyCon tc = SL.SynonymTyCon name
           | isPromotedDataCon  tc = SL.Promoted     name dcon
           | otherwise = error "mkTyCon: unrecognized TyCon"
  where name   = (mkName . tyConName) tc
        algrhs = (mkAlgTyConRhs . algTyConRhs) tc
        dcon   = (mkData . MB.fromJust . isPromotedDataCon_maybe) tc

-- | Make SSTG `AlgTyRhs`.
mkAlgTyConRhs :: AlgTyConRhs -> SL.AlgTyRhs
mkAlgTyConRhs (AbstractTyCon b) = SL.AbstractTyCon b
mkAlgTyConRhs (DataTyCon {data_cons = dcs}) = SL.DataTyCon  (map mkDataTag dcs)
mkAlgTyConRhs (TupleTyCon {data_con = dc})  = SL.TupleTyCon (mkDataTag dc)
mkAlgTyConRhs (NewTyCon {data_con = dc})    = SL.NewTyCon   (mkDataTag dc)

-- | make SSTG `TyBinder`.
mkTyBndr :: TyBinder -> SL.TyBinder
mkTyBndr (Anon ty)   = SL.AnonTyBndr  (mkType ty)
mkTyBndr (Named v _) = SL.NamedTyBndr (mkName (V.varName v))
                                      (mkType (varType v))

-- | Make SSTG `Type` literals.
mkTyLit :: TyLit -> SL.TyLit
mkTyLit (NumTyLit i)  = SL.NumTyLit (fromInteger i)
mkTyLit (StrTyLit fs) = SL.StrTyLit (unpackFS fs)

-- | Make SSTG `Coercion`.
mkCoercion :: Coercion -> SL.Coercion
mkCoercion coer = SL.Coercion (mkType a) (mkType b)
  where (a, b) = (unPair . coercionKind) coer