{-|
Module      : Idris.Coverage
Description : Clause generation for coverage checking

License     : BSD3
Maintainer  : The Idris Community.
-}
{-# LANGUAGE FlexibleContexts, PatternGuards #-}
module Idris.Coverage(genClauses, validCoverageCase, recoverableCoverage,
                      mkPatTm) where

import Idris.AbsSyntax
import Idris.Core.CaseTree
import Idris.Core.Evaluate
import Idris.Core.TT
import Idris.Delaborate
import Idris.Elab.Utils
import Idris.Error

import Control.Monad.State.Strict
import Data.Char
import Data.List
import Data.Maybe

-- | Generate a pattern from an 'impossible' LHS.
--
-- We need this to eliminate the pattern clauses which have been
-- provided explicitly from new clause generation.
--
-- This takes a type directed approach to disambiguating names. If we
-- can't immediately disambiguate by looking at the expected type, it's an
-- error (we can't do this the usual way of trying it to see what type checks
-- since the whole point of an impossible case is that it won't type check!)
mkPatTm :: PTerm -> Idris Term
mkPatTm t = do i <- getIState
               let timp = addImpl' True [] [] [] i t
               evalStateT (toTT Nothing timp) 0
  where
    toTT :: Maybe Type -> PTerm -> StateT Int Idris Term
    toTT ty (PRef _ _ n)
       = do i <- lift getIState
            case lookupDefExact n (tt_ctxt i) of
                 Just (TyDecl nt _) -> return $ P nt n Erased
                 _ -> return $ P Ref n Erased
    toTT ty (PApp _ t@(PRef _ _ n) args)
       = do i <- lift getIState
            let aTys = case lookupTyExact n (tt_ctxt i) of
                              Just nty -> map (Just . snd) (getArgTys nty)
                              Nothing -> map (const Nothing) args
            args' <- zipWithM toTT aTys (map getTm args)
            t' <- toTT Nothing t
            return $ mkApp t' args'
    toTT ty (PApp _ t args)
       = do t' <- toTT Nothing t
            args' <- mapM (toTT Nothing . getTm) args
            return $ mkApp t' args'
    toTT ty (PDPair _ _ _ l _ r)
       = do l' <- toTT Nothing l
            r' <- toTT Nothing r
            return $ mkApp (P Ref sigmaCon Erased) [Erased, Erased, l', r']
    toTT ty (PPair _ _ _ l r)
       = do l' <- toTT Nothing l
            r' <- toTT Nothing r
            return $ mkApp (P Ref pairCon Erased) [Erased, Erased, l', r']
    -- For alternatives, pick the first and drop the namespaces. It doesn't
    -- really matter which is taken since matching will ignore the namespace.
    toTT (Just ty) (PAlternative _ _ as)
       | (hd, _) <- unApply ty
          = do i <- lift getIState
               case pruneByType True [] hd ty i as of
                    [a] -> toTT (Just ty) a
                    _ -> lift $ ierror $ CantResolveAlts (map getAltName as)
    toTT Nothing (PAlternative _ _ as)
                    = lift $ ierror $ CantResolveAlts (map getAltName as)
    toTT ty _
       = do v <- get
            put (v + 1)
            return (P Bound (sMN v "imp") Erased)

    getAltName (PApp _ (PRef _ _ (UN l)) [_, _, arg])
             | l == txt "Delay" = getAltName (getTm arg)
    getAltName (PApp _ (PRef _ _ n) _) = n
    getAltName (PRef _ _ n) = n
    getAltName (PApp _ h _) = getAltName h
    getAltName (PHidden h) = getAltName h
    getAltName x = sUN "_" -- should never happen here

-- | Given a list of LHSs, generate a extra clauses which cover the remaining
-- cases. The ones which haven't been provided are marked 'absurd' so
-- that the checker will make sure they can't happen.
--
-- This will only work after the given clauses have been typechecked and the
-- names are fully explicit!
genClauses :: FC -> Name -> [([Name], Term)] -> -- (Argument names, LHS)
              [PTerm] -> Idris [PTerm]
-- No clauses (only valid via elab reflection). We should probably still do
-- a check here somehow, e.g. that one of the arguments is an obviously
-- empty type. In practice, this should only really be used for Void elimination.
genClauses fc n lhs_tms [] = return []
genClauses fc n lhs_tms given
   = do i <- getIState

        let lhs_given = zipWith removePlaceholders lhs_tms
                            (map (stripUnmatchable i) (map flattenArgs given))

        logCoverage 5 $ "Building coverage tree for:\n" ++ showSep "\n" (map showTmImpls given)
        logCoverage 10 $ "Building coverage tree for:\n" ++ showSep "\n" (map show lhs_given)
        logCoverage 10 $ "From terms:\n" ++ showSep "\n" (map show lhs_tms)
        let givenpos = mergePos (map getGivenPos given)

        (cns, ctree_in) <-
                         case simpleCase False (UnmatchedCase "Undefined") False
                              (CoverageCheck givenpos) emptyFC [] []
                              lhs_given
                              (const []) of
                           OK (CaseDef cns ctree_in _) ->
                              return (cns, ctree_in)
                           Error e -> tclift $ tfail $ At fc e

        let ctree = trimOverlapping (addMissingCons i ctree_in)
        let (coveredas, missingas) = mkNewClauses (tt_ctxt i) n cns ctree
        let covered = map (\t -> delab' i t True True) coveredas
        let missing = filter (\x -> x `notElem` covered) $
                          map (\t -> delab' i t True True) missingas

        logCoverage 5 $ "Coverage from case tree for " ++ show n ++ ": " ++ show ctree
        logCoverage 2 $ show (length missing) ++ " missing clauses for " ++ show n
        logCoverage 3 $ "Missing clauses:\n" ++ showSep "\n"
                              (map showTmImpls missing)
        logCoverage 10 $ "Covered clauses:\n" ++ showSep "\n"
                              (map showTmImpls covered)
        return missing
    where
        flattenArgs (PApp fc (PApp _ f as) as')
             = flattenArgs (PApp fc f (as ++ as'))
        flattenArgs t = t

getGivenPos :: PTerm -> [Int]
getGivenPos (PApp _ _ pargs) = getGiven 0 (map getTm pargs)
  where
    getGiven i (Placeholder : tms) = getGiven (i + 1) tms
    getGiven i (_ : tms) = i : getGiven (i + 1) tms
    getGiven i [] = []
getGivenPos _ = []

-- Return a list of Ints which are in every list
mergePos :: [[Int]] -> [Int]
mergePos [] = []
mergePos [x] = x
mergePos (x : xs) = intersect x (mergePos xs)

removePlaceholders :: ([Name], Term) -> PTerm -> ([Name], Term, Term)
removePlaceholders (ns, tm) ptm = (ns, rp tm ptm, Erased)
  where
    rp Erased Placeholder = Erased
    rp tm Placeholder = Inferred tm
    rp tm (PApp _ pf pargs)
       | (tf, targs) <- unApply tm
           = let tf' = rp tf pf
                 targs' = zipWith rp targs (map getTm pargs) in
                 mkApp tf' targs'
    rp tm (PPair _ _ _ pl pr)
       | (tf, [tyl, tyr, tl, tr]) <- unApply tm
           = let tl' = rp tl pl
                 tr' = rp tr pr in
                 mkApp tf [Erased, Erased, tl', tr']
    rp tm (PDPair _ _ _ pl pt pr)
       | (tf, [tyl, tyr, tl, tr]) <- unApply tm
           = let tl' = rp tl pl
                 tr' = rp tr pr in
                 mkApp tf [Erased, Erased, tl', tr']
    rp tm _ = tm

mkNewClauses :: Context -> Name -> [Name] -> SC -> ([Term], [Term])
mkNewClauses ctxt fn ns sc
     = (map (mkPlApp (P Ref fn Erased)) $
            mkFromSC True (map (\n -> P Ref n Erased) ns) sc,
        map (mkPlApp (P Ref fn Erased)) $
            mkFromSC False (map (\n -> P Ref n Erased) ns) sc)
  where
    mkPlApp f args = mkApp f (map erasePs args)

    erasePs ap@(App t f a)
        | (f, args) <- unApply ap = mkApp f (map erasePs args)
    erasePs (P _ n _) | not (isConName n ctxt) = Erased
    erasePs tm = tm

    mkFromSC cov args sc = evalState (mkFromSC' cov args sc) []

    mkFromSC' :: Bool -> [Term] -> SC -> State [[Term]] [[Term]]
    mkFromSC' cov args (STerm _)
        = if cov then return [args] else return [] -- leaf of provided case
    mkFromSC' cov args (UnmatchedCase _)
        = if cov then return [] else return [args] -- leaf of missing case
    mkFromSC' cov args ImpossibleCase = return []
    mkFromSC' cov args (Case _ x alts)
       = do done <- get
            if (args `elem` done)
               then return []
               else do alts' <- mapM (mkFromAlt cov args x) alts
                       put (args : done)
                       return (concat alts')
    mkFromSC' cov args _ = return [] -- Should never happen

    mkFromAlt :: Bool -> [Term] -> Name -> CaseAlt -> State [[Term]] [[Term]]
    mkFromAlt cov args x (ConCase c t conargs sc)
       = let argrep = mkApp (P (DCon t (length args) False) c Erased)
                            (map (\n -> P Ref n Erased) conargs)
             args' = map (subst x argrep) args in
             mkFromSC' cov args' sc
    mkFromAlt cov args x (ConstCase c sc)
       = let argrep = Constant c
             args' = map (subst x argrep) args in
             mkFromSC' cov args' sc
    mkFromAlt cov args x (DefaultCase sc)
       = mkFromSC' cov args sc
    mkFromAlt cov _ _ _ = return []

-- Modify the generated case tree (the case tree builder doesn't have access
-- to the context, so can't do this itself).
-- Replaces any missing cases with explicit cases for the missing constructors
addMissingCons :: IState -> SC -> SC
addMissingCons ist sc = evalState (addMissingConsSt ist sc) 0

addMissingConsSt :: IState -> SC -> State Int SC
addMissingConsSt ist (Case t n alts) = liftM (Case t n) (addMissingAlts n alts)
  where
    addMissingAlt :: CaseAlt -> State Int CaseAlt
    addMissingAlt (ConCase n i ns sc)
         = liftM (ConCase n i ns) (addMissingConsSt ist sc)
    addMissingAlt (FnCase n ns sc)
         = liftM (FnCase n ns) (addMissingConsSt ist sc)
    addMissingAlt (ConstCase c sc)
         = liftM (ConstCase c) (addMissingConsSt ist sc)
    addMissingAlt (SucCase n sc)
         = liftM (SucCase n) (addMissingConsSt ist sc)
    addMissingAlt (DefaultCase sc)
         = liftM DefaultCase (addMissingConsSt ist sc)

    addMissingAlts argn as
--          | any hasDefault as = map addMissingAlt as
         | cons@(n:_) <- mapMaybe collectCons as,
           Just tyn <- getConType n,
           Just ti <- lookupCtxtExact tyn (idris_datatypes ist)
             -- If we've fallen through on this argument earlier, then the
             -- things which were matched in other cases earlier can't be missing
             -- cases now
             = let missing = con_names ti \\ cons in
                   do as' <- addCases missing as
                      mapM addMissingAlt as'
         | consts@(n:_) <- mapMaybe collectConsts as
             = let missing = nub (map nextConst consts) \\ consts in
                   mapM addMissingAlt (addCons missing as)
    addMissingAlts n as = mapM addMissingAlt as

    addCases missing [] = return []
    addCases missing (DefaultCase rhs : rest)
       = do missing' <- mapM (genMissingAlt rhs) missing
            return (mapMaybe id missing' ++ rest)
    addCases missing (c : rest)
       = liftM (c :) $ addCases missing rest

    addCons missing [] = []
    addCons missing (DefaultCase rhs : rest)
       = map (genMissingConAlt rhs) missing ++ rest
    addCons missing (c : rest) = c : addCons missing rest

    genMissingAlt rhs n
         | Just (TyDecl (DCon tag arity _) ty) <- lookupDefExact n (tt_ctxt ist)
             = do name <- get
                  put (name + arity)
                  let args = map (name +) [0..arity-1]
                  return $ Just $ ConCase n tag (map (\i -> sMN i "m") args) rhs
         | otherwise = return Nothing

    genMissingConAlt rhs n = ConstCase n rhs

    collectCons (ConCase n i args sc) = Just n
    collectCons _ = Nothing

    collectConsts (ConstCase c sc) = Just c
    collectConsts _ = Nothing

    getConType n = do ty <- lookupTyExact n (tt_ctxt ist)
                      case unApply (getRetTy (normalise (tt_ctxt ist) [] ty)) of
                           (P _ tyn _, _) -> Just tyn
                           _ -> Nothing

    -- for every constant in a term (at any level) take next one to make sure
    -- that constants which are not explicitly handled are covered
    nextConst (I c) = I (c + 1)
    nextConst (BI c) = BI (c + 1)
    nextConst (Fl c) = Fl (c + 1)
    nextConst (B8 c) = B8 (c + 1)
    nextConst (B16 c) = B16 (c + 1)
    nextConst (B32 c) = B32 (c + 1)
    nextConst (B64 c) = B64 (c + 1)
    nextConst (Ch c) = Ch (chr $ ord c + 1)
    nextConst (Str c) = Str (c ++ "'")
    nextConst o = o

addMissingConsSt ist sc = return sc

trimOverlapping :: SC -> SC
trimOverlapping sc = trim [] [] sc
  where
    trim :: [(Name, (Name, [Name]))] -> -- Variable - constructor+args already matched
            [(Name, [Name])] -> -- Variable - constructors which it can't be
            SC -> SC
    trim mustbes nots (Case t vn alts)
       | Just (c, args) <- lookup vn mustbes
            = Case t vn (trimAlts mustbes nots vn (substMatch (c, args) alts))
       | Just cantbe <- lookup vn nots
            = let alts' = filter (notConMatch cantbe) alts in
                  Case t vn (trimAlts mustbes nots vn alts')
       | otherwise = Case t vn (trimAlts mustbes nots vn alts)
    trim cs nots sc = sc

    trimAlts cs nots vn [] = []
    trimAlts cs nots vn (ConCase cn t args sc : rest)
        = ConCase cn t args (trim (addMatch vn (cn, args) cs) nots sc) :
            trimAlts cs (addCantBe vn cn nots) vn rest
    trimAlts cs nots vn (FnCase n ns sc : rest)
        = FnCase n ns (trim cs nots sc) : trimAlts cs nots vn rest
    trimAlts cs nots vn (ConstCase c sc : rest)
        = ConstCase c (trim cs nots sc) : trimAlts cs nots vn rest
    trimAlts cs nots vn (SucCase n sc : rest)
        = SucCase n (trim cs nots sc) : trimAlts cs nots vn rest
    trimAlts cs nots vn (DefaultCase sc : rest)
        = DefaultCase (trim cs nots sc) : trimAlts cs nots vn rest

    substMatch :: (Name, [Name]) -> [CaseAlt] -> [CaseAlt]
    substMatch ca [] = []
    substMatch (c,args) (ConCase cn t args' sc : _)
        | c == cn = [ConCase c t args (substNames (zip args' args) sc)]
    substMatch ca (_:cs) = substMatch ca cs

    substNames [] sc = sc
    substNames ((n, n') : ns) sc
       = substNames ns (substSC n n' sc)

    notConMatch cs (ConCase cn t args sc) = cn `notElem` cs
    notConMatch cs _ = True

    addMatch vn cn cs = (vn, cn) : cs

    addCantBe :: Name -> Name -> [(Name, [Name])] -> [(Name, [Name])]
    addCantBe vn cn [] = [(vn, [cn])]
    addCantBe vn cn ((n, cbs) : nots)
          | vn == n = ((n, nub (cn : cbs)) : nots)
          | otherwise = ((n, cbs) : addCantBe vn cn nots)

-- | Does this error result rule out a case as valid when coverage checking?
validCoverageCase :: Context -> Err -> Bool
validCoverageCase ctxt (CantUnify _ (topx, _) (topy, _) e _ _)
    = let topx' = normalise ctxt [] topx
          topy' = normalise ctxt [] topy in
          not (sameFam topx' topy' || not (validCoverageCase ctxt e))
  where sameFam topx topy
            = case (unApply topx, unApply topy) of
                   ((P _ x _, _), (P _ y _, _)) -> x == y
                   _ -> False
validCoverageCase ctxt (InfiniteUnify _ _ _) = False
validCoverageCase ctxt (CantConvert topx topy _)
    = let topx' = normalise ctxt [] topx
          topy' = normalise ctxt [] topy in
          not (sameFam topx' topy')
  where sameFam topx topy
            = case (unApply topx, unApply topy) of
                   ((P _ x _, _), (P _ y _, _)) -> x == y
                   _ -> False
validCoverageCase ctxt (At _ e) = validCoverageCase ctxt e
validCoverageCase ctxt (Elaborating _ _ _ e) = validCoverageCase ctxt e
validCoverageCase ctxt (ElaboratingArg _ _ _ e) = validCoverageCase ctxt e
validCoverageCase ctxt _ = True

-- | Check whether an error is recoverable in the sense needed for
-- coverage checking.
recoverableCoverage :: Context -> Err -> Bool
recoverableCoverage ctxt (CantUnify r (topx, _) (topy, _) e _ _)
    = let topx' = normalise ctxt [] topx
          topy' = normalise ctxt [] topy in
          evalState (checkRec topx' topy') []
recoverableCoverage ctxt (CantConvert topx topy _)
    = let topx' = normalise ctxt [] topx
          topy' = normalise ctxt [] topy in
          evalState (checkRec topx' topy') []
recoverableCoverage ctxt (InfiniteUnify _ _ _) = False -- always unrecoverable
recoverableCoverage ctxt (At _ e) = recoverableCoverage ctxt e
recoverableCoverage ctxt (Elaborating _ _ _ e) = recoverableCoverage ctxt e
recoverableCoverage ctxt (ElaboratingArg _ _ _ e) = recoverableCoverage ctxt e
recoverableCoverage _ _ = False

-- different notion of recoverable than in unification, since we
-- have no metavars -- just looking to see if a constructor is failing
-- to unify with a function that may be reduced later, or if any
-- variables need to have two different constructor forms

-- The state is a mapping of name to what it has failed to unify
-- with
checkRec :: Term -> Term -> State [(Name, Term)] Bool
checkRec (P Bound x _) tm
   | isCon tm = do nmap <- get
                   case lookup x nmap of
                        Nothing -> do put ((x, tm) : nmap)
                                      return True
                        Just y' -> checkRec tm y'
 where
   isCon tm
       | (P yt _ _, _) <- unApply tm,
         conType yt = True
   isCon (Constant _) = True
   isCon _ = False

   conType (DCon _ _ _) = True
   conType (TCon _ _) = True
   conType _ = False

checkRec tm (P Bound y _)
   | isCon tm = do nmap <- get
                   case lookup y nmap of
                        Nothing -> do put ((y, tm) : nmap)
                                      return True
                        Just x' -> checkRec tm x'
 where
   isCon tm
       | (P yt _ _, _) <- unApply tm,
         conType yt = True
   isCon (Constant _) = True
   isCon _ = False

   conType (DCon _ _ _) = True
   conType (TCon _ _) = True
   conType _ = False

checkRec (App _ f a) p@(P _ _ _) = checkRec f p
checkRec (App _ f a) p@(Constant _) = checkRec f p
checkRec p@(P _ _ _) (App _ f a) = checkRec p f
checkRec p@(Constant _) (App _ f a) = checkRec p f
checkRec fa@(App _ _ _) fa'@(App _ _ _)
    | (f, as) <- unApply fa,
      (f', as') <- unApply fa'
         = if (length as /= length as')
              then checkRec f f'
              -- Same function but different args is recoverable,
              -- and vice versa, if it's an ordinary function
              -- If a constructor, everything has to be recoverable
              else do fok <- checkRec f f'
                      argok <- checkRecs (f : as) (f : as')
                      return (if conType f then fok && argok
                                           else fok || argok)
  where
    checkRecs [] [] = return True
    checkRecs (a : as) (b : bs) = do aok <- checkRec a b
                                     asok <- checkRecs as bs
                                     return (aok && asok)
    conType (P (DCon _ _ _) _ _) = True
    conType (P (TCon _ _) _ _) = True
    conType (Constant _) = True
    conType _ = False

checkRec (P xt x _) (P yt y _)
   | x == y = return True
   | ntRec xt yt = return True
 where
    -- If either name is a reference or a bound variable, then further
    -- development may fix the error, so consider it recoverable.
    -- If both names are constructors, and the name is different, then
    -- it's not recoverable
    ntRec x y | Ref <- x = True
              | Ref <- y = True
              | Bound <- x = True
              | Bound <- y = True
              | otherwise = False -- name is different, unrecoverable
-- A function reference against a constant might be recoverable if we get to
-- reduce the function
checkRec (P Ref _ _) (Constant _) = return True
checkRec (Constant _) (P Ref _ _) = return True
checkRec (TType _) (TType _) = return True
checkRec _ _ = return False