-- |
-- Module      :  Cryptol.TypeCheck.Solver.Numeric.Interval
-- Copyright   :  (c) 2015-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
--
-- An interval interpretation of types.

{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE BangPatterns #-}

module Cryptol.TypeCheck.Solver.Numeric.Interval where

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.Utils.PP hiding (int)

import           Data.Map ( Map )
import qualified Data.Map as Map
import           Data.Maybe (catMaybes)


-- | Only meaningful for numeric types
typeInterval :: Map TVar Interval -> Type -> Interval
typeInterval varInfo = go
  where
  go ty =
    case ty of
      TUser _ _ t -> go t
      TCon tc ts ->
        case (tc, ts) of
          (TC TCInf, [])      -> iConst Inf
          (TC (TCNum n), [])  -> iConst (Nat n)
          (TF TCAdd, [x,y])   -> iAdd (go x) (go y)
          (TF TCSub, [x,y])   -> iSub (go x) (go y)
          (TF TCMul, [x,y])   -> iMul (go x) (go y)
          (TF TCDiv, [x,y])   -> iDiv (go x) (go y)
          (TF TCMod, [x,y])   -> iMod (go x) (go y)
          (TF TCExp, [x,y])   -> iExp (go x) (go y)
          (TF TCWidth, [x])   -> iWidth (go x)
          (TF TCMin, [x,y])   -> iMin (go x) (go y)
          (TF TCMax, [x,y])   -> iMax (go x) (go y)

          (TF TCCeilDiv, [x,y]) -> iCeilDiv (go x) (go y)
          (TF TCCeilMod, [x,y]) -> iCeilMod (go x) (go y)

          (TF TCLenFromThenTo, [x,y,z]) ->
            iLenFromThenTo (go x) (go y) (go z)
          _ -> iAny

      TVar x -> tvarInterval varInfo x

      _ -> iAny

tvarInterval :: Map TVar Interval -> TVar -> Interval
tvarInterval varInfo x = Map.findWithDefault iAny x varInfo


data IntervalUpdate = NoChange
                    | InvalidInterval TVar
                    | NewIntervals (Map TVar Interval)
                      deriving (Show)

updateInterval :: (TVar,Interval) -> Map TVar Interval -> IntervalUpdate
updateInterval (x,int) varInts =
  case Map.lookup x varInts of
    Just int' ->
      case iIntersect int int' of
        Just val | int' /= val -> NewIntervals (Map.insert x val varInts)
                 | otherwise   -> NoChange
        Nothing                -> InvalidInterval x

    Nothing   -> NewIntervals (Map.insert x int varInts)


computePropIntervals :: Map TVar Interval -> [Prop] -> IntervalUpdate
computePropIntervals ints ps0 = go (3 :: Int) False ints ps0
  where
  go !_n False _ [] = NoChange

  go !n True  is []
    | n > 0     = changed is (go (n-1) False is ps0)
    | otherwise = NewIntervals is

  go !n new   is (p:ps) =
    case add False (propInterval is p) is of
      InvalidInterval i -> InvalidInterval i
      NewIntervals is'  -> go n True is' ps
      NoChange          -> go n new  is  ps

  add ch [] int = if ch then NewIntervals int else NoChange
  add ch (i:is) int = case updateInterval i int of
                        InvalidInterval j -> InvalidInterval j
                        NoChange          -> add ch is int
                        NewIntervals is'  -> add True is is'

  changed a x = case x of
                  NoChange -> NewIntervals a
                  r        -> r


-- | What we learn about variables from a single prop.
propInterval :: Map TVar Interval -> Prop -> [(TVar,Interval)]
propInterval varInts prop = catMaybes
  [ do ty <- pIsFin prop
       x  <- tIsVar ty
       return (x,iAnyFin)

  , do (l,r) <- pIsEqual prop
       x     <- tIsVar l
       return (x,typeInterval varInts r)

  , do (l,r) <- pIsEqual prop
       x     <- tIsVar r
       return (x,typeInterval varInts l)

  , do (l,r) <- pIsGeq prop
       x     <- tIsVar l
       let int = typeInterval varInts r
       return (x,int { iUpper = Just Inf })

  , do (l,r) <- pIsGeq prop
       x     <- tIsVar r
       let int = typeInterval varInts l
       return (x,int { iLower = Nat 0 })

    -- k >= width x
  , do (l,r) <- pIsGeq prop
       x     <- tIsVar =<< pIsWidth r
           -- record the exact upper bound when it produces values within 128
           -- bits
       let ub = case iIsExact (typeInterval varInts l) of
                  Just (Nat val) | val < 128 -> Just (Nat (2 ^ val - 1))
                                 | otherwise -> Nothing
                  upper                      -> upper

       return (x, Interval { iLower = Nat 0, iUpper = ub })

    , do (e,_) <- pIsValidFloat prop
         x <- tIsVar e
         pure (x, iAnyFin)

    , do (_,p) <- pIsValidFloat prop
         x <- tIsVar p
         pure (x, iAnyFin)

  ]


--------------------------------------------------------------------------------

data Interval = Interval
  { iLower :: Nat'          -- ^ lower bound (inclusive)
  , iUpper :: Maybe Nat'    -- ^ upper bound (inclusive)
                            -- If there is no upper bound,
                            -- then all *natural* numbers.
  } deriving (Eq,Show)

ppIntervals :: Map TVar Interval -> Doc
ppIntervals  = vcat . map ppr . Map.toList
  where
  ppr (var,i) = pp var <.> char ':' <+> ppInterval i

ppInterval :: Interval -> Doc
ppInterval x = brackets (hsep [ ppr (iLower x)
                              , text ".."
                              , maybe (text "fin") ppr (iUpper x)])
  where
  ppr a = case a of
           Nat n -> integer n
           Inf   -> text "inf"


iIsExact :: Interval -> Maybe Nat'
iIsExact i = if iUpper i == Just (iLower i) then Just (iLower i) else Nothing

iIsFin :: Interval -> Bool
iIsFin i = case iUpper i of
             Just Inf -> False
             _        -> True

-- | Finite positive number. @[1 ..  inf)@.
iIsPosFin :: Interval -> Bool
iIsPosFin i = iLower i >= Nat 1 && iIsFin i


-- | Returns 'True' when the intervals definitely overlap, and 'False'
-- otherwise.
iOverlap :: Interval -> Interval -> Bool
iOverlap
  (Interval (Nat l1) (Just (Nat h1)))
  (Interval (Nat l2) (Just (Nat h2))) =
    or [ h1 > l2 && h1 < h2, l1 > l2 && l1 < h2 ]
iOverlap _ _ = False


-- | Intersect two intervals, yielding a new one that describes the space where
-- they overlap.  If the two intervals are disjoint, the result will be
-- 'Nothing'.
iIntersect :: Interval -> Interval -> Maybe Interval
iIntersect i j =
  case (lower,upper) of
    (Nat l, Just (Nat u)) | l <= u -> ok
    (Nat _, Just  Inf)             -> ok
    (Nat _, Nothing)               -> ok
    (Inf,   Just Inf)              -> ok
    _                              -> Nothing
  where

  ok    = Just (Interval lower upper)

  lower = nMax (iLower i) (iLower j)

  upper = case (iUpper i, iUpper j) of
            (Just a, Just b)            -> Just (nMin a b)
            (Nothing,Nothing)           -> Nothing
            (Just l,Nothing) | l /= Inf -> Just l
            (Nothing,Just r) | r /= Inf -> Just r
            _                           -> Nothing


-- | Any value
iAny :: Interval
iAny = Interval (Nat 0) (Just Inf)

-- | Any finite value
iAnyFin :: Interval
iAnyFin = Interval (Nat 0) Nothing

-- | Exactly this value
iConst :: Nat' -> Interval
iConst x = Interval x (Just x)




iAdd :: Interval -> Interval -> Interval
iAdd i j = Interval { iLower = nAdd (iLower i) (iLower j)
                    , iUpper = case (iUpper i, iUpper j) of
                                 (Nothing, Nothing) -> Nothing
                                 (Just x, Just y)   -> Just (nAdd x y)
                                 (Nothing, Just y)  -> upper y
                                 (Just x, Nothing)  -> upper x
                    }
  where
  upper x = case x of
              Inf -> Just Inf
              _   -> Nothing

iMul :: Interval -> Interval -> Interval
iMul i j = Interval { iLower = nMul (iLower i) (iLower j)
                    , iUpper = case (iUpper i, iUpper j) of
                                 (Nothing, Nothing) -> Nothing
                                 (Just x, Just y)   -> Just (nMul x y)
                                 (Nothing, Just y)  -> upper y
                                 (Just x, Nothing)  -> upper x
                    }
  where
  upper x = case x of
              Inf   -> Just Inf
              Nat 0 -> Just (Nat 0)
              _     -> Nothing

iExp :: Interval -> Interval -> Interval
iExp i j = Interval { iLower = nExp (iLower i) (iLower j)
                    , iUpper = case (iUpper i, iUpper j) of
                                 (Nothing, Nothing) -> Nothing
                                 (Just x, Just y)   -> Just (nExp x y)
                                 (Nothing, Just y)  -> upperR y
                                 (Just x, Nothing)  -> upperL x
                    }
  where
  upperL x = case x of
               Inf   -> Just Inf
               Nat 0 -> Just (Nat 0)
               Nat 1 -> Just (Nat 1)
               _     -> Nothing

  upperR x = case x of
               Inf   -> Just Inf
               Nat 0 -> Just (Nat 1)
               _     -> Nothing

iMin :: Interval -> Interval -> Interval
iMin i j = Interval { iLower = nMin (iLower i) (iLower j)
                    , iUpper = case (iUpper i, iUpper j) of
                                 (Nothing, Nothing)   -> Nothing
                                 (Just x, Just y)     -> Just (nMin x y)
                                 (Nothing, Just Inf)  -> Nothing
                                 (Nothing, Just y)    -> Just y
                                 (Just Inf, Nothing)  -> Nothing
                                 (Just x, Nothing)    -> Just x
                    }

iMax :: Interval -> Interval -> Interval
iMax i j = Interval { iLower = nMax (iLower i) (iLower j)
                    , iUpper = case (iUpper i, iUpper j) of
                                 (Nothing, Nothing)   -> Nothing
                                 (Just x, Just y)     -> Just (nMax x y)
                                 (Nothing, Just Inf)  -> Just Inf
                                 (Nothing, Just _)    -> Nothing
                                 (Just Inf, Nothing)  -> Just Inf
                                 (Just _, Nothing)    -> Nothing
                    }

iSub :: Interval -> Interval -> Interval
iSub i j = Interval { iLower = lower, iUpper = upper }
  where
  lower = case iUpper j of
            Nothing -> Nat 0
            Just x  -> case nSub (iLower i) x of
                         Nothing -> Nat 0
                         Just y  -> y


  upper = case iUpper i of
            Nothing -> Nothing
            Just x  -> case nSub x (iLower j) of
                         Nothing -> Just Inf {- malformed subtraction -}
                         Just y  -> Just y


iDiv :: Interval -> Interval -> Interval
iDiv i j = Interval { iLower = lower, iUpper = upper }
  where
  lower = case iUpper j of
            Nothing -> Nat 0
            Just x  -> case nDiv (iLower i) x of
                         Nothing -> Nat 0   -- malformed division
                         Just y  -> y

  upper = case iUpper i of
            Nothing -> Nothing
            Just x  -> case nDiv x (nMax (iLower i) (Nat 1)) of
                         Nothing -> Just Inf
                         Just y  -> Just y


iMod :: Interval -> Interval -> Interval
iMod _ j = Interval { iLower = Nat 0, iUpper = upper }
  where
  upper = case iUpper j of
            Just (Nat n) | n > 0 -> Just (Nat (n - 1))
            _                    -> Nothing


iCeilDiv :: Interval -> Interval -> Interval
iCeilDiv i j = Interval { iLower = lower, iUpper = upper }
  where
  lower = case iUpper j of
            Nothing -> if iLower i == Nat 0 then Nat 0 else Nat 1
            Just x  -> case nCeilDiv (iLower i) x of
                         Nothing -> Nat 0   -- malformed division
                         Just y  -> y

  upper = case iUpper i of
            Nothing -> Nothing
            Just x  -> case nCeilDiv x (nMax (iLower i) (Nat 1)) of
                         Nothing -> Just Inf
                         Just y  -> Just y


iCeilMod :: Interval -> Interval -> Interval
iCeilMod = iMod -- bounds are the same as for Mod

iWidth :: Interval -> Interval
iWidth i = Interval { iLower = nWidth (iLower i)
                    , iUpper = case iUpper i of
                                 Nothing -> Nothing
                                 Just n  -> Just (nWidth n)
                    }

iLenFromThenTo :: Interval -> Interval -> Interval -> Interval
iLenFromThenTo i j k
  | Just x <- iIsExact i, Just y <- iIsExact j, Just z <- iIsExact k
  , Just r <- nLenFromThenTo x y z = iConst r
  | otherwise = iAnyFin