{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveFunctor         #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE DerivingStrategies    #-}
{-# LANGUAGE DerivingVia           #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MonoLocalBinds        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE TypeSynonymInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
{-# LANGUAGE ViewPatterns          #-}
{-# OPTIONS_GHC -fno-warn-orphans  #-}

module Ide.Plugin.Tactic.Machinery
  ( module Ide.Plugin.Tactic.Machinery
  ) where

import           Class (Class(classTyVars))
import           Control.Arrow
import           Control.Monad.Error.Class
import           Control.Monad.Reader
import           Control.Monad.State (MonadState(..))
import           Control.Monad.State.Class (gets, modify)
import           Control.Monad.State.Strict (StateT (..))
import           Data.Bool (bool)
import           Data.Coerce
import           Data.Either
import           Data.Foldable
import           Data.Functor ((<&>))
import           Data.Generics (mkQ, everything, gcount)
import           Data.List (sortBy)
import qualified Data.Map as M
import           Data.Ord (comparing, Down(..))
import           Data.Set (Set)
import qualified Data.Set as S
import           Development.IDE.GHC.Compat
import           Ide.Plugin.Tactic.Judgements
import           Ide.Plugin.Tactic.Types
import           OccName (HasOccName(occName))
import           Refinery.ProofState
import           Refinery.Tactic
import           Refinery.Tactic.Internal
import           TcType
import           Type
import           Unify


substCTy :: TCvSubst -> CType -> CType
substCTy :: TCvSubst -> CType -> CType
substCTy TCvSubst
subst = Type -> CType
coerce (Type -> CType) -> (CType -> Type) -> CType -> CType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => TCvSubst -> Type -> Type
TCvSubst -> Type -> Type
substTy TCvSubst
subst (Type -> Type) -> (CType -> Type) -> CType -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> Type
coerce


------------------------------------------------------------------------------
-- | Produce a subgoal that must be solved before we can solve the original
-- goal.
newSubgoal
    :: Judgement
    -> Rule
newSubgoal :: Judgement -> Rule
newSubgoal Judgement
j = do
    TCvSubst
unifier <- (TacticState -> TCvSubst)
-> RuleT
     Judgement
     (Trace, LHsExpr GhcPs)
     TacticError
     TacticState
     ExtractM
     TCvSubst
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TacticState -> TCvSubst
ts_unifier
    Judgement -> Rule
forall jdg ext (m :: * -> *). MonadRule jdg ext m => jdg -> m ext
subgoal
      (Judgement -> Rule) -> Judgement -> Rule
forall a b. (a -> b) -> a -> b
$ TCvSubst -> Judgement -> Judgement
substJdg TCvSubst
unifier
      (Judgement -> Judgement) -> Judgement -> Judgement
forall a b. (a -> b) -> a -> b
$ Judgement -> Judgement
forall a. Judgement' a -> Judgement' a
unsetIsTopHole Judgement
j


------------------------------------------------------------------------------
-- | Attempt to generate a term of the right type using in-scope bindings, and
-- a given tactic.
runTactic
    :: Context
    -> Judgement
    -> TacticsM ()       -- ^ Tactic to use
    -> Either [TacticError] RunTacticResults
runTactic :: Context
-> Judgement
-> TacticsM ()
-> Either [TacticError] RunTacticResults
runTactic Context
ctx Judgement
jdg TacticsM ()
t =
    let skolems :: Set TyVar
skolems = [TyVar] -> Set TyVar
forall a. Ord a => [a] -> Set a
S.fromList
                ([TyVar] -> Set TyVar) -> [TyVar] -> Set TyVar
forall a b. (a -> b) -> a -> b
$ (CType -> [TyVar]) -> [CType] -> [TyVar]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Type -> [TyVar]
tyCoVarsOfTypeWellScoped (Type -> [TyVar]) -> (CType -> Type) -> CType -> [TyVar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> Type
unCType)
                ([CType] -> [TyVar]) -> [CType] -> [TyVar]
forall a b. (a -> b) -> a -> b
$ (:) (Judgement -> CType
forall a. Judgement' a -> a
jGoal Judgement
jdg)
                ([CType] -> [CType]) -> [CType] -> [CType]
forall a b. (a -> b) -> a -> b
$ (HyInfo CType -> CType) -> [HyInfo CType] -> [CType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HyInfo CType -> CType
forall a. HyInfo a -> a
hi_type
                ([HyInfo CType] -> [CType]) -> [HyInfo CType] -> [CType]
forall a b. (a -> b) -> a -> b
$ Map OccName (HyInfo CType) -> [HyInfo CType]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
                (Map OccName (HyInfo CType) -> [HyInfo CType])
-> Map OccName (HyInfo CType) -> [HyInfo CType]
forall a b. (a -> b) -> a -> b
$ Judgement -> Map OccName (HyInfo CType)
forall a. Judgement' a -> Map OccName (HyInfo a)
jHypothesis Judgement
jdg
        unused_topvals :: Set OccName
unused_topvals = Map OccName (HyInfo CType) -> Set OccName
forall k a. Map k a -> Set k
M.keysSet
                       (Map OccName (HyInfo CType) -> Set OccName)
-> Map OccName (HyInfo CType) -> Set OccName
forall a b. (a -> b) -> a -> b
$ (HyInfo CType -> Bool)
-> Map OccName (HyInfo CType) -> Map OccName (HyInfo CType)
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Provenance -> Bool
isTopLevel (Provenance -> Bool)
-> (HyInfo CType -> Provenance) -> HyInfo CType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HyInfo CType -> Provenance
forall a. HyInfo a -> Provenance
hi_provenance)
                       (Map OccName (HyInfo CType) -> Map OccName (HyInfo CType))
-> Map OccName (HyInfo CType) -> Map OccName (HyInfo CType)
forall a b. (a -> b) -> a -> b
$ Judgement -> Map OccName (HyInfo CType)
forall a. Judgement' a -> Map OccName (HyInfo a)
jHypothesis Judgement
jdg
        tacticState :: TacticState
tacticState =
          TacticState
defaultTacticState
            { ts_skolems :: Set TyVar
ts_skolems = Set TyVar
skolems
            , ts_unused_top_vals :: Set OccName
ts_unused_top_vals = Set OccName
unused_topvals
            }
    in case [Either
   TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> ([TacticError],
    [((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
forall a b. [Either a b] -> ([a], [b])
partitionEithers
          ([Either
    TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
 -> ([TacticError],
     [((Trace, LHsExpr GhcPs), TacticState, [Judgement])]))
-> (ExtractM
      [Either
         TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
    -> [Either
          TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
-> ExtractM
     [Either
        TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> ([TacticError],
    [((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Reader
   Context
   [Either
      TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
 -> Context
 -> [Either
       TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
-> Context
-> Reader
     Context
     [Either
        TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> [Either
      TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Reader
  Context
  [Either
     TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> Context
-> [Either
      TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
forall r a. Reader r a -> r -> a
runReader Context
ctx
          (Reader
   Context
   [Either
      TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
 -> [Either
       TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
-> (ExtractM
      [Either
         TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
    -> Reader
         Context
         [Either
            TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
-> ExtractM
     [Either
        TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> [Either
      TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtractM
  [Either
     TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> Reader
     Context
     [Either
        TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
forall a. ExtractM a -> Reader Context a
unExtractM
          (ExtractM
   [Either
      TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
 -> ([TacticError],
     [((Trace, LHsExpr GhcPs), TacticState, [Judgement])]))
-> ExtractM
     [Either
        TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> ([TacticError],
    [((Trace, LHsExpr GhcPs), TacticState, [Judgement])])
forall a b. (a -> b) -> a -> b
$ TacticsM ()
-> Judgement
-> TacticState
-> ExtractM
     [Either
        TacticError ((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
forall ext (m :: * -> *) jdg err s.
MonadExtract ext m =>
TacticT jdg ext err s m ()
-> jdg -> s -> m [Either err (ext, s, [jdg])]
runTacticT TacticsM ()
t Judgement
jdg TacticState
tacticState of
      ([TacticError]
errs, []) -> [TacticError] -> Either [TacticError] RunTacticResults
forall a b. a -> Either a b
Left ([TacticError] -> Either [TacticError] RunTacticResults)
-> [TacticError] -> Either [TacticError] RunTacticResults
forall a b. (a -> b) -> a -> b
$ Int -> [TacticError] -> [TacticError]
forall a. Int -> [a] -> [a]
take Int
50 ([TacticError] -> [TacticError]) -> [TacticError] -> [TacticError]
forall a b. (a -> b) -> a -> b
$ [TacticError]
errs
      ([TacticError]
_, (((Trace, LHsExpr GhcPs), TacticState, [Judgement])
 -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement])))
-> [((Trace, LHsExpr GhcPs), TacticState, [Judgement])]
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Trace, LHsExpr GhcPs), TacticState, [Judgement])
-> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
forall a b c. (a, b, c) -> (a, (b, c))
assoc23 -> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
solns) -> do
        let sorted :: [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
sorted =
              ((((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
  -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
  -> Ordering)
 -> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
 -> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))])
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> (((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
    -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
    -> Ordering)
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
 -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
 -> Ordering)
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
solns ((((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
  -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
  -> Ordering)
 -> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))])
-> (((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
    -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
    -> Ordering)
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
forall a b. (a -> b) -> a -> b
$ (((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
 -> Down
      (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
       Penalize Int, Penalize Int))
-> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
-> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
-> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing ((((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
  -> Down
       (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
        Penalize Int, Penalize Int))
 -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
 -> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
 -> Ordering)
-> (((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
    -> Down
         (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
          Penalize Int, Penalize Int))
-> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
-> ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
-> Ordering
forall a b. (a -> b) -> a -> b
$ \((Trace
_, LHsExpr GhcPs
ext), (TacticState
jdg, [Judgement]
holes)) ->
                (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
 Penalize Int, Penalize Int)
-> Down
     (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
      Penalize Int, Penalize Int)
forall a. a -> Down a
Down ((Penalize Int, Reward Bool, Penalize Int, Penalize Int,
  Reward Int, Penalize Int, Penalize Int)
 -> Down
      (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
       Penalize Int, Penalize Int))
-> (Penalize Int, Reward Bool, Penalize Int, Penalize Int,
    Reward Int, Penalize Int, Penalize Int)
-> Down
     (Penalize Int, Reward Bool, Penalize Int, Penalize Int, Reward Int,
      Penalize Int, Penalize Int)
forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs
-> TacticState
-> [Judgement]
-> (Penalize Int, Reward Bool, Penalize Int, Penalize Int,
    Reward Int, Penalize Int, Penalize Int)
scoreSolution LHsExpr GhcPs
ext TacticState
jdg [Judgement]
holes
        case [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
sorted of
          (((Trace
tr, LHsExpr GhcPs
ext), (TacticState, [Judgement])
_) : [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
_) ->
            RunTacticResults -> Either [TacticError] RunTacticResults
forall a b. b -> Either a b
Right
              (RunTacticResults -> Either [TacticError] RunTacticResults)
-> ([((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
    -> RunTacticResults)
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> Either [TacticError] RunTacticResults
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace
-> LHsExpr GhcPs -> [(Trace, LHsExpr GhcPs)] -> RunTacticResults
RunTacticResults Trace
tr LHsExpr GhcPs
ext
              ([(Trace, LHsExpr GhcPs)] -> RunTacticResults)
-> ([((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
    -> [(Trace, LHsExpr GhcPs)])
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> RunTacticResults
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Trace, LHsExpr GhcPs)] -> [(Trace, LHsExpr GhcPs)]
forall a. [a] -> [a]
reverse
              ([(Trace, LHsExpr GhcPs)] -> [(Trace, LHsExpr GhcPs)])
-> ([((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
    -> [(Trace, LHsExpr GhcPs)])
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> [(Trace, LHsExpr GhcPs)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
 -> (Trace, LHsExpr GhcPs))
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> [(Trace, LHsExpr GhcPs)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))
-> (Trace, LHsExpr GhcPs)
forall a b. (a, b) -> a
fst
              ([((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
 -> Either [TacticError] RunTacticResults)
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> Either [TacticError] RunTacticResults
forall a b. (a -> b) -> a -> b
$ Int
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
-> [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
forall a. Int -> [a] -> [a]
take Int
5 [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
sorted
          -- guaranteed to not be empty
          [((Trace, LHsExpr GhcPs), (TacticState, [Judgement]))]
_ -> [TacticError] -> Either [TacticError] RunTacticResults
forall a b. a -> Either a b
Left []

assoc23 :: (a, b, c) -> (a, (b, c))
assoc23 :: (a, b, c) -> (a, (b, c))
assoc23 (a
a, b
b, c
c) = (a
a, (b
b, c
c))


tracePrim :: String -> Trace
tracePrim :: String -> Trace
tracePrim = (String -> [Trace] -> Trace) -> [Trace] -> String -> Trace
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> [Trace] -> Trace
forall a. (Eq a, Monoid a) => a -> [Rose a] -> Rose a
rose []


tracing
    :: Functor m
    => String
    -> TacticT jdg (Trace, ext) err s m a
    -> TacticT jdg (Trace, ext) err s m a
tracing :: String
-> TacticT jdg (Trace, ext) err s m a
-> TacticT jdg (Trace, ext) err s m a
tracing String
s (TacticT StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
m)
  = StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
-> TacticT jdg (Trace, ext) err s m a
forall jdg ext err s (m :: * -> *) a.
StateT jdg (ProofStateT ext ext err s m) a
-> TacticT jdg ext err s m a
TacticT (StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
 -> TacticT jdg (Trace, ext) err s m a)
-> StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
-> TacticT jdg (Trace, ext) err s m a
forall a b. (a -> b) -> a -> b
$ (jdg -> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg))
-> StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((jdg -> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg))
 -> StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a)
-> (jdg -> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg))
-> StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
forall a b. (a -> b) -> a -> b
$ \jdg
jdg ->
      ((Trace, ext) -> (Trace, ext))
-> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg)
-> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg)
forall (m :: * -> *) a b ext' err s jdg.
Functor m =>
(a -> b)
-> ProofStateT ext' a err s m jdg -> ProofStateT ext' b err s m jdg
mapExtract' ((Trace -> Trace) -> (Trace, ext) -> (Trace, ext)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((Trace -> Trace) -> (Trace, ext) -> (Trace, ext))
-> (Trace -> Trace) -> (Trace, ext) -> (Trace, ext)
forall a b. (a -> b) -> a -> b
$ String -> [Trace] -> Trace
forall a. (Eq a, Monoid a) => a -> [Rose a] -> Rose a
rose String
s ([Trace] -> Trace) -> (Trace -> [Trace]) -> Trace -> Trace
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Trace -> [Trace]
forall (f :: * -> *) a. Applicative f => a -> f a
pure) (ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg)
 -> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg))
-> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg)
-> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg)
forall a b. (a -> b) -> a -> b
$ StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
-> jdg -> ProofStateT (Trace, ext) (Trace, ext) err s m (a, jdg)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT jdg (ProofStateT (Trace, ext) (Trace, ext) err s m) a
m jdg
jdg


------------------------------------------------------------------------------
-- | Recursion is allowed only when we can prove it is on a structurally
-- smaller argument. The top of the 'ts_recursion_stack' witnesses the smaller
-- pattern val.
guardStructurallySmallerRecursion
    :: TacticState
    -> Maybe TacticError
guardStructurallySmallerRecursion :: TacticState -> Maybe TacticError
guardStructurallySmallerRecursion TacticState
s =
  case [Maybe PatVal] -> Maybe PatVal
forall a. [a] -> a
head ([Maybe PatVal] -> Maybe PatVal) -> [Maybe PatVal] -> Maybe PatVal
forall a b. (a -> b) -> a -> b
$ TacticState -> [Maybe PatVal]
ts_recursion_stack TacticState
s of
     Just PatVal
_  -> Maybe TacticError
forall a. Maybe a
Nothing
     Maybe PatVal
Nothing -> TacticError -> Maybe TacticError
forall a. a -> Maybe a
Just TacticError
NoProgress


------------------------------------------------------------------------------
-- | Mark that the current recursive call is structurally smaller, due to
-- having been matched on a pattern value.
--
-- Implemented by setting the top of the 'ts_recursion_stack'.
markStructuralySmallerRecursion :: MonadState TacticState m => PatVal -> m ()
markStructuralySmallerRecursion :: PatVal -> m ()
markStructuralySmallerRecursion PatVal
pv = do
  (TacticState -> TacticState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((TacticState -> TacticState) -> m ())
-> (TacticState -> TacticState) -> m ()
forall a b. (a -> b) -> a -> b
$ ([Maybe PatVal] -> [Maybe PatVal]) -> TacticState -> TacticState
withRecursionStack (([Maybe PatVal] -> [Maybe PatVal]) -> TacticState -> TacticState)
-> ([Maybe PatVal] -> [Maybe PatVal]) -> TacticState -> TacticState
forall a b. (a -> b) -> a -> b
$ \case
    (Maybe PatVal
_ : [Maybe PatVal]
bs) -> PatVal -> Maybe PatVal
forall a. a -> Maybe a
Just PatVal
pv Maybe PatVal -> [Maybe PatVal] -> [Maybe PatVal]
forall a. a -> [a] -> [a]
: [Maybe PatVal]
bs
    []       -> []


------------------------------------------------------------------------------
-- | Given the results of running a tactic, score the solutions by
-- desirability.
--
-- TODO(sandy): This function is completely unprincipled and was just hacked
-- together to produce the right test results.
scoreSolution
    :: LHsExpr GhcPs
    -> TacticState
    -> [Judgement]
    -> ( Penalize Int  -- number of holes
       , Reward Bool   -- all bindings used
       , Penalize Int  -- unused top-level bindings
       , Penalize Int  -- number of introduced bindings
       , Reward Int    -- number used bindings
       , Penalize Int  -- number of recursive calls
       , Penalize Int  -- size of extract
       )
scoreSolution :: LHsExpr GhcPs
-> TacticState
-> [Judgement]
-> (Penalize Int, Reward Bool, Penalize Int, Penalize Int,
    Reward Int, Penalize Int, Penalize Int)
scoreSolution LHsExpr GhcPs
ext TacticState{Int
[Maybe PatVal]
Set TyVar
Set OccName
TCvSubst
UniqSupply
ts_unique_gen :: TacticState -> UniqSupply
ts_recursion_count :: TacticState -> Int
ts_intro_vals :: TacticState -> Set OccName
ts_used_vals :: TacticState -> Set OccName
ts_unique_gen :: UniqSupply
ts_recursion_count :: Int
ts_recursion_stack :: [Maybe PatVal]
ts_unused_top_vals :: Set OccName
ts_intro_vals :: Set OccName
ts_used_vals :: Set OccName
ts_unifier :: TCvSubst
ts_skolems :: Set TyVar
ts_recursion_stack :: TacticState -> [Maybe PatVal]
ts_unused_top_vals :: TacticState -> Set OccName
ts_skolems :: TacticState -> Set TyVar
ts_unifier :: TacticState -> TCvSubst
..} [Judgement]
holes
  = ( Int -> Penalize Int
forall a. a -> Penalize a
Penalize (Int -> Penalize Int) -> Int -> Penalize Int
forall a b. (a -> b) -> a -> b
$ [Judgement] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Judgement]
holes
    , Bool -> Reward Bool
forall a. a -> Reward a
Reward   (Bool -> Reward Bool) -> Bool -> Reward Bool
forall a b. (a -> b) -> a -> b
$ Set OccName -> Bool
forall a. Set a -> Bool
S.null (Set OccName -> Bool) -> Set OccName -> Bool
forall a b. (a -> b) -> a -> b
$ Set OccName
ts_intro_vals Set OccName -> Set OccName -> Set OccName
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set OccName
ts_used_vals
    , Int -> Penalize Int
forall a. a -> Penalize a
Penalize (Int -> Penalize Int) -> Int -> Penalize Int
forall a b. (a -> b) -> a -> b
$ Set OccName -> Int
forall a. Set a -> Int
S.size Set OccName
ts_unused_top_vals
    , Int -> Penalize Int
forall a. a -> Penalize a
Penalize (Int -> Penalize Int) -> Int -> Penalize Int
forall a b. (a -> b) -> a -> b
$ Set OccName -> Int
forall a. Set a -> Int
S.size Set OccName
ts_intro_vals
    , Int -> Reward Int
forall a. a -> Reward a
Reward   (Int -> Reward Int) -> Int -> Reward Int
forall a b. (a -> b) -> a -> b
$ Set OccName -> Int
forall a. Set a -> Int
S.size Set OccName
ts_used_vals
    , Int -> Penalize Int
forall a. a -> Penalize a
Penalize (Int -> Penalize Int) -> Int -> Penalize Int
forall a b. (a -> b) -> a -> b
$ Int
ts_recursion_count
    , Int -> Penalize Int
forall a. a -> Penalize a
Penalize (Int -> Penalize Int) -> Int -> Penalize Int
forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs -> Int
solutionSize LHsExpr GhcPs
ext
    )


------------------------------------------------------------------------------
-- | Compute the number of 'LHsExpr' nodes; used as a rough metric for code
-- size.
solutionSize :: LHsExpr GhcPs -> Int
solutionSize :: LHsExpr GhcPs -> Int
solutionSize = (Int -> Int -> Int) -> GenericQ Int -> GenericQ Int
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (GenericQ Int -> GenericQ Int) -> GenericQ Int -> GenericQ Int
forall a b. (a -> b) -> a -> b
$ GenericQ Bool -> GenericQ Int
gcount (GenericQ Bool -> GenericQ Int) -> GenericQ Bool -> GenericQ Int
forall a b. (a -> b) -> a -> b
$ Bool -> (LHsExpr GhcPs -> Bool) -> a -> Bool
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Bool
False ((LHsExpr GhcPs -> Bool) -> a -> Bool)
-> (LHsExpr GhcPs -> Bool) -> a -> Bool
forall a b. (a -> b) -> a -> b
$ \case
  (LHsExpr GhcPs
_ :: LHsExpr GhcPs) -> Bool
True


newtype Penalize a = Penalize a
  deriving (Penalize a -> Penalize a -> Bool
(Penalize a -> Penalize a -> Bool)
-> (Penalize a -> Penalize a -> Bool) -> Eq (Penalize a)
forall a. Eq a => Penalize a -> Penalize a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Penalize a -> Penalize a -> Bool
$c/= :: forall a. Eq a => Penalize a -> Penalize a -> Bool
== :: Penalize a -> Penalize a -> Bool
$c== :: forall a. Eq a => Penalize a -> Penalize a -> Bool
Eq, Eq (Penalize a)
Eq (Penalize a)
-> (Penalize a -> Penalize a -> Ordering)
-> (Penalize a -> Penalize a -> Bool)
-> (Penalize a -> Penalize a -> Bool)
-> (Penalize a -> Penalize a -> Bool)
-> (Penalize a -> Penalize a -> Bool)
-> (Penalize a -> Penalize a -> Penalize a)
-> (Penalize a -> Penalize a -> Penalize a)
-> Ord (Penalize a)
Penalize a -> Penalize a -> Bool
Penalize a -> Penalize a -> Ordering
Penalize a -> Penalize a -> Penalize a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Penalize a)
forall a. Ord a => Penalize a -> Penalize a -> Bool
forall a. Ord a => Penalize a -> Penalize a -> Ordering
forall a. Ord a => Penalize a -> Penalize a -> Penalize a
min :: Penalize a -> Penalize a -> Penalize a
$cmin :: forall a. Ord a => Penalize a -> Penalize a -> Penalize a
max :: Penalize a -> Penalize a -> Penalize a
$cmax :: forall a. Ord a => Penalize a -> Penalize a -> Penalize a
>= :: Penalize a -> Penalize a -> Bool
$c>= :: forall a. Ord a => Penalize a -> Penalize a -> Bool
> :: Penalize a -> Penalize a -> Bool
$c> :: forall a. Ord a => Penalize a -> Penalize a -> Bool
<= :: Penalize a -> Penalize a -> Bool
$c<= :: forall a. Ord a => Penalize a -> Penalize a -> Bool
< :: Penalize a -> Penalize a -> Bool
$c< :: forall a. Ord a => Penalize a -> Penalize a -> Bool
compare :: Penalize a -> Penalize a -> Ordering
$ccompare :: forall a. Ord a => Penalize a -> Penalize a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Penalize a)
Ord, Int -> Penalize a -> ShowS
[Penalize a] -> ShowS
Penalize a -> String
(Int -> Penalize a -> ShowS)
-> (Penalize a -> String)
-> ([Penalize a] -> ShowS)
-> Show (Penalize a)
forall a. Show a => Int -> Penalize a -> ShowS
forall a. Show a => [Penalize a] -> ShowS
forall a. Show a => Penalize a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Penalize a] -> ShowS
$cshowList :: forall a. Show a => [Penalize a] -> ShowS
show :: Penalize a -> String
$cshow :: forall a. Show a => Penalize a -> String
showsPrec :: Int -> Penalize a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Penalize a -> ShowS
Show) via (Down a)

newtype Reward a = Reward a
  deriving (Reward a -> Reward a -> Bool
(Reward a -> Reward a -> Bool)
-> (Reward a -> Reward a -> Bool) -> Eq (Reward a)
forall a. Eq a => Reward a -> Reward a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reward a -> Reward a -> Bool
$c/= :: forall a. Eq a => Reward a -> Reward a -> Bool
== :: Reward a -> Reward a -> Bool
$c== :: forall a. Eq a => Reward a -> Reward a -> Bool
Eq, Eq (Reward a)
Eq (Reward a)
-> (Reward a -> Reward a -> Ordering)
-> (Reward a -> Reward a -> Bool)
-> (Reward a -> Reward a -> Bool)
-> (Reward a -> Reward a -> Bool)
-> (Reward a -> Reward a -> Bool)
-> (Reward a -> Reward a -> Reward a)
-> (Reward a -> Reward a -> Reward a)
-> Ord (Reward a)
Reward a -> Reward a -> Bool
Reward a -> Reward a -> Ordering
Reward a -> Reward a -> Reward a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Reward a)
forall a. Ord a => Reward a -> Reward a -> Bool
forall a. Ord a => Reward a -> Reward a -> Ordering
forall a. Ord a => Reward a -> Reward a -> Reward a
min :: Reward a -> Reward a -> Reward a
$cmin :: forall a. Ord a => Reward a -> Reward a -> Reward a
max :: Reward a -> Reward a -> Reward a
$cmax :: forall a. Ord a => Reward a -> Reward a -> Reward a
>= :: Reward a -> Reward a -> Bool
$c>= :: forall a. Ord a => Reward a -> Reward a -> Bool
> :: Reward a -> Reward a -> Bool
$c> :: forall a. Ord a => Reward a -> Reward a -> Bool
<= :: Reward a -> Reward a -> Bool
$c<= :: forall a. Ord a => Reward a -> Reward a -> Bool
< :: Reward a -> Reward a -> Bool
$c< :: forall a. Ord a => Reward a -> Reward a -> Bool
compare :: Reward a -> Reward a -> Ordering
$ccompare :: forall a. Ord a => Reward a -> Reward a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Reward a)
Ord, Int -> Reward a -> ShowS
[Reward a] -> ShowS
Reward a -> String
(Int -> Reward a -> ShowS)
-> (Reward a -> String) -> ([Reward a] -> ShowS) -> Show (Reward a)
forall a. Show a => Int -> Reward a -> ShowS
forall a. Show a => [Reward a] -> ShowS
forall a. Show a => Reward a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reward a] -> ShowS
$cshowList :: forall a. Show a => [Reward a] -> ShowS
show :: Reward a -> String
$cshow :: forall a. Show a => Reward a -> String
showsPrec :: Int -> Reward a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Reward a -> ShowS
Show) via a


------------------------------------------------------------------------------
-- | Like 'tcUnifyTy', but takes a list of skolems to prevent unification of.
tryUnifyUnivarsButNotSkolems :: Set TyVar -> CType -> CType -> Maybe TCvSubst
tryUnifyUnivarsButNotSkolems :: Set TyVar -> CType -> CType -> Maybe TCvSubst
tryUnifyUnivarsButNotSkolems Set TyVar
skolems CType
goal CType
inst =
  case (TyVar -> BindFlag) -> [Type] -> [Type] -> UnifyResult
tcUnifyTysFG
         (BindFlag -> BindFlag -> Bool -> BindFlag
forall a. a -> a -> Bool -> a
bool BindFlag
BindMe BindFlag
Skolem (Bool -> BindFlag) -> (TyVar -> Bool) -> TyVar -> BindFlag
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TyVar -> Set TyVar -> Bool) -> Set TyVar -> TyVar -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip TyVar -> Set TyVar -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Set TyVar
skolems)
         [CType -> Type
unCType CType
inst]
         [CType -> Type
unCType CType
goal] of
    Unifiable TCvSubst
subst -> TCvSubst -> Maybe TCvSubst
forall (f :: * -> *) a. Applicative f => a -> f a
pure TCvSubst
subst
    UnifyResult
_ -> Maybe TCvSubst
forall a. Maybe a
Nothing



------------------------------------------------------------------------------
-- | Attempt to unify two types.
unify :: CType -- ^ The goal type
      -> CType -- ^ The type we are trying unify the goal type with
      -> RuleM ()
unify :: CType -> CType -> RuleM ()
unify CType
goal CType
inst = do
  Set TyVar
skolems <- (TacticState -> Set TyVar)
-> RuleT
     Judgement
     (Trace, LHsExpr GhcPs)
     TacticError
     TacticState
     ExtractM
     (Set TyVar)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TacticState -> Set TyVar
ts_skolems
  case Set TyVar -> CType -> CType -> Maybe TCvSubst
tryUnifyUnivarsButNotSkolems Set TyVar
skolems CType
goal CType
inst of
    Just TCvSubst
subst ->
      (TacticState -> TacticState) -> RuleM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TacticState
s -> TacticState
s { ts_unifier :: TCvSubst
ts_unifier = TCvSubst -> TCvSubst -> TCvSubst
unionTCvSubst TCvSubst
subst (TacticState -> TCvSubst
ts_unifier TacticState
s) })
    Maybe TCvSubst
Nothing -> TacticError -> RuleM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (CType -> CType -> TacticError
UnificationError CType
inst CType
goal)


------------------------------------------------------------------------------
-- | Get the class methods of a 'PredType', correctly dealing with
-- instantiation of quantified class types.
methodHypothesis :: PredType -> Maybe [(OccName, HyInfo CType)]
methodHypothesis :: Type -> Maybe [(OccName, HyInfo CType)]
methodHypothesis Type
ty = do
  (TyCon
tc, [Type]
apps) <- HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
ty
  Class
cls <- TyCon -> Maybe Class
tyConClass_maybe TyCon
tc
  let methods :: [TyVar]
methods = Class -> [TyVar]
classMethods Class
cls
      tvs :: [TyVar]
tvs     = Class -> [TyVar]
classTyVars Class
cls
      subst :: TCvSubst
subst   = [TyVar] -> [Type] -> TCvSubst
HasDebugCallStack => [TyVar] -> [Type] -> TCvSubst
zipTvSubst [TyVar]
tvs [Type]
apps
  [(OccName, HyInfo CType)]
sc_methods <- ([[(OccName, HyInfo CType)]] -> [(OccName, HyInfo CType)])
-> Maybe [[(OccName, HyInfo CType)]]
-> Maybe [(OccName, HyInfo CType)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[(OccName, HyInfo CType)]] -> [(OccName, HyInfo CType)]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join
              (Maybe [[(OccName, HyInfo CType)]]
 -> Maybe [(OccName, HyInfo CType)])
-> Maybe [[(OccName, HyInfo CType)]]
-> Maybe [(OccName, HyInfo CType)]
forall a b. (a -> b) -> a -> b
$ (Type -> Maybe [(OccName, HyInfo CType)])
-> [Type] -> Maybe [[(OccName, HyInfo CType)]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Type -> Maybe [(OccName, HyInfo CType)]
methodHypothesis (Type -> Maybe [(OccName, HyInfo CType)])
-> (Type -> Type) -> Type -> Maybe [(OccName, HyInfo CType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => TCvSubst -> Type -> Type
TCvSubst -> Type -> Type
substTy TCvSubst
subst)
              ([Type] -> Maybe [[(OccName, HyInfo CType)]])
-> [Type] -> Maybe [[(OccName, HyInfo CType)]]
forall a b. (a -> b) -> a -> b
$ Class -> [Type]
classSCTheta Class
cls
  [(OccName, HyInfo CType)] -> Maybe [(OccName, HyInfo CType)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(OccName, HyInfo CType)] -> Maybe [(OccName, HyInfo CType)])
-> [(OccName, HyInfo CType)] -> Maybe [(OccName, HyInfo CType)]
forall a b. (a -> b) -> a -> b
$ [(OccName, HyInfo CType)]
-> [(OccName, HyInfo CType)] -> [(OccName, HyInfo CType)]
forall a. Monoid a => a -> a -> a
mappend [(OccName, HyInfo CType)]
sc_methods ([(OccName, HyInfo CType)] -> [(OccName, HyInfo CType)])
-> [(OccName, HyInfo CType)] -> [(OccName, HyInfo CType)]
forall a b. (a -> b) -> a -> b
$ [TyVar]
methods [TyVar]
-> (TyVar -> (OccName, HyInfo CType)) -> [(OccName, HyInfo CType)]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \TyVar
method ->
    let ([TyVar]
_, [Type]
_, Type
ty) = Type -> ([TyVar], [Type], Type)
tcSplitSigmaTy (Type -> ([TyVar], [Type], Type))
-> Type -> ([TyVar], [Type], Type)
forall a b. (a -> b) -> a -> b
$ TyVar -> Type
idType TyVar
method
    in ( TyVar -> OccName
forall name. HasOccName name => name -> OccName
occName TyVar
method
       , Provenance -> CType -> HyInfo CType
forall a. Provenance -> a -> HyInfo a
HyInfo (Uniquely Class -> Provenance
ClassMethodPrv (Uniquely Class -> Provenance) -> Uniquely Class -> Provenance
forall a b. (a -> b) -> a -> b
$ Class -> Uniquely Class
forall a. a -> Uniquely a
Uniquely Class
cls) (CType -> HyInfo CType) -> CType -> HyInfo CType
forall a b. (a -> b) -> a -> b
$ Type -> CType
CType (Type -> CType) -> Type -> CType
forall a b. (a -> b) -> a -> b
$ HasCallStack => TCvSubst -> Type -> Type
TCvSubst -> Type -> Type
substTy TCvSubst
subst Type
ty
       )


------------------------------------------------------------------------------
-- | Run the given tactic iff the current hole contains no univars. Skolems and
-- already decided univars are OK though.
requireConcreteHole :: TacticsM a -> TacticsM a
requireConcreteHole :: TacticsM a -> TacticsM a
requireConcreteHole TacticsM a
m = do
  Judgement
jdg     <- TacticT
  Judgement
  (Trace, LHsExpr GhcPs)
  TacticError
  TacticState
  ExtractM
  Judgement
forall (m :: * -> *) jdg ext err s.
Functor m =>
TacticT jdg ext err s m jdg
goal
  Set TyVar
skolems <- (TacticState -> Set TyVar)
-> TacticT
     Judgement
     (Trace, LHsExpr GhcPs)
     TacticError
     TacticState
     ExtractM
     (Set TyVar)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TacticState -> Set TyVar
ts_skolems
  let vars :: Set TyVar
vars = [TyVar] -> Set TyVar
forall a. Ord a => [a] -> Set a
S.fromList ([TyVar] -> Set TyVar) -> [TyVar] -> Set TyVar
forall a b. (a -> b) -> a -> b
$ Type -> [TyVar]
tyCoVarsOfTypeWellScoped (Type -> [TyVar]) -> Type -> [TyVar]
forall a b. (a -> b) -> a -> b
$ CType -> Type
unCType (CType -> Type) -> CType -> Type
forall a b. (a -> b) -> a -> b
$ Judgement -> CType
forall a. Judgement' a -> a
jGoal Judgement
jdg
  case Set TyVar -> Int
forall a. Set a -> Int
S.size (Set TyVar -> Int) -> Set TyVar -> Int
forall a b. (a -> b) -> a -> b
$ Set TyVar
vars Set TyVar -> Set TyVar -> Set TyVar
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set TyVar
skolems of
    Int
0 -> TacticsM a
m
    Int
_ -> TacticError -> TacticsM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TacticError
TooPolymorphic