-- | Haskell Translation
module SSTG.Core.Translation.Haskell
    ( CompileClosure
    , mkCompileClosure
    , mkTargetBinds
    , mkIOString
    ) where

import qualified SSTG.Core.Language as SL

import Coercion
import CorePrep
import CoreSyn
import CoreToStg
import DataCon
import FastString
import GHC
import GHC.Paths
import HscTypes
import Literal
import Name
import Outputable
import Pair
import PrimOp
import StgSyn
import TyCon
import TyCoRep
import Unique
import Var as V

import qualified Data.Maybe as MB

-- | Make IO String from Outputable.
mkIOString :: (Outputable a) => a -> IO String
mkIOString obj = runGhc (Just libdir) $ do
    dflags <- getSessionDynFlags
    return (showPpr dflags obj)

-- | Given the project directory and the source file path, compiles the
-- `ModuleGraph` and translates it into a SSTG `Bind`s.
mkTargetBinds :: FilePath -> FilePath -> IO [SL.Bind]
mkTargetBinds proj src = do
    (sums_gutss, dflags, env) <- mkCompileClosure proj src
    let (sums, gutss) = (map fst sums_gutss, map snd sums_gutss)
    let mod_lcs = map (\s -> (ms_mod s, ms_location s)) sums
    let mod_bindss = map mg_binds gutss
    let mod_tycss = map mg_tcs gutss
    -- Zip in preparation for STG transformation.
    let zipd1 = zip3 mod_lcs mod_bindss mod_tycss
    preps <- mapM (\((m, l), b, t) -> corePrepPgm env m l b t) zipd1
    let zipd2 = zip (map fst mod_lcs) preps
    stg_bindss <- mapM (\(m, p) -> coreToStg dflags m p) zipd2
    -- Create the binds.
    return (map mkBind (concat stg_bindss))

-- | Compilation closure type.
type CompileClosure = ([(ModSummary, ModGuts)], DynFlags, HscEnv)

-- | Captures a snapshot of the `DynFlags` and `HscEnv` in addition to the
-- `ModGuts` in the `ModuleGraph`. This allows compilation to be, in theory,
-- more portable across different applications, since `ModGuts` is a crucial
-- intermediary for compilation in general.
mkCompileClosure :: FilePath -> FilePath -> IO CompileClosure
mkCompileClosure proj src = runGhc (Just libdir) $ do
    beta_flags <- getSessionDynFlags
    let dflags = beta_flags { importPaths = [proj] }
    _ <- setSessionDynFlags dflags
    env <- getSession
    target <- guessTarget src Nothing
    _ <- setTargets [target]
    _ <- load LoadAllTargets
    -- Now that things are loaded, make the compilation closure.
    mod_graph <- getModuleGraph
    pmods <- mapM parseModule mod_graph
    tmods <- mapM typecheckModule pmods
    dmods <- mapM desugarModule tmods
    let mod_gutss = map coreModule dmods
    return (zip mod_graph mod_gutss, dflags, env)

-- | Make SSTG `Expr`.
mkExpr :: StgExpr -> SL.Expr
mkExpr (StgLit lit) = SL.Atom (SL.LitAtom (mkLit lit))
mkExpr (StgApp occ args) = SL.FunApp (mkVar occ) (map mkAtom args)
mkExpr (StgConApp dcon args) = SL.ConApp (mkData dcon) (map mkAtom args)
mkExpr (StgOpApp op args _) = SL.PrimApp (mkPrimOp op) (map mkAtom args)
mkExpr (StgTick _ expr)= mkExpr expr
mkExpr (StgLam _ _) = error "mkExpr: StgLam detected"
mkExpr (StgLet bind expr) = SL.Let (mkBind bind) (mkExpr expr)
mkExpr (StgLetNoEscape _ _ bind expr) = mkExpr (StgLet bind expr)
mkExpr (StgCase mxpr _ _ cvar _ _ alts) = SL.Case (mkExpr mxpr) (mkVar cvar)
                                                  (map mkAlt alts)

