{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}

{-# OPTIONS_GHC -Wno-unused-top-binds #-}

module IfSat.Plugin.Compat
  ( wrapTcS, getRestoreTcS )
  where

-- base

import Unsafe.Coerce
  ( unsafeCoerce )

-- ghc

#if MIN_VERSION_ghc(9,4,0)
import GHC.Tc.Solver.InertSet
  ( WorkList, InertSet )
#endif
import GHC.Tc.Solver.Monad
  ( TcS
#if MIN_VERSION_ghc(9,1,0)
  , TcLevel, wrapTcS
#endif
#if !MIN_VERSION_ghc(9,4,0)
  , WorkList, InertSet
#endif
  )
import GHC.Tc.Types
  ( TcM, TcRef )
import GHC.Tc.Types.Evidence
  ( EvBindsVar(..) )

-- ghc-tcplugin-api

import GHC.TcPlugin.API
  ( readTcRef, writeTcRef )

--------------------------------------------------------------------------------


-- | Capture the current 'TcS' state, returning an action which restores

-- the fields of 'TcSEnv' as appropriate after running a test-run

-- of 'solveSimpleWanteds' and deciding to backtrack.

getRestoreTcS :: TcS (TcS ())
getRestoreTcS :: TcS (TcS ())
getRestoreTcS = do
  ShimTcSEnv
shim_tcs_env <- TcS ShimTcSEnv
getShimTcSEnv
  let ev_binds_var :: EvBindsVar
ev_binds_var   = ShimTcSEnv -> EvBindsVar
shim_tcs_ev_binds ShimTcSEnv
shim_tcs_env
      unif_var :: TcRef Int
unif_var       = ShimTcSEnv -> TcRef Int
shim_tcs_unified  ShimTcSEnv
shim_tcs_env
#if MIN_VERSION_ghc(9,1,0)
      unif_lvl_var :: TcRef (Maybe TcLevel)
unif_lvl_var   = ShimTcSEnv -> TcRef (Maybe TcLevel)
shim_tcs_unif_lvl ShimTcSEnv
shim_tcs_env
#endif
      unit_count_var :: TcRef Int
unit_count_var = ShimTcSEnv -> TcRef Int
shim_tcs_count    ShimTcSEnv
shim_tcs_env
  forall a. TcM a -> TcS a
wrapTcS forall a b. (a -> b) -> a -> b
$ do
    IOEnv (Env TcGblEnv TcLclEnv) ()
restore_evBinds <- case EvBindsVar
ev_binds_var of
      EvBindsVar { ebv_binds :: EvBindsVar -> IORef EvBindMap
ebv_binds = IORef EvBindMap
ev_binds_ref
                 , ebv_tcvs :: EvBindsVar -> IORef CoVarSet
ebv_tcvs  = IORef CoVarSet
ev_cvs_ref } ->
        do EvBindMap
ev_binds <- forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef EvBindMap
ev_binds_ref
           CoVarSet
ev_cvs   <- forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef CoVarSet
ev_cvs_ref
           forall (m :: * -> *) a. Monad m => a -> m a
return do
             forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef IORef EvBindMap
ev_binds_ref EvBindMap
ev_binds
             forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef IORef CoVarSet
ev_cvs_ref   CoVarSet
ev_cvs
      CoEvBindsVar { ebv_tcvs :: EvBindsVar -> IORef CoVarSet
ebv_tcvs = IORef CoVarSet
ev_cvs_ref } ->
        do CoVarSet
ev_cvs   <- forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef IORef CoVarSet
ev_cvs_ref
           forall (m :: * -> *) a. Monad m => a -> m a
return do
             forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef IORef CoVarSet
ev_cvs_ref   CoVarSet
ev_cvs

    Int
unif         <- forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef TcRef Int
unif_var
#if MIN_VERSION_ghc(9,1,0)
    Maybe TcLevel
unif_lvl     <- forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef TcRef (Maybe TcLevel)
unif_lvl_var
#endif
    Int
count        <- forall a gbl lcl. TcRef a -> TcRnIf gbl lcl a
readTcRef TcRef Int
unit_count_var
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. TcM a -> TcS a
wrapTcS forall a b. (a -> b) -> a -> b
$ do
      IOEnv (Env TcGblEnv TcLclEnv) ()
restore_evBinds
      forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef TcRef Int
unif_var       Int
unif
#if MIN_VERSION_ghc(9,1,0)
      forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef TcRef (Maybe TcLevel)
unif_lvl_var   Maybe TcLevel
unif_lvl
#endif
      forall a gbl lcl. TcRef a -> a -> TcRnIf gbl lcl ()
writeTcRef TcRef Int
unit_count_var Int
count

  -- NB: no need to reset 'tcs_inerts' or 'tcs_worklist', because

  -- 'solveSimpleWanteds' calls 'nestTcS', which appropriately resets

  -- both of those fields.


#if !MIN_VERSION_ghc(9,1,0)
wrapTcS :: TcM a -> TcS a
wrapTcS = unsafeCoerce const
#endif


-- Obtain the 'TcSEnv' underlying the 'TcS' monad (in the form of a 'ShimTcSEnv').

getShimTcSEnv :: TcS ShimTcSEnv
getShimTcSEnv :: TcS ShimTcSEnv
getShimTcSEnv = forall a b. a -> b
unsafeCoerce ( forall (m :: * -> *) a. Monad m => a -> m a
return :: ShimTcSEnv -> TcM ShimTcSEnv )

-- | A shim copy of "GHC.Tc.Solver.Monad.TcSEnv", to work around the

-- fact that it isn't exported.

--

-- Needs to be manually kept in sync with 'TcSEnv' to avoid segfaults due

-- to the use of 'unsafeCoerce' in 'getShimTcSEnv'.

data ShimTcSEnv
  = ShimTcSEnv
  { ShimTcSEnv -> EvBindsVar
shim_tcs_ev_binds           :: EvBindsVar
  , ShimTcSEnv -> TcRef Int
shim_tcs_unified            :: TcRef Int
#if MIN_VERSION_ghc(9,1,0)
  , ShimTcSEnv -> TcRef (Maybe TcLevel)
shim_tcs_unif_lvl           :: TcRef (Maybe TcLevel)
#endif
  , ShimTcSEnv -> TcRef Int
shim_tcs_count              :: TcRef Int
  , ShimTcSEnv -> TcRef InertSet
shim_tcs_inerts             :: TcRef InertSet
#if MIN_VERSION_ghc(9,3,0)
  , shim_tcs_abort_on_insoluble :: Bool
#endif
  , ShimTcSEnv -> TcRef WorkList
shim_tcs_worklist           :: TcRef WorkList
  }