{-# LANGUAGE BangPatterns #-}
module Data.SBV.Utils.SExpr (SExpr(..), parenDeficit, parseSExpr) where
import Data.Bits (setBit, testBit)
import Data.Word (Word32, Word64)
import Data.Char (isDigit, ord, isSpace)
import Data.List (isPrefixOf)
import Data.Maybe (fromMaybe, listToMaybe)
import Numeric (readInt, readDec, readHex, fromRat)
import Data.SBV.Core.AlgReals
import Data.SBV.Core.Data (nan, infinity, RoundingMode(..))
import Data.Numbers.CrackNum (wordToFloat, wordToDouble)
data SExpr = ECon String
| ENum (Integer, Maybe Int)
| EReal AlgReal
| EFloat Float
| EDouble Double
| EApp [SExpr]
deriving Show
tokenize :: String -> [String]
tokenize inp = go inp []
where go "" sofar = reverse sofar
go (c:cs) sofar
| isSpace c = go (dropWhile isSpace cs) sofar
go ('(':cs) sofar = go cs ("(" : sofar)
go (')':cs) sofar = go cs (")" : sofar)
go (':':':':cs) sofar = go cs ("::" : sofar)
go (':':cs) sofar = case break (`elem` stopper) cs of
(pre, rest) -> go rest ((':':pre) : sofar)
go ('|':r) sofar = case span (/= '|') r of
(pre, '|':rest) -> go rest (pre : sofar)
(pre, rest) -> go rest (pre : sofar)
go ('"':r) sofar = go rest (finalStr : sofar)
where grabString [] acc = (reverse acc, [])
grabString ('"' :'"':cs) acc = grabString cs ('"' :acc)
grabString ('"':cs) acc = (reverse acc, cs)
grabString (c:cs) acc = grabString cs (c:acc)
(str, rest) = grabString r []
finalStr = '"' : str ++ "\""
go cs sofar = case span (`notElem` stopper) cs of
(pre, post) -> go post (pre : sofar)
stopper = " \t\n():|\""
parenDeficit :: String -> Int
parenDeficit = go 0 . tokenize
where go :: Int -> [String] -> Int
go !balance [] = balance
go !balance ("(" : rest) = go (balance+1) rest
go !balance (")" : rest) = go (balance-1) rest
go !balance (_ : rest) = go balance rest
parseSExpr :: String -> Either String SExpr
parseSExpr inp = do (sexp, extras) <- parse inpToks
if null extras
then case sexp of
EApp [ECon "error", ECon er] -> Left $ "Solver returned an error: " ++ er
_ -> return sexp
else die "Extra tokens after valid input"
where inpToks = tokenize inp
die w = fail $ "SBV.Provers.SExpr: Failed to parse S-Expr: " ++ w
++ "\n*** Input : <" ++ inp ++ ">"
parse [] = die "ran out of tokens"
parse ("(":toks) = do (f, r) <- parseApp toks []
f' <- cvt (EApp f)
return (f', r)
parse (")":_) = die "extra tokens after close paren"
parse [tok] = do t <- pTok tok
return (t, [])
parse _ = die "ill-formed s-expr"
parseApp [] _ = die "failed to grab s-expr application"
parseApp (")":toks) sofar = return (reverse sofar, toks)
parseApp ("(":toks) sofar = do (f, r) <- parse ("(":toks)
parseApp r (f : sofar)
parseApp (tok:toks) sofar = do t <- pTok tok
parseApp toks (t : sofar)
pTok "false" = return $ ENum (0, Nothing)
pTok "true" = return $ ENum (1, Nothing)
pTok ('0':'b':r) = mkNum (Just (length r)) $ readInt 2 (`elem` "01") (\c -> ord c - ord '0') r
pTok ('b':'v':r) | not (null r) && all isDigit r = mkNum Nothing $ readDec (takeWhile (/= '[') r)
pTok ('#':'b':r) = mkNum (Just (length r)) $ readInt 2 (`elem` "01") (\c -> ord c - ord '0') r
pTok ('#':'x':r) = mkNum (Just (4 * length r)) $ readHex r
pTok n
| not (null n) && isDigit (head n)
= if '.' `elem` n then getReal n
else mkNum Nothing $ readDec n
pTok n = return $ ECon (constantMap n)
mkNum l [(n, "")] = return $ ENum (n, l)
mkNum _ _ = die "cannot read number"
getReal n = return $ EReal $ mkPolyReal (Left (exact, n'))
where exact = not ("?" `isPrefixOf` reverse n)
n' | exact = n
| True = init n
cvt (EApp [ECon "to_int", EReal a]) = return $ EReal a
cvt (EApp [ECon "to_real", EReal a]) = return $ EReal a
cvt (EApp [ECon "/", EReal a, EReal b]) = return $ EReal (a / b)
cvt (EApp [ECon "/", EReal a, ENum b]) = return $ EReal (a / fromInteger (fst b))
cvt (EApp [ECon "/", ENum a, EReal b]) = return $ EReal (fromInteger (fst a) / b )
cvt (EApp [ECon "/", ENum a, ENum b]) = return $ EReal (fromInteger (fst a) / fromInteger (fst b))
cvt (EApp [ECon "-", EReal a]) = return $ EReal (-a)
cvt (EApp [ECon "-", ENum a]) = return $ ENum (-(fst a), snd a)
cvt (EApp [ECon "_", ENum a, ENum _b]) = return $ ENum a
cvt (EApp [ECon "root-obj", EApp (ECon "+":trms), ENum k]) = do ts <- mapM getCoeff trms
return $ EReal $ mkPolyReal (Right (fst k, ts))
cvt (EApp [ECon "as", n, EApp [ECon "_", ECon "FloatingPoint", ENum (11, _), ENum (53, _)]]) = getDouble n
cvt (EApp [ECon "as", n, EApp [ECon "_", ECon "FloatingPoint", ENum ( 8, _), ENum (24, _)]]) = getFloat n
cvt (EApp [ECon "as", n, ECon "Float64"]) = getDouble n
cvt (EApp [ECon "as", n, ECon "Float32"]) = getFloat n
cvt (EApp [ECon "fp", ENum (s, Just 1), ENum ( e, Just 8), ENum (m, Just 23)]) = return $ EFloat $ getTripleFloat s e m
cvt (EApp [ECon "fp", ENum (s, Just 1), ENum ( e, Just 11), ENum (m, Just 52)]) = return $ EDouble $ getTripleDouble s e m
cvt (EApp [ECon "_", ECon "NaN", ENum ( 8, _), ENum (24, _)]) = return $ EFloat nan
cvt (EApp [ECon "_", ECon "NaN", ENum (11, _), ENum (53, _)]) = return $ EDouble nan
cvt (EApp [ECon "_", ECon "+oo", ENum ( 8, _), ENum (24, _)]) = return $ EFloat infinity
cvt (EApp [ECon "_", ECon "+oo", ENum (11, _), ENum (53, _)]) = return $ EDouble infinity
cvt (EApp [ECon "_", ECon "-oo", ENum ( 8, _), ENum (24, _)]) = return $ EFloat (-infinity)
cvt (EApp [ECon "_", ECon "-oo", ENum (11, _), ENum (53, _)]) = return $ EDouble (-infinity)
cvt (EApp [ECon "_", ECon "+zero", ENum ( 8, _), ENum (24, _)]) = return $ EFloat 0
cvt (EApp [ECon "_", ECon "+zero", ENum (11, _), ENum (53, _)]) = return $ EDouble 0
cvt (EApp [ECon "_", ECon "-zero", ENum ( 8, _), ENum (24, _)]) = return $ EFloat (-0)
cvt (EApp [ECon "_", ECon "-zero", ENum (11, _), ENum (53, _)]) = return $ EDouble (-0)
cvt x = return x
getCoeff (EApp [ECon "*", ENum k, EApp [ECon "^", ECon "x", ENum p]]) = return (fst k, fst p)
getCoeff (EApp [ECon "*", ENum k, ECon "x" ] ) = return (fst k, 1)
getCoeff ( EApp [ECon "^", ECon "x", ENum p] ) = return ( 1, fst p)
getCoeff ( ECon "x" ) = return ( 1, 1)
getCoeff ( ENum k ) = return (fst k, 0)
getCoeff x = die $ "Cannot parse a root-obj,\nProcessing term: " ++ show x
getDouble (ECon s) = case (s, rdFP (dropWhile (== '+') s)) of
("plusInfinity", _ ) -> return $ EDouble infinity
("minusInfinity", _ ) -> return $ EDouble (-infinity)
("oo", _ ) -> return $ EDouble infinity
("-oo", _ ) -> return $ EDouble (-infinity)
("zero", _ ) -> return $ EDouble 0
("-zero", _ ) -> return $ EDouble (-0)
("NaN", _ ) -> return $ EDouble nan
(_, Just v) -> return $ EDouble v
_ -> die $ "Cannot parse a double value from: " ++ s
getDouble (EApp [_, s, _, _]) = getDouble s
getDouble (EReal r) = return $ EDouble $ fromRat $ toRational r
getDouble x = die $ "Cannot parse a double value from: " ++ show x
getFloat (ECon s) = case (s, rdFP (dropWhile (== '+') s)) of
("plusInfinity", _ ) -> return $ EFloat infinity
("minusInfinity", _ ) -> return $ EFloat (-infinity)
("oo", _ ) -> return $ EFloat infinity
("-oo", _ ) -> return $ EFloat (-infinity)
("zero", _ ) -> return $ EFloat 0
("-zero", _ ) -> return $ EFloat (-0)
("NaN", _ ) -> return $ EFloat nan
(_, Just v) -> return $ EFloat v
_ -> die $ "Cannot parse a float value from: " ++ s
getFloat (EReal r) = return $ EFloat $ fromRat $ toRational r
getFloat (EApp [_, s, _, _]) = getFloat s
getFloat x = die $ "Cannot parse a float value from: " ++ show x
rdFP :: (Read a, RealFloat a) => String -> Maybe a
rdFP s = case break (`elem` "pe") s of
(m, 'p':e) -> rd m >>= \m' -> rd e >>= \e' -> return $ m' * ( 2 ** e')
(m, 'e':e) -> rd m >>= \m' -> rd e >>= \e' -> return $ m' * (10 ** e')
(m, "") -> rd m
_ -> Nothing
where rd v = case reads v of
[(n, "")] -> Just n
_ -> Nothing
getTripleFloat :: Integer -> Integer -> Integer -> Float
getTripleFloat s e m = wordToFloat w32
where sign = [s == 1]
expt = [e `testBit` i | i <- [ 7, 6 .. 0]]
mantissa = [m `testBit` i | i <- [22, 21 .. 0]]
positions = [i | (i, b) <- zip [31, 30 .. 0] (sign ++ expt ++ mantissa), b]
w32 = foldr (flip setBit) (0::Word32) positions
getTripleDouble :: Integer -> Integer -> Integer -> Double
getTripleDouble s e m = wordToDouble w64
where sign = [s == 1]
expt = [e `testBit` i | i <- [10, 9 .. 0]]
mantissa = [m `testBit` i | i <- [51, 50 .. 0]]
positions = [i | (i, b) <- zip [63, 62 .. 0] (sign ++ expt ++ mantissa), b]
w64 = foldr (flip setBit) (0::Word64) positions
constantMap :: String -> String
constantMap n = fromMaybe n (listToMaybe [to | (from, to) <- special, n `elem` from])
where special = [ (["RNE", "roundNearestTiesToEven"], show RoundNearestTiesToEven)
, (["RNA", "roundNearestTiesToAway"], show RoundNearestTiesToAway)
, (["RTP", "roundTowardPositive"], show RoundTowardPositive)
, (["RTN", "roundTowardNegative"], show RoundTowardNegative)
, (["RTZ", "roundTowardZero"], show RoundTowardZero)
]