-- | Make SSTG `Atom`.
mkAtom :: StgArg -> SL.Atom
mkAtom (StgVarArg occ) = SL.VarAtom (mkVar occ)
mkAtom (StgLitArg lit) = SL.LitAtom (mkLit lit)

-- | Make SSTG `Name`.
mkName :: Name -> SL.Name
mkName name = SL.Name occ mdl ns unq
  where
    occ = (occNameString . nameOccName) name
    ns = (mkNameSpace . occNameSpace . nameOccName) name
    unq = (getKey . nameUnique) name
    mdl = case nameModule_maybe name of
              Nothing -> Nothing
              Just md -> Just ((moduleNameString . moduleName) md)

-- | Make SSTG `NameSpace`.
mkNameSpace :: NameSpace -> SL.NameSpace
mkNameSpace ns | isVarNameSpace ns = SL.VarNSpace
               | isTvNameSpace ns = SL.TvNSpace
               | isDataConNameSpace ns = SL.DataNSpace
               | isTcClsNameSpace ns = SL.TcClsNSpace
               | otherwise = error "mkNameSpace: unrecognized"

-- | Make SSTG Var
mkVar :: Var -> SL.Var
mkVar var = SL.Var vname vtype
  where
    vname = (mkName . V.varName) var
    vtype = (mkType . varType) var

-- | Make SSTG `Bind`.
mkBind :: StgBinding -> SL.Bind
mkBind (StgNonRec bind r) = SL.Bind SL.NonRec [(mkVar bind, mkRhs r)]
mkBind (StgRec bind) = SL.Bind SL.Rec (map (\(b,r) -> (mkVar b, mkRhs r)) bind)

-- | Make SSTG `BindRhs`.
mkRhs :: StgRhs -> SL.BindRhs
mkRhs (StgRhsCon _ dcon args) = SL.ConForm (mkData dcon) (map mkAtom args)
mkRhs (StgRhsClosure _ _ _ _ _ ps e) = SL.FunForm (map mkVar ps) (mkExpr e)

-- | Make SSTG `Lit`.
mkLit :: Literal -> SL.Lit
mkLit lit = case lit of
  (MachChar chr) -> SL.MachChar chr ((mkType . literalType) lit)
  (MachStr bstr) -> SL.MachStr (show bstr) ((mkType . literalType) lit)
  (MachInt i) -> SL.MachInt (fromInteger i) ((mkType . literalType) lit)
  (MachInt64 i) -> SL.MachInt (fromInteger i) ((mkType . literalType) lit)
  (MachWord i) -> SL.MachWord (fromInteger i) ((mkType . literalType) lit)
  (MachWord64 i) -> SL.MachWord (fromInteger i) ((mkType . literalType) lit)
  (MachFloat rat) -> SL.MachFloat rat ((mkType . literalType) lit)
  (MachDouble rat) -> SL.MachDouble rat ((mkType . literalType) lit)
  (LitInteger i _) -> SL.MachInt (fromInteger i) ((mkType . literalType) lit)
  (MachNullAddr) -> SL.MachNullAddr ((mkType . literalType) lit)
  (MachLabel f m _) -> SL.MachLabel (unpackFS f) m ((mkType . literalType) lit)

-- | `DataCon`'s `Name`.
mkDataName :: DataCon -> SL.Name
mkDataName datacon = (mkName. dataConName) datacon

-- | Make SSTG `DataCon`.
mkData :: DataCon -> SL.DataCon
mkData datacon = SL.DataCon name ty args
  where
    name = mkDataName datacon
    ty = (mkType . dataConRepType) datacon
    args = map mkType (dataConOrigArgTys datacon)

-- | Make SSTG `PrimFun`.
mkPrimOp :: StgOp -> SL.PrimFun
mkPrimOp (StgPrimOp op) = SL.PrimFun (SL.Name occ Nothing ns unq) ty
  where
    occname = primOpOcc op
    occ = occNameString occname
    ns = (mkNameSpace . occNameSpace) occname
    unq = primOpTag op
    ty = (mkType . primOpType) op
