{-|
Module      : What4.Protocol.PolyRoot
Description : Representation for algebraic reals
Copyright   : (c) Galois Inc, 2016-2020
License     : BSD3
Maintainer  : jhendrix@galois.com

Defines a numeric data-type where each number is represented as the root of a
polynomial over a single variable.

This currently only defines operations for parsing the roots from the format
generated by Yices, and evaluating a polynomial over rational coefficients
to the rational derived from the closest double.
-}

{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
module What4.Protocol.PolyRoot
  ( Root
  , approximate
  , fromYicesText
  , parseYicesRoot
  ) where

import           Control.Applicative
import           Control.Lens
import qualified Data.Attoparsec.Text as Atto
import qualified Data.Map as Map
import           Data.Ratio
import           Data.Text (Text)
import qualified Data.Text as Text

import qualified Data.Vector as V
import           Text.PrettyPrint.ANSI.Leijen as PP hiding ((<$>))

atto_angle :: Atto.Parser a -> Atto.Parser a
atto_angle p = Atto.char '<' *> p <* Atto.char '>'

atto_paren :: Atto.Parser a -> Atto.Parser a
atto_paren p = Atto.char '(' *> p <* Atto.char ')'

-- | A polynomial with one variable.
newtype SingPoly coef = SingPoly (V.Vector coef)
  deriving (Functor, Foldable, Traversable, Show)

instance (Ord coef, Num coef, Pretty coef) => Pretty (SingPoly coef) where
  pretty (SingPoly v) =
    case V.findIndex (/= 0) v of
      Nothing -> text "0"
      Just j -> go (V.length v - 1)
        where ppc c | c < 0 = parens (pretty c)
                    | otherwise = pretty c

              ppi 1 = text "*x"
              ppi i = text "*x^" <> pretty i

              go 0 = ppc (v V.! 0)
              go i | seq i False = error "pretty SingPoly"
                   | i == j = ppc (v V.! i) <> ppi i
                   | v V.! i == 0 = go (i-1)
                   | otherwise = ppc (v V.! i) <> ppi i <+> text "+" <+> go (i-1)

fromList :: [c] -> SingPoly c
fromList = SingPoly . V.fromList

-- | Create a polyomial from a map from powers to coefficient.
fromMap :: (Eq c, Num c) => Map.Map Int c -> SingPoly c
fromMap m0 = SingPoly (V.generate (n+1) f)
  where m = Map.filter (/= 0) m0
        (n,_) = Map.findMax m
        f i   = Map.findWithDefault 0 i m

-- | Parse a positive monomial
pos_mono :: Integral c => Atto.Parser (c, Int)
pos_mono = (,) <$> Atto.decimal <*> times_x
  where times_x :: Atto.Parser Int
        times_x = (Atto.char '*' *> Atto.char 'x' *> expon) <|> pure 0

        -- Parse explicit exponent or return 1
        expon :: Atto.Parser Int
        expon = (Atto.char '^' *> Atto.decimal) <|> pure 1


-- | Parses a monomial and returns the coefficient and power
mono :: Integral c => Atto.Parser (c, Int)
mono = atto_paren (Atto.char '-' *> (over _1 negate <$> pos_mono))
     <|> pos_mono

parseYicesPoly :: Integral c => Atto.Parser (SingPoly c)
parseYicesPoly = do
     (c,p) <- mono
     go (Map.singleton p c)
  where go m = next m <|> pure (fromMap m)
        next m = seq m $ do
          _ <- Atto.char ' ' *> Atto.char '+' *> Atto.char ' '
          (c,p) <- mono
          go (Map.insertWith (+) p c m)


-- | Evaluate polynomial at a specific value.
--
-- Note that due to rounding, the result may not be exact when using
-- finite precision arithmetic.
eval :: forall c . Num c => SingPoly c -> c -> c
eval (SingPoly v) c = f 0 1 0
  where -- f takes an index, the current power, and the current sum.
        f :: Int -> c -> c -> c
        f i p s
          | seq p $ seq s $ False = error "internal error: Poly.eval"
          | i < V.length v = f (i+1) (p * c) (s + p * (v V.! i))
          | otherwise = s

data Root c = Root { rootPoly :: !(SingPoly c)
                   , rootLbound :: !c
                   , rootUbound :: !c
                   }
  deriving (Show)

-- | Construct a root from a rational constant
rootFromRational :: Num c => c -> Root c
rootFromRational r = Root { rootPoly = fromList [ negate r, 1 ]
                          , rootLbound = r
                          , rootUbound = r
                          }

instance (Ord c, Num c, Pretty c) => Pretty (Root c) where
  pretty (Root p l u) = langle <> pretty p <> comma <+> bounds <> rangle
    where bounds = parens (pretty l <> comma <+> pretty u)

-- | This either returns the root exactly, or it computes the closest double
-- precision approximation of the root.
--
-- Underneath the hood, this uses rational arithmetic to guarantee precision,
-- so this operation is relatively slow.  However, it is guaranteed to provide
-- an exact answer.
--
-- If performance is a concern, there are faster algorithms for computing this.
approximate :: Root Rational -> Rational
approximate r
    | l0 == u0       = l0
    | init_lval == 0 = l0
    | init_uval == 0 = u0
    | init_lval < 0 && init_uval > 0 = bisect (fromRational l0) (fromRational u0)
    | init_lval > 0 && init_uval < 0 = bisect (fromRational u0) (fromRational l0)
    | otherwise = error "Closest root given bad root."
  where p_rat = rootPoly r
        l0 = rootLbound r
        u0 = rootUbound r

        init_lval = eval p_rat l0
        init_uval = eval p_rat u0

        -- bisect takes a value that evaluates to a negative value under the 'p',
        -- and a value that evalautes to a positive value, and runs until it
        -- converges.
        bisect :: Double -> Double -> Rational
        bisect l u   -- Stop if mid point is at bound.
                   | m == l || m == u = toRational $
                      -- Pick whichever bound is cl oser to root.
                      if l_val <= u_val then l else u
                   | m_val == 0 = toRational m -- Stop if mid point is exact root.
                   | m_val <  0 = bisect m u -- Use mid point as new lower bound
                   | otherwise  = bisect l m -- Use mid point as new upper bound.
          where m = (l + u) / 2
                m_val = eval p_rat (toRational m)
                l_val = abs (eval p_rat (toRational l))
                u_val = abs (eval p_rat (toRational u))


atto_pair :: (a -> b -> r) -> Atto.Parser a -> Atto.Parser b -> Atto.Parser r
atto_pair f x y = f <$> x <*> (Atto.char ',' *> Atto.char ' ' *> y)

atto_sdecimal :: Integral c => Atto.Parser c
atto_sdecimal = Atto.char '-' *> (negate <$> Atto.decimal)
              <|> Atto.decimal

atto_rational :: Integral c => Atto.Parser (Ratio c)
atto_rational = (%) <$> atto_sdecimal <*> denom
  where denom = (Atto.char '/' *> Atto.decimal) <|> pure 1

parseYicesRoot :: Atto.Parser (Root Rational)
parseYicesRoot = atto_angle (atto_pair mkRoot (fmap fromInteger <$> parseYicesPoly) parseBounds)
             <|> (rootFromRational <$> atto_rational)
  where mkRoot :: SingPoly c -> (c, c) -> Root c
        mkRoot = uncurry . Root
        parseBounds :: Atto.Parser (Rational, Rational)
        parseBounds = atto_paren (atto_pair (,) atto_rational atto_rational)

-- | Convert text to a root
fromYicesText :: Text -> Maybe (Root Rational)
fromYicesText t = resolve (Atto.parse parseYicesRoot t)
  where resolve (Atto.Fail _rem _ _msg) = Nothing
        resolve (Atto.Partial f) =
          resolve (f Text.empty)
        resolve (Atto.Done i r)
          | Text.null i = Just $! r
          | otherwise = Nothing