{-# LANGUAGE DataKinds, TypeFamilies, TemplateHaskell, PolyKinds, ExistentialQuantification, ScopedTypeVariables, UndecidableInstances, TypeOperators, UndecidableSuperClasses, GADTs, PartialTypeSignatures, RankNTypes, FlexibleInstances, MultiParamTypeClasses, FlexibleContexts #-} module Data.HappyTree where import Data.Singletons.Prelude import Data.Singletons.TH import qualified Generics.SOP as SOP import Data.Constraint import Data.List import Data.Ord import Data.Maybe import Safe $(singletons [d| revAppend :: [a] -> [a] -> [a] revAppend [] x = x revAppend (x:xs) n = revAppend xs (x:n) selElemAux :: [a] -> [a] -> [(a, [a])] selElemAux l [] = [] selElemAux l (r:rs) = (r, revAppend l rs) : selElemAux (r:l) rs selElem :: [a] -> [(a, [a])] selElem = selElemAux [] |]) takeElemAux :: [a] -> [a] -> [([a], a, [a])] takeElemAux l [] = [] takeElemAux l (r:rs) = (reverse l, r, rs) : takeElemAux (r:l) rs takeElem :: [a] -> [([a], a, [a])] takeElem = takeElemAux [] data SelElemTypeAux a = SelElemTypeAux (Fst a) (SOP.NP SOP.I (Snd a)) type SelElemAuxType a b = SOP.NP SelElemTypeAux (SelElemAux a b) type SelElemType a = SelElemAuxType '[] a npRevAppend :: SOP.NP f a -> SOP.NP f b -> SOP.NP f (RevAppend a b) npRevAppend SOP.Nil x = x npRevAppend (x SOP.:* y) z = npRevAppend y (x SOP.:* z) npSelElemAux :: SOP.NP SOP.I a -> SOP.NP SOP.I b -> SelElemAuxType a b npSelElemAux _ SOP.Nil = SOP.Nil npSelElemAux x (SOP.I y SOP.:* z) = SelElemTypeAux y (npRevAppend x z) SOP.:* npSelElemAux (SOP.I y SOP.:* x) z npSelElem :: SOP.NP SOP.I a -> SelElemAuxType '[] a npSelElem = npSelElemAux SOP.Nil dictSList :: SOP.SList a -> Dict (SOP.SListI a) dictSList SOP.SNil = Dict dictSList SOP.SCons = Dict sListCons :: Proxy a -> SOP.SList b -> SOP.SList (a:b) sListCons _ x = dictSList x `withDict` SOP.SCons npAppend :: SOP.NP f a -> SOP.NP f b -> SOP.NP f (a :++ b) npAppend SOP.Nil x = x npAppend (x SOP.:* y) z = x SOP.:* npAppend y z npToSList :: SOP.NP f a -> SOP.SList a npToSList SOP.Nil = SOP.SNil npToSList (_ SOP.:* x) = sListCons Proxy (npToSList x) newtype SplitOnAux a b c = SplitOnAux { runSplitOnAux :: DecisionTree (c :++ a) b } data SplitOn (b :: *) (x :: (*, [*])) = forall (c :: [[*]]) . SOP.SListI c => SplitOn (Fst x -> SOP.SOP SOP.I c) (SOP.NP (SplitOnAux (Snd x) b) c) data DecisionTree (a :: [*]) (b :: *) = Leaf (SOP.NP SOP.I a -> b) | Split (SOP.NS (SplitOn b) (SelElem a)) eval :: DecisionTree a b -> SOP.NP SOP.I a -> b eval (Leaf f) x = f x eval (Split f) x = dictSList (npToSList sex) `withDict` SOP.hcollapse (SOP.hzipWith onTree sex f) where sex = npSelElem x onTree :: SelElemTypeAux c -> SplitOn b c -> SOP.K b _ onTree (SelElemTypeAux a b) (SplitOn c d) = let SOP.SOP e = c a in SOP.K $ SOP.hcollapse $ SOP.hzipWith (\(SplitOnAux f) g -> SOP.K $ eval f (npAppend g b)) d e entropy :: Ord a => [a] -> Double entropy x = sum $ map (\y -> let py = fromIntegral (length y) / lenx in -py * log py) $ group $ sort x where lenx = fromIntegral $ length x :: Double data IndexAux (l :: k) (r :: k) = l ~ r => IndexAux newtype Index (l :: [k]) (x :: k) = Index { runIndex :: SOP.NS (IndexAux x) l } fromIndex :: SOP.SListI l => SOP.NP f l -> Index l x -> f x fromIndex l (Index r) = SOP.hcollapse $ SOP.hzipWith (\x IndexAux -> SOP.K x) l r class GetIndex l x where getIndex :: Proxy l -> Proxy x -> Index l x instance {-# OVERLAPPABLE #-} GetIndex r x => GetIndex (l:r) x where getIndex _ _ = Index $ SOP.S $ runIndex $ getIndex Proxy Proxy instance {-# OVERLAPPING #-} GetIndex (l:r) l where getIndex _ _ = Index $ SOP.Z IndexAux newtype SplitFunAuxAux b d = SplitFunAuxAux { runSplitFunAuxAux :: [(SOP.NP SOP.I d, b)] } data SplitFun env a b = forall (c :: [[*]]) . (SOP.SListI2 c, SOP.All2 (GetIndex env) c) => SplitFun (a -> SOP.SOP SOP.I c) (SOP.NP (SplitFunAuxAux b) c) unitSplitFun :: [b] -> SplitFun env a b unitSplitFun b = SplitFun (const $ SOP.SOP (SOP.Z SOP.Nil)) (SplitFunAuxAux (map (\x -> (SOP.Nil, x)) b) SOP.:* SOP.Nil) splitFunP :: (SOP.SListI2 c, SOP.All2 (GetIndex env) c) => Proxy c -> (a -> SOP.SOP SOP.I c) -> SOP.NP (SplitFunAuxAux b) c -> SplitFun env a b splitFunP _ = SplitFun data SplitStrat (env :: [*]) a = SplitStrat (forall b . [(a, b)] -> [SplitFun env a b]) runSplitFun (SplitStrat f) x = f x type SplitStrats cur env = SOP.NP (SplitStrat env) cur instance Monoid (SplitStrat env a) where mempty = SplitStrat (const []) SplitStrat l `mappend` SplitStrat r = SplitStrat (\b -> l b ++ r b) splitStaticAux :: SplitStrat env a -> Proxy (GetIndex env) splitStaticAux _ = Proxy splitStatic :: (SOP.SListI2 c, SOP.All2 (GetIndex env) c) => (a -> SOP.SOP SOP.I c) -> SplitStrat env a splitStatic split = res where res = SplitStrat $ \x -> [SplitFun split (foldl join def $ map (\(a, b) -> SOP.hexpand (SplitFunAuxAux []) $ SOP.hmap (\c -> SplitFunAuxAux [(c, b)]) $ SOP.unSOP $ split a) x)] join :: SOP.SListI2 c => SOP.NP (SplitFunAuxAux b) c -> SOP.NP (SplitFunAuxAux b) c -> SOP.NP (SplitFunAuxAux b) c join = SOP.hzipWith (\(SplitFunAuxAux l) (SplitFunAuxAux r) -> SplitFunAuxAux $ l ++ r) def :: SOP.SListI2 c => SOP.NP (SplitFunAuxAux b) c def = SOP.hpure $ SplitFunAuxAux [] splitOrd :: (Ord a, GetIndex env a) => SplitStrat env a splitOrd = SplitStrat $ map (\(x, y, z) -> splitFunP (Proxy :: Proxy ['[a], '[], '[a]]) (\a -> SOP.SOP $ case a `compare` fst (head y) of LT -> SOP.Z $ SOP.I a SOP.:* SOP.Nil EQ -> SOP.S $ SOP.Z SOP.Nil GT -> SOP.S $ SOP.S $ SOP.Z $ SOP.I a SOP.:* SOP.Nil) (SplitFunAuxAux (map func $ concat x) SOP.:* SplitFunAuxAux (map (\(_, a) -> (SOP.Nil, a)) y) SOP.:* SplitFunAuxAux (map func $ concat z) SOP.:* SOP.Nil)) . takeElem . groupBy (\(l, _) (r, _) -> l == r) . sortBy (comparing fst) where func (a, b) = (SOP.I a SOP.:* SOP.Nil, b) splitStructure :: (SOP.Generic a, SOP.SListI2 (SOP.Code a), SOP.All2 (GetIndex env) (SOP.Code a)) => SplitStrat env a splitStructure = splitStatic SOP.from getIndex2 :: SOP.All (GetIndex l) r => SOP.SList r -> SOP.NP (Index l) r getIndex2 SOP.SNil = SOP.Nil getIndex2 SOP.SCons = getIndex Proxy Proxy SOP.:* getIndex2 SOP.sList mode :: Ord a => [a] -> a mode = head . maximumBy (comparing length) . group . sort nMinOnAux :: Ord b => (forall x . f x -> b) -> SOP.NP f a -> Maybe (b, SOP.NS f a) nMinOnAux fb SOP.Nil = Nothing nMinOnAux fb (l SOP.:* r) = let lb = fb l in case nMinOnAux fb r of Nothing -> Just (lb, SOP.Z l) Just (rb, rs) -> case lb `compare` rb of GT -> Just (rb, SOP.S rs) _ -> Just (lb, SOP.Z l) nMinOn :: Ord b => (forall x . f x -> b) -> SOP.NP f a -> Maybe (SOP.NS f a) nMinOn f = fmap snd . nMinOnAux f data SelElemTypeAuxIndex env a = SelElemTypeAuxIndex (Index env (Fst a)) (SOP.NP (Index env) (Snd a)) selElemTypeAuxIndex :: SOP.NP (Index env) a -> SOP.NP (Index env) b -> SOP.NP (SelElemTypeAuxIndex env) (SelElemAux a b) selElemTypeAuxIndex _ SOP.Nil = SOP.Nil selElemTypeAuxIndex x (y SOP.:* z) = SelElemTypeAuxIndex y (npRevAppend x z) SOP.:* selElemTypeAuxIndex (y SOP.:* x) z fromSF :: SplitFun (env :: [*]) a b -> Proxy (SOP.All (GetIndex env)) fromSF _ = Proxy newtype BuildAuxAux a b = BuildAuxAux { runBuildAuxAux :: [(a, Fst b, SOP.NP SOP.I (Snd b))] } data Score = Destructing | Deciding Double deriving Eq instance Ord Score where Destructing `compare` Destructing = EQ Destructing `compare` _ = LT _ `compare` Destructing = GT Deciding l `compare` Deciding r = l `compare` r newtype WithScore b x = WithScore { runWithScore :: (Score, (SplitOn b x)) } buildTree :: (SOP.SListI env, Ord b) => SplitStrats env env -> SOP.NP (Index env) (Snd a1) -> b -> SplitFun env (Fst a1) (b, SOP.NP SOP.I (Snd a1)) -> (Score, SplitOn b a1) buildTree sf i def sfa@(SplitFun x y) = (if (<=1) $ length $ filter not $ SOP.hcollapse $ SOP.hmap (\(SplitFunAuxAux z) -> SOP.K $ null z) y then Destructing else Deciding $ sum $ SOP.hcollapse $ SOP.hmap (\(SplitFunAuxAux z) -> SOP.K $ (fromIntegral (length z)*) $ entropy $ map (fst . snd) z) y, SplitOn x (SOP.hcmap (fromSF sfa) (\(SplitFunAuxAux z) -> if length z == 0 then SplitOnAux $ Leaf $ const def else let j = fst $ head z in SplitOnAux $ buildAux (getIndex2 (npToSList j) `npAppend` i) sf (map (\(c, (d, e)) -> (c `npAppend` e, d)) z) (mode $ map (fst . snd) z)) y)) buildAux :: (SOP.SListI env, Ord b) => SOP.NP (Index env) a -> SplitStrats env env -> [(SOP.NP SOP.I a, b)] -> b -> DecisionTree a b buildAux _ sf [] def = Leaf $ const def buildAux i sf x@((l, r):_) def = case dictSList $ npToSList $ npSelElem $ l of Dict -> if length (group (map snd x)) == 1 then Leaf $ const $ r else let a = map (\(l, r) -> SOP.hmap (\(SelElemTypeAux a b) -> BuildAuxAux [(r, a, b)]) $ npSelElem l) x b = foldl (SOP.hzipWith (\(BuildAuxAux l) (BuildAuxAux r) -> BuildAuxAux (l ++ r))) (SOP.hpure (BuildAuxAux [])) a in fromMaybe (Leaf $ const def) $ fmap (Split . SOP.hmap (\(WithScore (_, t)) -> t)) $ nMinOn (\(WithScore (s, _)) -> s) $ SOP.hzipWith (\(SelElemTypeAuxIndex c d) (BuildAuxAux e) -> WithScore $ minimumByDef (buildTree sf d def $ unitSplitFun $ map (\(f, g, h) -> (f, h)) e) (comparing fst) $ map (buildTree sf d def) $ runSplitFun (fromIndex sf c) $ map (\(f, g, h) -> (g, (f, h))) e) (selElemTypeAuxIndex SOP.Nil i) b build :: (SOP.All (GetIndex env) a, SOP.SListI env, Ord b) => SplitStrats env env -> [(SOP.NP SOP.I a, b)] -> b -> DecisionTree a b build sf [] def = Leaf $ const def build sf x@((np, _):_) _ = buildAux (getIndex2 $ npToSList $ np) sf (sortBy (comparing snd) x) (mode $ map snd x)