mkPrimOp _ = error "mkPrimOp: got StgPrimCallOp or StgFCallOp"

-- | Make SSTG `Alt`.
mkAlt :: StgAlt -> SL.Alt
mkAlt (a, b, _, e) = SL.Alt (mkAltCon a) (map mkVar b) (mkExpr e)

-- | Make SSTG `AltCon`.
mkAltCon :: AltCon -> SL.AltCon
mkAltCon (DataAlt dc) = SL.DataAlt (mkData dc)
mkAltCon (LitAlt lit) = SL.LitAlt (mkLit lit)
mkAltCon (DEFAULT) = SL.Default

-- | Make SSTG `Type`.
mkType :: Type -> SL.Type
mkType (TyVarTy v) = SL.TyVarTy (mkName (V.varName v)) (mkType (varType v))
mkType (AppTy t1 t2) = SL.AppTy (mkType t1) (mkType t2)
mkType (TyConApp tc ts) = SL.TyConApp (mkTyCon tc) (map mkType ts)
mkType (ForAllTy b ty) = SL.ForAllTy (mkTyBinder b) (mkType ty)
mkType (LitTy tlit) = SL.LitTy (mkTyLit tlit)
mkType (CastTy ty cor) = SL.CastTy (mkType ty) (mkCoercion cor)
mkType (CoercionTy cor) = SL.CoercionTy (mkCoercion cor)

-- | Make SSTG `TyCon`.
mkTyCon :: TyCon -> SL.TyCon
mkTyCon tc | isFunTyCon tc = SL.FunTyCon name tcbindrs
           | isAlgTyCon tc = SL.AlgTyCon name tvnames algrhs
           | isFamilyTyCon tc = SL.FamilyTyCon name tvnames
           | isPrimTyCon tc = SL.PrimTyCon name tcbindrs
           | isTypeSynonymTyCon tc = SL.SynonymTyCon name tvnames
           | isPromotedDataCon tc = SL.Promoted name tcbindrs dcon
           | otherwise = error "mkTyCon: unrecognized TyCon"
  where
    name = (mkName . tyConName) tc
    algrhs = (mkAlgTyConRhs . algTyConRhs) tc
    tcbindrs = map mkTyBinder (tyConBinders tc)
    tvnames = map (mkName. V.varName) (tyConTyVars tc)
    dcon = (mkData . MB.fromJust . isPromotedDataCon_maybe) tc

-- | Make SSTG `AlgTyRhs`.
mkAlgTyConRhs :: AlgTyConRhs -> SL.AlgTyRhs
mkAlgTyConRhs (AbstractTyCon b) = SL.AbstractTyCon b
mkAlgTyConRhs (DataTyCon {data_cons = ds}) = SL.DataTyCon (map mkDataName ds)
mkAlgTyConRhs (TupleTyCon {data_con = d}) = SL.TupleTyCon (mkDataName d)
mkAlgTyConRhs (NewTyCon {data_con = d}) = SL.NewTyCon (mkDataName d)

-- | make SSTG `TyBinder`.
mkTyBinder :: TyBinder -> SL.TyBinder
mkTyBinder (Anon _) = SL.AnonTyBndr
mkTyBinder (Named v _) = SL.NamedTyBndr (mkName (V.varName v))

-- | Make SSTG `Type` literals.
mkTyLit :: TyLit -> SL.TyLit
mkTyLit (NumTyLit i) = SL.NumTyLit (fromInteger i)
mkTyLit (StrTyLit fs) = SL.StrTyLit (unpackFS fs)

-- | Make SSTG `Coercion`.
mkCoercion :: Coercion -> SL.Coercion
mkCoercion coer = SL.Coercion (mkType a) (mkType b)
  where
    (a, b) = (unPair . coercionKind) coer