{-# LANGUAGE MagicHash #-}
module Bayes.Factor
(
Factor
, Dimensions, dimensions, fromDimensions, hasVarD, mergesD
, Dimensional(..)
, makeFactor
, values
, multiply, sumout, sumouts, condition, conditions, normalize
, eliminate, eliminateSplit, eliminateList
) where
import Data.List
import Data.Maybe
import Data.Semigroup
import GHC.Exts
import qualified Data.Set as S
data Factor = F { dimensions :: Dimensions, getTree :: Tree }
instance Show Factor where
show (F d t) = intercalate "," (map showDim (fromDimensions d)) ++ ": " ++ show t
where
showDim (s, n) = s ++ "[" ++ show n ++ "]"
instance Semigroup Factor where
(<>) = multiply
instance Monoid Factor where
mempty = makeFactor [] [1]
mappend = (<>)
mconcat = multiplyList
makeFactor :: [(String, Int)] -> [Double] -> Factor
makeFactor d bs = F (mkD dims) (reorder dims d $ makeTree (map snd d) bs)
where
dims = sort d
makeTree :: [Int] -> [Double] -> Tree
makeTree is = rec (reverse is) . map leaf
where
rec (2:ns) ts = rec ns $ groups2 ts
rec (3:ns) ts = rec ns $ groups3 ts
rec (n:ns) ts = rec ns $ groupsN n ts
rec [] [t] = t
rec _ _ = error "makeFactor"
{-# INLINE groups2 #-}
groups2 :: [Tree] -> [Tree]
groups2 [] = []
groups2 (x:y:rest) = Two x y : groups2 rest
groups2 _ = error "groups2"
{-# INLINE groups3 #-}
groups3 :: [Tree] -> [Tree]
groups3 [] = []
groups3 (x:y:z:rest) = Three x y z : groups3 rest
groups3 _ = error "groups3"
{-# INLINE groupsN #-}
groupsN :: Int -> [Tree] -> [Tree]
groupsN n = rec
where
rec xs
| null xs = []
| otherwise = Bin xs1 : rec xs2
where
(xs1, xs2) = splitAt n xs
condition :: String -> Int -> Factor -> Factor
condition s i (F d t) = F (deleteD s d) (set s i d t)
conditions :: [(String, Int)] -> Factor -> Factor
conditions xs (F d t) = F (deletesD (map fst xs) d) (choose xs d t)
multiplyList :: [Factor] -> Factor
multiplyList = rec
where
rec [] = mempty
rec [a] = a
rec [a,b] = multiply a b
rec as = rec (best:rest)
where
m:ms = sortBy cmp as
cmp x y = size x `compare` size y
(best, rest) = minimumBy (\x y -> size (fst x) `compare` size (fst y)) $ map make [0 .. length ms-1]
make i = (multiply m y, xs ++ ys)
where
(xs, y:ys) = splitAt i ms
multiply :: Factor -> Factor -> Factor
multiply (F d1 t1) (F d2 t2) = F (mergeD d1 d2) (mergeWith (*##) d1 d2 t1 t2)
sumouts :: [String] -> Factor -> Factor
sumouts vs0 (F dt t0) = F (deletesD vs0 dt) (rec (dropTail bools) t0)
where
ds = map fst (fromDimensions dt)
bools = f (filter (`elem` ds) vs0) ds
f [] _ = []
f vs (x:xs)
| x `elem` vs = True : f (delete x vs) xs
| otherwise = False : f vs xs
f _ _ = []
rec [] t = t
rec (b:bs) t
| b = foldr1 (zipTree (+##)) (subtrees rt)
| otherwise = rt
where
rt = mapSubtrees (rec bs) t
dropTail = reverse . dropWhile not . reverse
sumout :: String -> Factor -> Factor
sumout = sumouts . return
values :: Factor -> [Double]
values = treeToList . getTree
normalize :: Factor -> Factor
normalize x =
case sum (values x) of
D# total -> x { getTree = mapTree (/## total) (getTree x) }
eliminate :: [Factor] -> String -> [Factor]
eliminate fs = uncurry f . eliminateSplit fs
where
f x xs = if size x <= 1 then xs else x : xs
eliminateList :: [Factor] -> [String] -> [Factor]
eliminateList = foldl eliminate
eliminateSplit :: [Factor] -> String -> (Factor, [Factor])
eliminateSplit fs s = (sumout s (mconcat fs1), fs2)
where
(fs1, fs2) = partition (S.member s . varSet) fs
class Dimensional a where
size :: a -> Int
varSet :: a -> S.Set String
vars :: a -> [String]
vars = S.toList . varSet
instance Dimensional Factor where
size = size . dimensions
varSet = varSet . dimensions
instance Dimensional Dimensions where
size (D n _) = n
varSet = S.fromList . map fst . fromDimensions
instance Dimensional a => Dimensional [a] where
size = sum . map size
varSet = S.unions . map varSet
data Dimensions = D Int [(String, Int)]
fromDimensions :: Dimensions -> [(String, Int)]
fromDimensions (D _ xs) = xs
mkD :: [(String, Int)] -> Dimensions
mkD xs = D (product (map snd xs)) xs
deleteD :: String -> Dimensions -> Dimensions
deleteD s = filterD (/= s)
deletesD :: [String] -> Dimensions -> Dimensions
deletesD xs = filterD (`notElem` xs)
filterD :: (String -> Bool) -> Dimensions -> Dimensions
filterD p (D _ m) = mkD (filter (p . fst) m)
hasVarD :: String -> Dimensions -> Bool
hasVarD s (D _ xs) = rec xs
where
rec [] = False
rec ((x, _):rest) =
case compare s x of
LT -> False
EQ -> True
GT -> rec rest
mergeD :: Dimensions -> Dimensions -> Dimensions
mergeD (D _ m1) (D _ m2) = mkD (rec m1 m2)
where
rec lx@(x:xs) ly@(y:ys) =
case compare x y of
LT -> x : rec xs ly
EQ -> x : rec xs ys
GT -> y : rec lx ys
rec xs ys = xs ++ ys
mergesD :: [Dimensions] -> Dimensions
mergesD = foldr1 mergeD
data Tree = Bin [Tree] | Leaf Double# | Two !Tree !Tree | Three !Tree !Tree !Tree
bin :: [Tree] -> Tree
bin [x, y] = Two x y
bin [x, y, z] = Three x y z
bin xs = Bin xs
leaf :: Double -> Tree
leaf (D# x) = Leaf x
instance Show Tree where
show (Leaf a) = show (D# a)
show t = "(" ++ intercalate "," (map show (subtrees t)) ++ ")"
{-# INLINE mapTree #-}
mapTree :: (Double# -> Double#) -> Tree -> Tree
mapTree f = rec
where
rec (Leaf a) = Leaf (f a)
rec t = mapSubtrees rec t
treeToList :: Tree -> [Double]
treeToList = ($ []) . rec
where
rec (Bin ts) = foldr ((.) . rec) id ts
rec (Leaf a) = (D# a:)
rec (Two x y) = rec x . rec y
rec (Three x y z) = rec x . rec y . rec z
{-# INLINE zipTree #-}
zipTree :: (Double# -> Double# -> Double#) -> Tree -> Tree -> Tree
zipTree f = rec
where
rec (Leaf a) (Leaf b) = Leaf (f a b)
rec t1 t2 = zipSubtrees rec t1 t2
{-# INLINE subtrees #-}
subtrees :: Tree -> [Tree]
subtrees (Bin xs) = xs
subtrees (Two x y) = [x, y]
subtrees (Three x y z) = [x, y, z]
subtrees _ = error "subtrees"
{-# INLINE subtree #-}
subtree :: Int -> Tree -> Tree
subtree i (Bin xs) = xs !! i
subtree 0 (Two x _) = x
subtree 1 (Two _ y) = y
subtree 0 (Three x _ _) = x
subtree 1 (Three _ y _) = y
subtree 2 (Three _ _ z) = z
subtree _ _ = error "subtree"
{-# INLINE mapSubtrees #-}
mapSubtrees :: (Tree -> Tree) -> Tree -> Tree
mapSubtrees f (Bin xs) = Bin (map f xs)
mapSubtrees f (Two x y) = Two (f x) (f y)
mapSubtrees f (Three x y z) = Three (f x) (f y) (f z)
mapSubtrees _ _ = error "mapSubtrees"
{-# INLINE zipSubtrees #-}
zipSubtrees :: (Tree -> Tree -> Tree) -> Tree -> Tree -> Tree
zipSubtrees f (Bin xs) (Bin ys) = Bin (zipWith f xs ys)
zipSubtrees f (Two x1 x2) (Two y1 y2) = Two (f x1 y1) (f x2 y2)
zipSubtrees f (Three x1 x2 x3) (Three y1 y2 y3) = Three (f x1 y1) (f x2 y2) (f x3 y3)
zipSubtrees _ _ _ = error "zipSubtrees"
{-# INLINE mergeWith #-}
mergeWith :: (Double# -> Double# -> Double#) -> Dimensions -> Dimensions -> Tree -> Tree -> Tree
mergeWith f da db = rec (merges (fromDimensions da) (fromDimensions db))
where
rec (TakeLeft:ms) t1 t2 = mapSubtrees (\x -> rec ms x t2) t1
rec (TakeRight:ms) t1 t2 = mapSubtrees (\x -> rec ms t1 x) t2
rec (Merge:ms) t1 t2 = zipSubtrees (rec ms) t1 t2
rec [] t1 t2 = zipTree f t1 t2
merges :: [(String, Int)] -> [(String, Int)] -> [Merge]
merges [] [] = []
merges (_:ds1) [] = TakeLeft : merges ds1 []
merges [] (_:ds2) = TakeRight : merges [] ds2
merges l1@(d1:ds1) l2@(d2:ds2) =
case compare d1 d2 of
LT -> TakeLeft : merges ds1 l2
EQ -> Merge : merges ds1 ds2
GT -> TakeRight : merges l1 ds2
data Merge = TakeLeft | TakeRight | Merge
set :: String -> Int -> Dimensions -> Tree -> Tree
set s i = choose [(s, i)]
choose :: [(String, Int)] -> Dimensions -> Tree -> Tree
choose env d = rec (dropTail choices)
where
rec [] t = t
rec (Nothing:cs) t = mapSubtrees (rec cs) t
rec (Just i:cs) t = rec cs (subtree i t)
choices = map (\(s, _) -> lookup s env) (fromDimensions d)
dropTail = reverse . dropWhile isNothing . reverse
setAtLevel :: Int -> Int -> Tree -> Tree
setAtLevel lev i = rec lev
where
rec 0 = subtree i
rec l = mapSubtrees (rec (l-1))
reorder :: [(String, Int)] -> [(String, Int)] -> Tree -> Tree
reorder as bs t | as == bs = t
reorder (p:rest) (o:old) t | p==o = mapSubtrees (reorder rest old) t
reorder ((s, n):rest) old t =
case findIndex ((== s) . fst) old of
Just l -> bin [ reorder rest (filter ((/= s) . fst) old) $ setAtLevel l i t | i <- [0..n-1] ]
Nothing -> error "invalid reordering"
reorder _ _ _ = error "invalid reordering"