{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}

module Language.REST.KBO (kbo, kboGTE) where

import           Language.REST.OCAlgebra
import           Language.REST.Op
import           Language.REST.RuntimeTerm as RT
import           Language.REST.SMT
import           Language.REST.Internal.Util

import qualified Data.Map as M

termOps :: RuntimeTerm -> [Op]
termOps :: RuntimeTerm -> [Op]
termOps (App Op
f [RuntimeTerm]
xs) = Op
fOp -> [Op] -> [Op]
forall a. a -> [a] -> [a]
:((RuntimeTerm -> [Op]) -> [RuntimeTerm] -> [Op]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap RuntimeTerm -> [Op]
termOps [RuntimeTerm]
xs)

arityConstraints :: RuntimeTerm -> SMTExpr Bool
arityConstraints :: RuntimeTerm -> SMTExpr Bool
arityConstraints RuntimeTerm
t = Map Op Int -> SMTExpr Bool
forall a. ToSMT a Int => Map a Int -> SMTExpr Bool
toExpr (Map Op Int -> SMTExpr Bool) -> Map Op Int -> SMTExpr Bool
forall a b. (a -> b) -> a -> b
$ Map Op Int -> RuntimeTerm -> Map Op Int
go Map Op Int
forall k a. Map k a
M.empty RuntimeTerm
t where
  go :: M.Map Op Int -> RuntimeTerm -> M.Map Op Int
  go :: Map Op Int -> RuntimeTerm -> Map Op Int
go Map Op Int
m (App Op
f [])  = Op -> Int -> Map Op Int -> Map Op Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Op
f Int
1 Map Op Int
m
  go Map Op Int
m (App Op
f [RuntimeTerm
targ]) = Map Op Int -> RuntimeTerm -> Map Op Int
go (Op -> Int -> Map Op Int -> Map Op Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Op
f Int
1 Map Op Int
m) RuntimeTerm
targ
  go Map Op Int
m (App Op
f [RuntimeTerm]
ts)  = (Map Op Int -> RuntimeTerm -> Map Op Int)
-> Map Op Int -> [RuntimeTerm] -> Map Op Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map Op Int -> RuntimeTerm -> Map Op Int
go (Op -> Int -> Map Op Int -> Map Op Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Op
f Int
0 Map Op Int
m) [RuntimeTerm]
ts

  toExpr :: Map a Int -> SMTExpr Bool
toExpr Map a Int
m = [SMTExpr Bool] -> SMTExpr Bool
And ([SMTExpr Bool] -> SMTExpr Bool) -> [SMTExpr Bool] -> SMTExpr Bool
forall a b. (a -> b) -> a -> b
$ ((a, Int) -> SMTExpr Bool) -> [(a, Int)] -> [SMTExpr Bool]
forall a b. (a -> b) -> [a] -> [b]
map (a, Int) -> SMTExpr Bool
forall a. ToSMT a Int => (a, Int) -> SMTExpr Bool
toConstraint (Map a Int -> [(a, Int)]
forall k a. Map k a -> [(k, a)]
M.toList Map a Int
m)
  toConstraint :: (a, Int) -> SMTExpr Bool
toConstraint (a
sym, Int
n) = a -> SMTExpr Int
forall a b. ToSMT a b => a -> SMTExpr b
toSMT a
sym SMTExpr Int -> SMTExpr Int -> SMTExpr Bool
`smtGTE` (Int -> SMTExpr Int
Const Int
n)


kboGTE :: RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
kboGTE :: RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
kboGTE RuntimeTerm
t RuntimeTerm
u = RuntimeTerm -> SMTExpr Bool
arityConstraints RuntimeTerm
t SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
`smtAnd` RuntimeTerm -> SMTExpr Bool
arityConstraints RuntimeTerm
u SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
`smtAnd` ([Op] -> SMTExpr Int
forall a. ToSMT a Int => [a] -> SMTExpr Int
size [Op]
tOps SMTExpr Int -> SMTExpr Int -> SMTExpr Bool
`smtGTE` [Op] -> SMTExpr Int
forall a. ToSMT a Int => [a] -> SMTExpr Int
size [Op]
uOps)
  where
    ([Op]
tOps, [Op]
uOps) = (Op -> Op -> Bool) -> [Op] -> [Op] -> ([Op], [Op])
forall a. Eq a => (a -> a -> Bool) -> [a] -> [a] -> ([a], [a])
removeEqBy Op -> Op -> Bool
forall a. Eq a => a -> a -> Bool
(==) (RuntimeTerm -> [Op]
termOps RuntimeTerm
t) (RuntimeTerm -> [Op]
termOps RuntimeTerm
u)
    size :: [a] -> SMTExpr Int
size [a]
ops     = [SMTExpr Int] -> SMTExpr Int
smtAdd ((a -> SMTExpr Int) -> [a] -> [SMTExpr Int]
forall a b. (a -> b) -> [a] -> [b]
map a -> SMTExpr Int
forall a b. ToSMT a b => a -> SMTExpr b
toSMT [a]
ops)


kbo :: SolverHandle -> OCAlgebra (SMTExpr Bool) RuntimeTerm IO
kbo :: SolverHandle -> OCAlgebra (SMTExpr Bool) RuntimeTerm IO
kbo SolverHandle
solver = OCAlgebra :: forall c a (m :: * -> *).
(c -> m Bool)
-> (c -> a -> a -> c)
-> c
-> (c -> c -> c)
-> (c -> c -> m Bool)
-> OCAlgebra c a m
OCAlgebra
  {  isSat :: SMTExpr Bool -> IO Bool
isSat           = SolverHandle -> SMTExpr Bool -> IO Bool
checkSat' SolverHandle
solver
  ,  SMTExpr Bool -> RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
refine :: SMTExpr Bool -> RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
refine :: SMTExpr Bool -> RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
refine
  ,  top :: SMTExpr Bool
top             = SMTExpr Bool
smtTrue
  ,  SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
union :: SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
union :: SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
union
  ,  SMTExpr Bool -> SMTExpr Bool -> IO Bool
notStrongerThan :: SMTExpr Bool -> SMTExpr Bool -> IO Bool
notStrongerThan :: SMTExpr Bool -> SMTExpr Bool -> IO Bool
notStrongerThan
  }
  where
    union :: SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
union  SMTExpr Bool
e1 SMTExpr Bool
e2          = [SMTExpr Bool] -> SMTExpr Bool
Or [SMTExpr Bool
e1, SMTExpr Bool
e2]
    refine :: SMTExpr Bool -> RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
refine SMTExpr Bool
e RuntimeTerm
t RuntimeTerm
u          = SMTExpr Bool
e SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
`smtAnd` RuntimeTerm -> RuntimeTerm -> SMTExpr Bool
kboGTE RuntimeTerm
t RuntimeTerm
u
    notStrongerThan :: SMTExpr Bool -> SMTExpr Bool -> IO Bool
notStrongerThan SMTExpr Bool
e1 SMTExpr Bool
e2 = SolverHandle -> SMTExpr Bool -> IO Bool
checkSat' SolverHandle
solver (SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
Implies SMTExpr Bool
e2 SMTExpr Bool
e1)