{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiWayIf #-}
module Haskus.Utils.Solver
(
PredState (..)
, PredOracle
, makeOracle
, oraclePredicates
, emptyOracle
, oracleUnion
, predIsSet
, predIsUnset
, predIsUndef
, predIsInvalid
, predIs
, predState
, predAdd
, Constraint (..)
, constraintOptimize
, constraintSimplify
, Rule (..)
, ruleSimplify
, evalsTo
, MatchResult (..)
, Predicated (..)
, createPredicateTable
, initP
, applyP
, resultP
)
where
import Haskus.Utils.Maybe
import Haskus.Utils.Flow
import Haskus.Utils.List
import Haskus.Utils.Map.Strict (Map)
import qualified Haskus.Utils.Map.Strict as Map
import Control.Arrow (first,second)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (pred,length)
data PredState
= SetPred
| UnsetPred
| UndefPred
| InvalidPred
deriving (Show,Eq,Ord)
type PredOracle p = Map p PredState
predIsSet :: Ord p => PredOracle p -> p -> Bool
predIsSet oracle p = predIs oracle p SetPred
predIsUnset :: Ord p => PredOracle p -> p -> Bool
predIsUnset oracle p = predIs oracle p UnsetPred
predIsUndef :: Ord p => PredOracle p -> p -> Bool
predIsUndef oracle p = predIs oracle p UndefPred
predIsInvalid :: Ord p => PredOracle p -> p -> Bool
predIsInvalid oracle p = predIs oracle p InvalidPred
predIs :: Ord p => PredOracle p -> p -> PredState -> Bool
predIs oracle p s = predState oracle p == s
predState :: Ord p => PredOracle p -> p -> PredState
predState oracle p = case p `Map.lookup` oracle of
Just s -> s
Nothing -> UndefPred
makeOracle :: Ord p => [(p,PredState)] -> PredOracle p
makeOracle = Map.fromList
oraclePredicates :: Ord p => PredOracle p -> [(p,PredState)]
oraclePredicates = filter (\(_,s) -> s /= UndefPred) . Map.toList
oracleUnion :: Ord p => PredOracle p -> PredOracle p -> PredOracle p
oracleUnion = Map.union
predAdd :: Ord p => [(p,PredState)] -> PredOracle p -> PredOracle p
predAdd cs = oracleUnion (makeOracle cs)
emptyOracle :: PredOracle p
emptyOracle = Map.empty
data Constraint e p
= Predicate p
| IsValid p
| Not (Constraint e p)
| And [Constraint e p]
| Or [Constraint e p]
| Xor [Constraint e p]
| CBool Bool
| CErr (Either String e)
deriving (Show,Eq,Ord)
instance Functor (Constraint e) where
fmap f (Predicate p) = Predicate (f p)
fmap f (IsValid p) = IsValid (f p)
fmap _ (CBool b) = CBool b
fmap f (Not c) = Not (fmap f c)
fmap f (And cs) = And (fmap (fmap f) cs)
fmap f (Or cs) = Or (fmap (fmap f) cs)
fmap f (Xor cs) = Xor (fmap (fmap f) cs)
fmap _ (CErr e) = CErr e
constraintSimplify :: (Ord p, Eq p, Eq e) => PredOracle p -> Constraint e p -> Constraint e p
constraintSimplify oracle c = case constraintOptimize c of
CErr e -> CErr e
IsValid p -> case predState oracle p of
UndefPred -> IsValid p
InvalidPred -> CBool False
SetPred -> CBool True
UnsetPred -> CBool True
Predicate p -> case predState oracle p of
UndefPred -> Predicate p
InvalidPred -> CErr (Left "Invalid predicate")
SetPred -> CBool True
UnsetPred -> CBool False
Not c' -> case constraintSimplify oracle c' of
CBool v -> CBool (not v)
CErr e -> CErr e
c'' -> Not c''
And cs -> case fmap (constraintSimplify oracle) cs of
[] -> CErr (Left "Empty And constraint")
cs' | all (constraintIsBool True) cs' -> CBool True
cs' | any (constraintIsBool False) cs' -> CBool False
cs' | all constraintIsError cs' -> CErr (Left "And expression only contains Error constraints")
cs' -> case filter (not . constraintIsBool True) cs' of
[c'] -> c'
cs'' -> And cs''
Or cs -> case filter (not . constraintIsError) <| fmap (constraintSimplify oracle) cs of
[] -> CErr (Left "Empty Or constraint")
cs' | all (constraintIsBool False) cs' -> CBool False
cs' | any (constraintIsBool True) cs' -> CBool True
cs' -> case filter (not . constraintIsBool False) cs' of
[c'] -> c'
cs'' -> Or cs''
Xor cs -> case fmap (constraintSimplify oracle) cs of
cs' | any constraintIsError cs' -> CErr (Left "Xor expression contains Error constraint")
[] -> CErr (Left "Empty Xor constraint")
cs' -> constraintOptimize (Xor cs')
c'@(CBool _) -> c'
constraintIsBool :: Bool -> Constraint e p -> Bool
constraintIsBool v (CBool v') = v == v'
constraintIsBool _ _ = False
constraintIsError :: Constraint e p -> Bool
constraintIsError (CErr _) = True
constraintIsError _ = False
getConstraintPredicates :: Ord p => Constraint e p -> Set p
getConstraintPredicates = \case
CErr _ -> Set.empty
IsValid p -> Set.singleton p
Predicate p -> Set.singleton p
Not c -> getConstraintPredicates c
And cs -> Set.unions $ fmap getConstraintPredicates cs
Or cs -> Set.unions $ fmap getConstraintPredicates cs
Xor cs -> Set.unions $ fmap getConstraintPredicates cs
CBool _ -> Set.empty
getConstraintTerminals :: Constraint e p -> Set Bool
getConstraintTerminals = \case
CErr _ -> Set.empty
IsValid _ -> tf
Predicate _ -> tf
CBool v -> Set.singleton v
Not c -> Set.map not (getConstraintTerminals c)
And cs -> let cs' = fmap getConstraintTerminals cs
in if | null cs -> Set.empty
| any (False `elem`) cs' -> Set.singleton False
| all (== Set.singleton True) cs' -> Set.singleton True
| otherwise -> tf
Or cs -> let cs' = fmap getConstraintTerminals cs
in if | null cs -> Set.empty
| any (True `elem`) cs' -> Set.singleton True
| all (== Set.singleton False) cs' -> Set.singleton False
| otherwise -> tf
Xor cs -> let cs' = fmap (Set.toList . getConstraintTerminals) cs
in if | null cs -> Set.empty
| otherwise -> xo False cs'
where
tf = Set.fromList [True,False]
xo t [] = Set.singleton t
xo False ([True]:xs) = xo True xs
xo True ([True]:_) = Set.singleton False
xo False ([False]:xs) = xo False xs
xo True ([False]:xs) = xo True xs
xo _ ([]:_) = Set.empty
xo _ _ = tf
constraintOptimize :: Constraint e p -> Constraint e p
constraintOptimize x = case x of
CErr _ -> x
Not (CErr e) -> CErr e
IsValid _ -> x
Predicate _ -> x
CBool _ -> x
Not (IsValid _) -> x
Not (Predicate _) -> x
Not (CBool v) -> CBool (not v)
Not (Not c) -> constraintOptimize c
Not (Or cs) -> constraintOptimize (And (fmap Not cs))
Not (And cs) -> constraintOptimize (Or (fmap Not cs))
Not (Xor cs) -> case constraintOptimize (Xor cs) of
Xor cs' -> Not (Xor cs')
r -> constraintOptimize (Not r)
And [c] -> constraintOptimize c
Or [c] -> constraintOptimize c
Xor [c] -> let c' = constraintOptimize c
in if | constraintIsBool True c' -> CBool True
| constraintIsBool False c' -> CBool False
| otherwise -> c'
And cs -> let cs' = fmap constraintOptimize cs
in if | any (constraintIsBool False) cs' -> CBool False
| all (constraintIsBool True) cs' -> CBool True
| otherwise -> And cs'
Or cs -> let cs' = fmap constraintOptimize cs
in if | any (constraintIsBool True) cs' -> CBool True
| all (constraintIsBool False) cs' -> CBool False
| otherwise -> Or cs'
Xor cs -> let cs' = fmap constraintOptimize cs
countTrue = length (filter (constraintIsBool True) cs')
countFalse = length (filter (constraintIsBool False) cs')
countAll = length cs'
in if | countTrue > 1 -> CBool False
| countTrue == 1 && countTrue + countFalse == countAll -> CBool True
| countAll == countFalse -> CBool False
| otherwise -> Xor cs'
data Rule e p a
= Terminal a
| OrderedNonTerminal [(Constraint e p, Rule e p a)]
| NonTerminal [(Constraint e p, Rule e p a)]
| Fail e
deriving (Show,Eq,Ord)
instance Functor (Rule e p) where
fmap f (Terminal a) = Terminal (f a)
fmap f (NonTerminal xs) = NonTerminal (fmap (second (fmap f)) xs)
fmap f (OrderedNonTerminal xs) = OrderedNonTerminal (fmap (second (fmap f)) xs)
fmap _ (Fail e) = Fail e
ruleSimplify ::
( Ord p, Eq e
) => PredOracle p -> Rule e p a -> Rule e p a
ruleSimplify oracle r = case r of
Terminal a -> Terminal a
Fail e -> Fail e
OrderedNonTerminal rs -> OrderedNonTerminal (simplifyNonTerminal rs)
NonTerminal rs -> NonTerminal (concatMap foldNonTerminal (simplifyNonTerminal rs))
where
simplifyNonTerminal xs = xs
|> fmap (first (constraintSimplify oracle))
|> fmap (second (ruleSimplify oracle))
|> filter (not . constraintIsBool False . fst)
foldNonTerminal (c, NonTerminal rs)
| constraintIsBool True c = rs
foldNonTerminal x = [x]
ruleReduce :: forall e p a.
( Ord p, Eq e, Eq p, Eq a) => PredOracle p -> Rule e p a -> MatchResult e (Rule e p a) a
ruleReduce oracle r = case ruleSimplify oracle r of
Terminal a -> Match a
Fail e -> MatchFail [e]
NonTerminal [] -> NoMatch
OrderedNonTerminal [] -> NoMatch
OrderedNonTerminal ((c,x):xs)
| constraintIsBool True c -> ruleReduce oracle x
| constraintIsBool False c -> ruleReduce oracle (OrderedNonTerminal xs)
| otherwise -> DontMatch (OrderedNonTerminal ((c,x):xs))
NonTerminal rs ->
let
(matchingRules,mayMatchRules) = partition (constraintIsBool True . fst) rs
matchingResults = nub $ fmap snd $ matchingRules
(failingResults,terminalResults,hasNonTerminalResults) = go [] [] False matchingResults
go fr tr ntr = \case
[] -> (fr,tr,ntr)
(Fail x:xs) -> go (x:fr) tr ntr xs
(Terminal x:xs) -> go fr (x:tr) ntr xs
(NonTerminal _:xs) -> go fr tr True xs
(OrderedNonTerminal _:xs) -> go fr tr True xs
divergence = case terminalResults of
(_:_:_) -> True
_ -> False
in
if | not (null failingResults) -> MatchFail failingResults
| divergence -> MatchDiverge (fmap Terminal terminalResults)
| hasNonTerminalResults -> DontMatch (NonTerminal rs)
| otherwise ->
case (terminalResults,mayMatchRules) of
([a], []) -> Match a
_ -> DontMatch (NonTerminal rs)
getRuleTerminals :: Ord a => Rule e p a -> Set a
getRuleTerminals (Fail _) = Set.empty
getRuleTerminals (Terminal a) = Set.singleton a
getRuleTerminals (NonTerminal xs) = Set.unions (fmap (getRuleTerminals . snd) xs)
getRuleTerminals (OrderedNonTerminal xs) = Set.unions (fmap (getRuleTerminals . snd) xs)
getRulePredicates :: (Eq p,Ord p) => Rule e p a -> Set p
getRulePredicates (Fail _) = Set.empty
getRulePredicates (Terminal _) = Set.empty
getRulePredicates (NonTerminal xs) = Set.unions $ fmap (\(x,y) -> getConstraintPredicates x `Set.union` getRulePredicates y) xs
getRulePredicates (OrderedNonTerminal xs) = Set.unions $ fmap (\(x,y) -> getConstraintPredicates x `Set.union` getRulePredicates y) xs
evalsTo :: (Ord (Pred a), Eq a, Eq (PredTerm a), Eq (Pred a), Predicated a) => a -> PredTerm a -> Constraint e (Pred a)
evalsTo s a = case createPredicateTable s (const True) of
Left x -> CBool (x == a)
Right xs -> orConstraints <| fmap andPredicates
<| fmap oraclePredicates
<| fmap fst
<| filter ((== a) . snd)
<| xs
where
andPredicates [] = CBool True
andPredicates xs = And (concatMap makePred xs)
orConstraints [] = CBool True
orConstraints [x] = x
orConstraints xs = Or xs
makePred (p, UnsetPred) = [IsValid p, Not (Predicate p)]
makePred (p, SetPred) = [IsValid p, Predicate p]
makePred (p, InvalidPred) = [Not (IsValid p)]
makePred (_, UndefPred) = undefined
class (Ord (Pred a), Ord (PredTerm a)) => Predicated a where
type PredErr a :: *
type Pred a :: *
type PredTerm a :: *
liftTerminal :: PredTerm a -> a
reducePredicates :: PredOracle (Pred a) -> a -> MatchResult (PredErr a) a (PredTerm a)
simplifyPredicates :: PredOracle (Pred a) -> a -> a
getTerminals :: a -> Set (PredTerm a)
getPredicates :: a -> Set (Pred a)
instance (Ord a, Ord p, Eq e, Eq a, Eq p) => Predicated (Rule e p a) where
type PredErr (Rule e p a) = e
type Pred (Rule e p a) = p
type PredTerm (Rule e p a) = a
reducePredicates = ruleReduce
simplifyPredicates = ruleSimplify
liftTerminal = Terminal
getTerminals = getRuleTerminals
getPredicates = getRulePredicates
instance (Ord p, Eq e, Eq p) => Predicated (Constraint e p) where
type PredErr (Constraint e p) = e
type Pred (Constraint e p) = p
type PredTerm (Constraint e p) = Bool
reducePredicates oracle c = case constraintSimplify oracle c of
CBool v -> Match v
c' -> DontMatch c'
simplifyPredicates oracle c = constraintSimplify oracle c
liftTerminal = CBool
getTerminals = getConstraintTerminals
getPredicates = getConstraintPredicates
instance forall x y.
( Predicated x
, Predicated y
, PredErr x ~ PredErr y
, Pred x ~ Pred y
) => Predicated (x,y)
where
type PredErr (x,y) = PredErr x
type Pred (x,y) = Pred x
type PredTerm (x,y) = (PredTerm x, PredTerm y)
reducePredicates oracle (x,y) =
initP (,) (,)
|> (`applyP` reducePredicates oracle x)
|> (`applyP` reducePredicates oracle y)
|> resultP
simplifyPredicates oracle (x,y) = (simplifyPredicates oracle x, simplifyPredicates oracle y)
liftTerminal (x,y) = (liftTerminal x, liftTerminal y)
getTerminals (x,y) = Set.fromList
[ (x',y') | x' <- Set.toList (getTerminals x)
, y' <- Set.toList (getTerminals y)
]
getPredicates (x,y) = Set.union (getPredicates x) (getPredicates y)
data MatchResult e nt t
= NoMatch
| Match t
| DontMatch nt
| MatchFail [e]
| MatchDiverge [nt]
deriving (Show,Eq,Ord)
instance Functor (MatchResult e nt) where
fmap f x = case x of
NoMatch -> NoMatch
MatchDiverge xs -> MatchDiverge xs
MatchFail es -> MatchFail es
Match a -> Match (f a)
DontMatch a -> DontMatch a
applyP ::
( Predicated ntb
) => MatchResult e (ntb -> nt) (ntb -> nt, PredTerm ntb -> t) -> MatchResult e ntb (PredTerm ntb) -> MatchResult e nt (nt,t)
applyP NoMatch _ = NoMatch
applyP _ NoMatch = NoMatch
applyP (MatchFail xs) (MatchFail ys) = MatchFail (xs++ys)
applyP (MatchFail xs) _ = MatchFail xs
applyP _ (MatchFail ys) = MatchFail ys
applyP (MatchDiverge fs) (MatchDiverge ys) = MatchDiverge [f y | f <- fs, y <- ys]
applyP (MatchDiverge fs) (Match b) = MatchDiverge [f (liftTerminal b) | f <- fs]
applyP (MatchDiverge fs) (DontMatch b) = MatchDiverge [f b | f <- fs]
applyP (DontMatch f) (MatchDiverge ys) = MatchDiverge [f y | y <- ys]
applyP (DontMatch f) (DontMatch b) = DontMatch (f b)
applyP (DontMatch f) (Match b) = DontMatch (f (liftTerminal b))
applyP (Match (fnt,_)) (MatchDiverge ys) = MatchDiverge [fnt y | y <- ys]
applyP (Match (fnt,_)) (DontMatch b) = DontMatch (fnt b)
applyP (Match (fnt,ft)) (Match b) = Match (fnt (liftTerminal b), ft b)
initP :: nt -> t -> MatchResult e nt (nt,t)
initP nt t = Match (nt,t)
resultP :: MatchResult e nt (nt,t) -> MatchResult e nt t
resultP = fmap snd
createPredicateTable ::
( Ord (Pred a)
, Eq (Pred a)
, Eq a
, Predicated a
, Predicated a
, Pred a ~ Pred a
) => a -> (PredOracle (Pred a) -> Bool) -> Either (PredTerm a) [(PredOracle (Pred a),PredTerm a)]
createPredicateTable s oracleChecker =
case reducePredicates emptyOracle s of
Match x -> Left x
_ -> Right (mapMaybe matching oracles)
where
matching oracle = case reducePredicates oracle s of
Match x -> Just (oracle,x)
_ -> Nothing
oracles = filter oracleChecker (fmap makeOracle predSets)
preds = Set.toList (getPredicates (simplifyPredicates emptyOracle s))
predSets = makeSets preds [[]]
makeSets [] os = os
makeSets (p:ps) os = let ns = [(p,SetPred),(p,UnsetPred),(p,UndefPred)]
in makeSets ps [(n:o) | o <- os, n <- ns]