module Haskus.Utils.Solver
(
PredState (..)
, PredOracle
, makeOracle
, oraclePredicates
, emptyOracle
, predIsSet
, predIsUnset
, predIsUndef
, predIs
, predState
, Constraint (..)
, simplifyConstraint
, constraintReduce
, Rule (..)
, orderedNonTerminal
, mergeRules
, evalsTo
, MatchResult (..)
, Predicated (..)
, createPredicateTable
, initP
, applyP
, resultP
)
where
import Haskus.Utils.Maybe
import Haskus.Utils.Flow
import Haskus.Utils.List
import Haskus.Utils.Map.Strict (Map)
import qualified Haskus.Utils.Map.Strict as Map
import Data.Bits
import Control.Arrow (first,second)
import Prelude hiding (pred)
data PredState
= SetPred
| UnsetPred
| UndefPred
deriving (Show,Eq,Ord)
type PredOracle p = Map p PredState
predIsSet :: Ord p => PredOracle p -> p -> Bool
predIsSet oracle p = predIs oracle p SetPred
predIsUnset :: Ord p => PredOracle p -> p -> Bool
predIsUnset oracle p = predIs oracle p UnsetPred
predIsUndef :: Ord p => PredOracle p -> p -> Bool
predIsUndef oracle p = predIs oracle p UndefPred
predIs :: Ord p => PredOracle p -> p -> PredState -> Bool
predIs oracle p s = predState oracle p == s
predState :: Ord p => PredOracle p -> p -> PredState
predState oracle p = case p `Map.lookup` oracle of
Just s -> s
Nothing -> UndefPred
makeOracle :: Ord p => [(p,PredState)] -> PredOracle p
makeOracle = Map.fromList
oraclePredicates :: Ord p => PredOracle p -> [(p,PredState)]
oraclePredicates = filter (\(_,s) -> s /= UndefPred) . Map.toList
emptyOracle :: PredOracle p
emptyOracle = Map.empty
data Constraint e p
= Predicate p
| Not (Constraint e p)
| And [Constraint e p]
| Or [Constraint e p]
| Xor [Constraint e p]
| CBool Bool
deriving (Show,Eq,Ord)
instance Functor (Constraint e) where
fmap f (Predicate p) = Predicate (f p)
fmap _ (CBool b) = CBool b
fmap f (Not c) = Not (fmap f c)
fmap f (And cs) = And (fmap (fmap f) cs)
fmap f (Or cs) = Or (fmap (fmap f) cs)
fmap f (Xor cs) = Xor (fmap (fmap f) cs)
constraintReduce :: (Ord p, Eq p, Eq e) => PredOracle p -> Constraint e p -> Constraint e p
constraintReduce oracle c = case simplifyConstraint c of
Predicate p -> case predState oracle p of
UndefPred -> Predicate p
SetPred -> CBool True
UnsetPred -> CBool False
Not c' -> case constraintReduce oracle c' of
CBool v -> CBool (not v)
c'' -> Not c''
And cs -> case fmap (constraintReduce oracle) cs of
[] -> error "Empty And constraint"
cs' | all (constraintIsBool True) cs' -> CBool True
cs' | any (constraintIsBool False) cs' -> CBool False
cs' -> case filter (not . constraintIsBool True) cs' of
[c'] -> c'
cs'' -> And cs''
Or cs -> case fmap (constraintReduce oracle) cs of
[] -> error "Empty Or constraint"
cs' | all (constraintIsBool False) cs' -> CBool False
cs' | any (constraintIsBool True) cs' -> CBool True
cs' -> case filter (not . constraintIsBool False) cs' of
[c'] -> c'
cs'' -> Or cs''
Xor cs -> case fmap (constraintReduce oracle) cs of
[] -> error "Empty Xor constraint"
cs' -> simplifyConstraint (Xor cs')
c'@(CBool _) -> c'
constraintIsBool :: Bool -> Constraint e p -> Bool
constraintIsBool v (CBool v') = v == v'
constraintIsBool _ _ = False
getConstraintPredicates :: Constraint e p -> [p]
getConstraintPredicates = \case
Predicate p -> [p]
Not c -> getConstraintPredicates c
And cs -> concatMap getConstraintPredicates cs
Or cs -> concatMap getConstraintPredicates cs
Xor cs -> concatMap getConstraintPredicates cs
CBool _ -> []
getConstraintTerminals :: Constraint e p -> [Bool]
getConstraintTerminals = \case
Predicate _ -> [True,False]
CBool v -> [v]
Not c -> fmap not (getConstraintTerminals c)
And cs -> let cs' = fmap getConstraintTerminals cs
in if | null cs -> []
| any (False `elem`) cs' -> [False]
| all (sing True) cs' -> [True]
| otherwise -> [True,False]
Or cs -> let cs' = fmap getConstraintTerminals cs
in if | null cs -> []
| any (True `elem`) cs' -> [True]
| all (sing False) cs' -> [False]
| otherwise -> [True,False]
Xor cs -> let cs' = fmap getConstraintTerminals cs
in if | null cs -> []
| otherwise -> xo False cs'
where
xo t [] = [t]
xo False ([True]:xs) = xo True xs
xo True ([True]:_) = [False]
xo False ([False]:xs) = xo False xs
xo True ([False]:xs) = xo True xs
xo _ ([]:_) = []
xo _ _ = [True,False]
sing v [v'] = v == v'
sing _ _ = False
data Rule e p a
= Terminal a
| NonTerminal [(Constraint e p, Rule e p a)]
| Fail e
deriving (Show,Eq,Ord)
instance Functor (Rule e p) where
fmap f (Terminal a) = Terminal (f a)
fmap f (NonTerminal xs) = NonTerminal (fmap (second (fmap f)) xs)
fmap _ (Fail e) = Fail e
orderedNonTerminal :: [(Constraint e p, Rule e p a)] -> Rule e p a
orderedNonTerminal = NonTerminal . go []
where
go _ [] = []
go [] ((c,r):xs) = (simplifyConstraint c,r) : go [c] xs
go cs ((c,r):xs) = (simplifyConstraint (And (c:fmap Not cs)),r) : go (c:cs) xs
simplifyConstraint :: Constraint e p -> Constraint e p
simplifyConstraint x = case x of
Predicate _ -> x
CBool _ -> x
Not (Predicate _) -> x
Not (CBool v) -> CBool (not v)
Not (Not c) -> simplifyConstraint c
Not (Or cs) -> simplifyConstraint (And (fmap Not cs))
Not (And cs) -> simplifyConstraint (Or (fmap Not cs))
Not (Xor cs) -> case simplifyConstraint (Xor cs) of
Xor cs' -> Not (Xor cs')
r -> simplifyConstraint (Not r)
And [c] -> simplifyConstraint c
Or [c] -> simplifyConstraint c
Xor [c] -> let c' = simplifyConstraint c
in if | constraintIsBool True c' -> CBool True
| constraintIsBool False c' -> CBool False
| otherwise -> c'
And cs -> let cs' = fmap simplifyConstraint cs
in if | any (constraintIsBool False) cs' -> CBool False
| all (constraintIsBool True) cs' -> CBool True
| otherwise -> And cs'
Or cs -> let cs' = fmap simplifyConstraint cs
in if | any (constraintIsBool True) cs' -> CBool True
| all (constraintIsBool False) cs' -> CBool False
| otherwise -> Or cs'
Xor cs -> let cs' = fmap simplifyConstraint cs
countTrue = length (filter (constraintIsBool True) cs')
countFalse = length (filter (constraintIsBool False) cs')
countAll = length cs'
in if | countTrue > 1 -> CBool False
| countTrue == 1 && countTrue + countFalse == countAll -> CBool True
| countAll == countFalse -> CBool False
| otherwise -> Xor cs'
mergeRules :: Rule e p a -> Rule e p b -> Rule e p (a,b)
mergeRules = go
where
go (Fail e) _ = Fail e
go _ (Fail e) = Fail e
go (Terminal a) (Terminal b) = Terminal (a,b)
go (Terminal a) (NonTerminal bs) = NonTerminal (fl (Terminal a) bs)
go (NonTerminal as) (Terminal b) = NonTerminal (fr (Terminal b) as)
go (NonTerminal as) b = NonTerminal (fr b as)
fl x = fmap (second (x `mergeRules`))
fr x = fmap (second (`mergeRules` x))
ruleReduce :: forall e p a.
( Ord p, Eq e, Eq p, Eq a) => PredOracle p -> Rule e p a -> MatchResult e (Rule e p a) a
ruleReduce oracle r = case r of
Terminal a -> Match a
Fail e -> MatchFail [e]
NonTerminal rs ->
let
rs' :: [(Constraint e p, Rule e p a)]
rs' = rs
|> fmap (first (constraintReduce oracle))
|> filter (not . constraintIsBool False . fst)
(matchingRules,mayMatchRules) = partition (constraintIsBool True . fst) rs'
matchingResults = nub $ fmap snd $ matchingRules
(failingResults,terminalResults,nonTerminalResults) = go [] [] [] matchingResults
go fr tr ntr = \case
[] -> (fr,tr,ntr)
(Fail x:xs) -> go (x:fr) tr ntr xs
(Terminal x:xs) -> go fr (x:tr) ntr xs
(NonTerminal x:xs) -> go fr tr (x:ntr) xs
divergence = case terminalResults of
(_:_:_) -> True
_ -> False
in
case rs' of
[] -> NoMatch
_ | not (null failingResults) -> MatchFail failingResults
| divergence -> MatchDiverge (fmap Terminal terminalResults)
| not (null nonTerminalResults) ->
ruleReduce oracle
<| NonTerminal
<| (fmap (\x -> (CBool True, Terminal x)) terminalResults
++ mayMatchRules
++ concat nonTerminalResults)
| otherwise ->
case (matchingResults,mayMatchRules) of
([Terminal a], []) -> Match a
_ -> DontMatch (NonTerminal rs')
getRuleTerminals :: Rule e p a -> [a]
getRuleTerminals (Fail _) = []
getRuleTerminals (Terminal a) = [a]
getRuleTerminals (NonTerminal xs) = concatMap (getRuleTerminals . snd) xs
getRulePredicates :: Eq p => Rule e p a -> [p]
getRulePredicates (Fail _) = []
getRulePredicates (Terminal _) = []
getRulePredicates (NonTerminal xs) = nub $ concatMap (\(x,y) -> getConstraintPredicates x ++ getRulePredicates y) xs
evalsTo :: (Ord (Pred a), Eq a, Eq (PredTerm a), Eq (Pred a), Predicated a) => a -> PredTerm a -> Constraint e (Pred a)
evalsTo s a = case createPredicateTable s (const True) True of
Left x -> CBool (x == a)
Right xs -> orConstraints <| fmap andPredicates
<| fmap oraclePredicates
<| fmap fst
<| filter ((== a) . snd)
<| xs
where
andPredicates [] = CBool True
andPredicates [x] = makePred x
andPredicates xs = And (fmap makePred xs)
orConstraints [] = CBool True
orConstraints [x] = x
orConstraints xs = Or xs
makePred (p, UnsetPred) = Not (Predicate p)
makePred (p, SetPred) = Predicate p
makePred (_, UndefPred) = undefined
class Predicated a where
type PredErr a :: *
type Pred a :: *
type PredTerm a :: *
liftTerminal :: PredTerm a -> a
reducePredicates :: PredOracle (Pred a) -> a -> MatchResult (PredErr a) a (PredTerm a)
getTerminals :: a -> [PredTerm a]
getPredicates :: a -> [Pred a]
instance (Ord p, Eq e, Eq a, Eq p) => Predicated (Rule e p a) where
type PredErr (Rule e p a) = e
type Pred (Rule e p a) = p
type PredTerm (Rule e p a) = a
reducePredicates = ruleReduce
liftTerminal = Terminal
getTerminals = getRuleTerminals
getPredicates = getRulePredicates
instance (Ord p, Eq e, Eq p) => Predicated (Constraint e p) where
type PredErr (Constraint e p) = e
type Pred (Constraint e p) = p
type PredTerm (Constraint e p) = Bool
reducePredicates oracle c = case constraintReduce oracle c of
CBool v -> Match v
c' -> DontMatch c'
liftTerminal = CBool
getTerminals = getConstraintTerminals
getPredicates = getConstraintPredicates
data MatchResult e nt t
= NoMatch
| Match t
| DontMatch nt
| MatchFail [e]
| MatchDiverge [nt]
deriving (Show,Eq,Ord)
instance Functor (MatchResult e nt) where
fmap f x = case x of
NoMatch -> NoMatch
MatchDiverge xs -> MatchDiverge xs
MatchFail es -> MatchFail es
Match a -> Match (f a)
DontMatch a -> DontMatch a
applyP ::
( Predicated ntb
) => MatchResult e (ntb -> nt) (ntb -> nt, PredTerm ntb -> t) -> MatchResult e ntb (PredTerm ntb) -> MatchResult e nt (nt,t)
applyP NoMatch _ = NoMatch
applyP _ NoMatch = NoMatch
applyP (MatchFail xs) (MatchFail ys) = MatchFail (xs++ys)
applyP (MatchFail xs) _ = MatchFail xs
applyP _ (MatchFail ys) = MatchFail ys
applyP (MatchDiverge fs) (MatchDiverge ys) = MatchDiverge [f y | f <- fs, y <- ys]
applyP (MatchDiverge fs) (Match b) = MatchDiverge [f (liftTerminal b) | f <- fs]
applyP (MatchDiverge fs) (DontMatch b) = MatchDiverge [f b | f <- fs]
applyP (DontMatch f) (MatchDiverge ys) = MatchDiverge [f y | y <- ys]
applyP (DontMatch f) (DontMatch b) = DontMatch (f b)
applyP (DontMatch f) (Match b) = DontMatch (f (liftTerminal b))
applyP (Match (fnt,_)) (MatchDiverge ys) = MatchDiverge [fnt y | y <- ys]
applyP (Match (fnt,_)) (DontMatch b) = DontMatch (fnt b)
applyP (Match (fnt,ft)) (Match b) = Match (fnt (liftTerminal b), ft b)
initP :: nt -> t -> MatchResult e nt (nt,t)
initP nt t = Match (nt,t)
resultP :: MatchResult e nt (nt,t) -> MatchResult e nt t
resultP = fmap snd
createPredicateTable ::
( Ord (Pred a)
, Eq (Pred a)
, Eq a
, Predicated a
, Predicated a
, Pred a ~ Pred a
) => a -> (PredOracle (Pred a) -> Bool) -> Bool -> Either (PredTerm a) [(PredOracle (Pred a),PredTerm a)]
createPredicateTable s oracleChecker fullTable =
case reducePredicates emptyOracle s of
Match x -> Left x
_ -> Right (mapMaybe matching oracles)
where
matching oracle = case reducePredicates oracle s of
Match x -> Just (oracle,x)
_ -> Nothing
oracles = filter oracleChecker (fmap makeOracle predSets)
preds = sort (getPredicates s)
predSets
| fullTable = makeFullSets preds
| otherwise = makeSets preds []
makeFullSets ps = fmap (makeFullSet ps) ([0..2^(length ps)1] :: [Word])
makeFullSet ps n = fmap (setB n) (ps `zip` [0..])
setB n (p,i) = if testBit n i
then (p,SetPred)
else (p,UnsetPred)
makeSets [] os = os
makeSets (p:ps) os = let ns = [(p,SetPred),(p,UnsetPred)]
in makeSets ps $ concat
[ [ [n] | n <- ns ]
, [(n:o) | o <- os, n <- ns]
, os
]