{-# LANGUAGE Safe, PatternGuards, MultiWayIf #-}
module Cryptol.TypeCheck.Solver.Numeric
  ( cryIsEqual, cryIsNotEqual, cryIsGeq
  ) where

import           Control.Applicative(Alternative(..))
import           Control.Monad (guard,mzero)
import qualified Control.Monad.Fail as Fail
import           Data.List (sortBy)

import Cryptol.Utils.Patterns
import Cryptol.TypeCheck.PP
import Cryptol.TypeCheck.Type hiding (tMul)
import Cryptol.TypeCheck.TypePat
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.Interval
import Cryptol.TypeCheck.SimpType as Simp

{- Convention for comments:

  K1, K2 ...          Concrete constants
  s1, s2, t1, t2 ...  Arbitrary type expressions
  a, b, c ...         Type variables

-}


-- | Try to solve @t1 = t2@
cryIsEqual :: Ctxt -> Type -> Type -> Solved
cryIsEqual ctxt t1 t2 =
  matchDefault Unsolved $
        (pBin PEqual (==) t1 t2)
    <|> (aNat' t1 >>= tryEqK ctxt t2)
    <|> (aNat' t2 >>= tryEqK ctxt t1)
    <|> (aTVar t1 >>= tryEqVar t2)
    <|> (aTVar t2 >>= tryEqVar t1)
    <|> ( guard (t1 == t2) >> return (SolvedIf []))
    <|> tryEqMin t1 t2
    <|> tryEqMin t2 t1
    <|> tryEqMins t1 t2
    <|> tryEqMins t2 t1
    <|> tryEqMulConst t1 t2
    <|> tryEqAddInf ctxt t1 t2
    <|> tryAddConst (=#=) t1 t2
    <|> tryCancelVar ctxt (=#=) t1 t2
    <|> tryLinearSolution t1 t2
    <|> tryLinearSolution t2 t1

-- | Try to solve @t1 /= t2@
cryIsNotEqual :: Ctxt -> Type -> Type -> Solved
cryIsNotEqual _i t1 t2 = matchDefault Unsolved (pBin PNeq (/=) t1 t2)

-- | Try to solve @t1 >= t2@
cryIsGeq :: Ctxt -> Type -> Type -> Solved
cryIsGeq i t1 t2 =
  matchDefault Unsolved $
        (pBin PGeq (>=) t1 t2)
    <|> (aNat' t1 >>= tryGeqKThan i t2)
    <|> (aNat' t2 >>= tryGeqThanK i t1)
    <|> (aTVar t2 >>= tryGeqThanVar i t1)
    <|> tryGeqThanSub i t1 t2
    <|> (geqByInterval i t1 t2)
    <|> (guard (t1 == t2) >> return (SolvedIf []))
    <|> tryAddConst (>==) t1 t2
    <|> tryCancelVar i (>==) t1 t2
    <|> tryMinIsGeq t1 t2
    -- XXX: k >= width e
    -- XXX: width e >= k


  -- XXX: max t 10 >= 2 --> True
  -- XXX: max t 2 >= 10 --> a >= 10

-- | Try to solve something by evaluation.
pBin :: PC -> (Nat' -> Nat' -> Bool) -> Type -> Type -> Match Solved
pBin tf p t1 t2 =
      Unsolvable <$> anError KNum t1
  <|> Unsolvable <$> anError KNum t2
  <|> (do x <- aNat' t1
          y <- aNat' t2
          return $ if p x y
                      then SolvedIf []
                      else Unsolvable $ TCErrorMessage
                        $ "It is not the case that " ++
                              show (pp (TCon (PC tf) [ tNat' x, tNat' y ])))


--------------------------------------------------------------------------------
-- GEQ

-- | Try to solve @K >= t@
tryGeqKThan :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqKThan _ _ Inf = return (SolvedIf [])
tryGeqKThan _ ty (Nat n) =

  -- K1 >= K2 * t
  do (a,b) <- aMul ty
     m     <- aNat' a
     return $ SolvedIf
            $ case m of
                Inf   -> [ b =#= tZero ]
                Nat 0 -> []
                Nat k -> [ tNum (div n k) >== b ]

-- | Try to solve @t >= K@
tryGeqThanK :: Ctxt -> Type -> Nat' -> Match Solved
tryGeqThanK _ t Inf = return (SolvedIf [ t =#= tInf ])
tryGeqThanK _ t (Nat k) =

  -- K1 + t >= K2
  do (a,b) <- anAdd t
     n     <- aNat a
     return $ SolvedIf $ if n >= k
                            then []
                            else [ b >== tNum (k - n) ]
  -- XXX: K1 ^^ n >= K2


tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved
tryGeqThanSub _ x y =

  -- t1 >= t1 - t2
  do (a,_) <- (|-|) y
     guard (x == a)
     return (SolvedIf [])

tryGeqThanVar :: Ctxt -> Type -> TVar -> Match Solved
tryGeqThanVar _ctxt ty x =
  -- (t + a) >= a
  do (a,b) <- anAdd ty
     let check y = do x' <- aTVar y
                      guard (x == x')
                      return (SolvedIf [])
     check a <|> check b

-- | Try to prove GEQ by considering the known intervals for the given types.
geqByInterval :: Ctxt -> Type -> Type -> Match Solved
geqByInterval ctxt x y =
  let ix = typeInterval (intervals ctxt) x
      iy = typeInterval (intervals ctxt) y
  in case (iLower ix, iUpper iy) of
       (l,Just n) | l >= n -> return (SolvedIf [])
       _                   -> mzero

-- min K1 t >= K2 ~~> t >= K2, if K1 >= K2;  Err otherwise
tryMinIsGeq :: Type -> Type -> Match Solved
tryMinIsGeq t1 t2 =
  do (a,b) <- aMin t1
     k1    <- aNat a
     k2    <- aNat t2
     return $ if k1 >= k2
               then SolvedIf [ b >== t2 ]
               else Unsolvable $ TCErrorMessage $
                      show k1 ++ " can't be greater than " ++ show k2

--------------------------------------------------------------------------------

-- | Cancel finite positive variables from both sides.
-- @(fin a, a >= 1) =>  a * t1 == a * t2 ~~~> t1 == t2@
-- @(fin a, a >= 1) =>  a * t1 >= a * t2 ~~~> t1 >= t2@
tryCancelVar :: Ctxt -> (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryCancelVar ctxt p t1 t2 =
  let lhs = preproc t1
      rhs = preproc t2
  in case check [] [] lhs rhs of
       Nothing -> Fail.fail "tryCancelVar"
       Just x  -> return x


  where
  check doneLHS doneRHS lhs@((a,mbA) : moreLHS) rhs@((b, mbB) : moreRHS) =
    do x <- mbA
       y <- mbB
       case compare x y of
         LT -> check (a : doneLHS) doneRHS moreLHS rhs
         EQ -> return $ SolvedIf [ p (term (doneLHS ++ map fst moreLHS))
                                     (term (doneRHS ++ map fst moreRHS)) ]
         GT -> check doneLHS (b : doneRHS) lhs moreRHS
  check _ _ _ _ = Nothing

  term xs = case xs of
              [] -> tNum (1::Int)
              _  -> foldr1 tMul xs

  preproc t = let fs = splitMul t []
              in sortBy cmpFact (zip fs (map cancelVar fs))

  splitMul t rest = case matchMaybe (aMul t) of
                      Just (a,b) -> splitMul a (splitMul b rest)
                      Nothing    -> t : rest

  cancelVar t = matchMaybe $ do x <- aTVar t
                                guard (iIsPosFin (tvarInterval (intervals ctxt) x))
                                return x

  -- cancellable variables go first, sorted alphabetically
  cmpFact (_,mbA) (_,mbB) =
    case (mbA,mbB) of
      (Just x, Just y)  -> compare x y
      (Just _, Nothing) -> LT
      (Nothing, Just _) -> GT
      _                 -> EQ



-- min t1 t2 = t1 ~> t1 <= t2
tryEqMin :: Type -> Type -> Match Solved
tryEqMin x y =
  do (a,b) <- aMin x
     let check m1 m2 = do guard (m1 == y)
                          return $ SolvedIf [ m2 >== m1 ]
     check a b <|> check b a


-- t1 == min (K + t1) t2 ~~> t1 == t2, if K >= 1
-- (also if (K + t1) is one term in a multi-way min)
tryEqMins :: Type -> Type -> Match Solved
tryEqMins x y =
  do (a, b) <- aMin y
     let ys = splitMin a ++ splitMin b
     let ys' = filter (not . isGt) ys
     let y' = if null ys' then tInf else foldr1 Simp.tMin ys'
     return $ if length ys' < length ys
              then SolvedIf [x =#= y']
              else Unsolved
  where
    splitMin :: Type -> [Type]
    splitMin ty =
      case matchMaybe (aMin ty) of
        Just (t1, t2) -> splitMin t1 ++ splitMin t2
        Nothing       -> [ty]

    isGt :: Type -> Bool
    isGt t =
      case matchMaybe (asAddK t) of
        Just (k, t') -> k > 0 && t' == x
        Nothing      -> False

    asAddK :: Type -> Match (Integer, Type)
    asAddK t =
      do (t1, t2) <- anAdd t
         k <- aNat t1
         return (k, t2)


tryEqVar :: Type -> TVar -> Match Solved
tryEqVar ty x =

  -- a = K + a --> x = inf
  (do (k,tv) <- matches ty (anAdd, aNat, aTVar)
      guard (tv == x && k >= 1)

      return $ SolvedIf [ TVar x =#= tInf ]
  )
  <|>

  -- a = min (K + a) t --> a = t
  (do (l,r) <- aMin ty
      let check this other =
            do (k,x') <- matches this (anAdd, aNat', aTVar)
               guard (x == x' && k >= Nat 1)
               return $ SolvedIf [ TVar x =#= other ]
      check l r <|> check r l
  )
  <|>
  -- a = K + min t a
  (do (k,(l,r)) <- matches ty (anAdd, aNat, aMin)
      guard (k >= 1)
      let check a b = do x' <- aTVar a
                         guard (x' == x)
                         return (SolvedIf [ TVar x =#= tAdd (tNum k) b ])
      check l r <|> check r l
  )







-- e.g., 10 = t
tryEqK :: Ctxt -> Type -> Nat' -> Match Solved
tryEqK ctxt ty lk =

  -- (t1 + t2 = inf, fin t1) ~~~> t2 = inf
  do guard (lk == Inf)
     (a,b) <- anAdd ty
     let check x y = do guard (iIsFin (typeInterval (intervals ctxt) x))
                        return $ SolvedIf [ y =#= tInf ]
     check a b <|> check b a
  <|>

  -- (K1 + t = K2, K2 >= K1) ~~~> t = (K2 - K1)
  do (rk, b) <- matches ty (anAdd, aNat', __)
     return $
       case nSub lk rk of
         -- NOTE: (Inf - Inf) shouldn't be possible
         Nothing -> Unsolvable
                      $ TCErrorMessage
                      $ "Adding " ++ showNat' rk ++ " will always exceed "
                                  ++ showNat' lk

         Just r -> SolvedIf [ b =#= tNat' r ]
  <|>

  -- (lk = t - rk) ~~> t = lk + rk
  do (t,rk) <- matches ty ((|-|) , __, aNat')
     return (SolvedIf [ t =#= tNat' (nAdd lk rk) ])

  <|>
  do (rk, b) <- matches ty (aMul, aNat', __)
     return $
       case (lk,rk) of
         -- Inf * t = Inf ~~~>  t >= 1
         (Inf,Inf)    -> SolvedIf [ b >== tOne ]

         -- K * t = Inf ~~~> t = Inf
         (Inf,Nat _)  -> SolvedIf [ b =#= tInf ]

         -- Inf * t = 0 ~~~> t = 0
         (Nat 0, Inf) -> SolvedIf [ b =#= tZero ]

         -- Inf * t = K ~~~> ERR      (K /= 0)
         (Nat k, Inf) -> Unsolvable
                       $ TCErrorMessage
                       $ show k ++ " != inf * anything"

         (Nat lk', Nat rk')
           -- 0 * t = K2 ~~> K2 = 0
           | rk' == 0 -> SolvedIf [ tNat' lk =#= tZero ]
              -- shouldn't happen, as `0 * t = t` should have been simplified

           -- K1 * t = K2 ~~> t = K2/K1
           | (q,0) <- divMod lk' rk' -> SolvedIf [ b =#= tNum q ]
           | otherwise ->
               Unsolvable
             $ TCErrorMessage
             $ showNat' lk ++ " != " ++ showNat' rk ++ " * anything"

  <|>
  -- K1 == K2 ^^ t    ~~> t = logBase K2 K1
  do (rk, b) <- matches ty ((|^|), aNat, __)
     return $ case lk of
                Inf | rk > 1 -> SolvedIf [ b =#= tInf ]
                Nat n | Just (a,True) <- genLog n rk -> SolvedIf [ b =#= tNum a]
                _ -> Unsolvable $ TCErrorMessage
                       $ show rk ++ " ^^ anything != " ++ showNat' lk

  -- XXX: Min, Max, etx
  -- 2  = min (10,y)  --> y = 2
  -- 2  = min (2,y)   --> y >= 2
  -- 10 = min (2,y)   --> impossible


-- | K1 * t1 + K2 * t2 + ... = K3 * t3 + K4 * t4 + ...
tryEqMulConst :: Type -> Type -> Match Solved
tryEqMulConst l r =
  do (lc,ls) <- matchLinear l
     (rc,rs) <- matchLinear r
     let d = foldr1 gcd (lc : rc : map fst (ls ++ rs))
     guard (d > 1)
     return (SolvedIf [build d lc ls =#= build d rc rs])
  where
  build d k ts   = foldr tAdd (cancel d k) (map (buildS d) ts)
  buildS d (k,t) = tMul (cancel d k) t
  cancel d x     = tNum (div x d)


-- | @(t1 + t2 = Inf, fin t1)  ~~> t2 = Inf@
tryEqAddInf :: Ctxt -> Type -> Type -> Match Solved
tryEqAddInf ctxt l r = check l r <|> check r l
  where

  -- check for x = a + b /\ x = inf
  check x y =
    do (x1,x2) <- anAdd x
       aInf y

       let x1Fin = iIsFin (typeInterval (intervals ctxt) x1)
       let x2Fin = iIsFin (typeInterval (intervals ctxt) x2)

       return $!
         if | x1Fin ->
              SolvedIf [ x2 =#= y ]

            | x2Fin ->
              SolvedIf [ x1 =#= y ]

            | otherwise ->
              Unsolved



-- | Check for addition of constants to both sides of a relation.
--  @((K1 + K2) + t1) `R` (K1 + t2)  ~~>   (K2 + t1) `R` t2@
--
-- This relies on the fact that constants are floated left during
-- simplification.
tryAddConst :: (Type -> Type -> Prop) -> Type -> Type -> Match Solved
tryAddConst rel l r =
  do (x1,x2) <- anAdd l
     (y1,y2) <- anAdd r

     k1 <- aNat x1
     k2 <- aNat y1

     if k1 > k2
        then return (SolvedIf [ tAdd (tNum (k1 - k2)) x2 `rel` y2 ])
        else return (SolvedIf [ x2 `rel` tAdd (tNum (k2 - k1)) y2 ])


-- | Check for situations where a unification variable is involved in
--   a sum of terms not containing additional unification variables,
--   and replace it with a solution and an inequality.
--   @s1 = ?a + s2 ~~> (?a = s1 - s2, s1 >= s2)@
tryLinearSolution :: Type -> Type -> Match Solved
tryLinearSolution s1 t =
  do (a,xs) <- matchLinearUnifier t
     guard (noFreeVariables s1)

     -- NB: matchLinearUnifier only matches if xs is nonempty
     let s2 = foldr1 Simp.tAdd xs
     return (SolvedIf [ TVar a =#= (Simp.tSub s1 s2), s1 >== s2 ])


-- | Match a sum of the form @(s1 + ... + ?a + ... sn)@ where
--   @s1@ through @sn@ do not contain any free variables.
--
--   Note: a successful match should only occur if @s1 ... sn@ is
--   not empty.
matchLinearUnifier :: Pat Type (TVar,[Type])
matchLinearUnifier = go []
 where
  go xs t =
    -- Case where a free variable occurs at the end of a sequence of additions.
    -- NB: match fails if @xs@ is empty
    do v <- aFreeTVar t
       guard (not . null $ xs)
       return (v, xs)
    <|>
    -- Next symbol is an addition
    do (x, y) <- anAdd t

        -- Case where a free variable occurs in the middle of an expression
       (do v <- aFreeTVar x
           guard (noFreeVariables y)
           return (v, reverse (y:xs))

        <|>
         -- Non-free-variable recursive case
        do guard (noFreeVariables x)
           go (x:xs) y)


-- | Is this a sum of products, where the products have constant coefficients?
matchLinear :: Pat Type (Integer, [(Integer,Type)])
matchLinear = go (0, [])
  where
  go (c,ts) t =
    do n <- aNat t
       return (n + c, ts)
    <|>
    do (x,y) <- aMul t
       n     <- aNat x
       return (c, (n,y) : ts)
    <|>
    do (l,r) <- anAdd t
       (c',ts') <- go (c,ts) l
       go (c',ts') r



showNat' :: Nat' -> String
showNat' Inf = "inf"
showNat' (Nat n) = show n