module Cases (caseToMatch, Match (..), Clause (..), matchSubst) where import Control.Arrow (first) import Control.Monad (liftM,foldM) import Data.List (groupBy,sortBy,lookup) import Data.Maybe (fromMaybe) import Ast import Guid import Substitute caseToMatch patterns = do v <- newVar match [v] (map (first (:[])) patterns) Fail newVar = do n <- guid return $ "case" ++ show n data Match = Match String [Clause] Match | Break | Fail | Other Expr | Seq [Match] deriving Show data Clause = Clause String [String] Match deriving Show matchSubst :: [(String,String)] -> Match -> Match matchSubst _ Break = Break matchSubst _ Fail = Fail matchSubst pairs (Seq ms) = Seq (map (matchSubst pairs) ms) matchSubst pairs (Other e) = Other $ foldr ($) e $ map (\(x,y) -> subst x (Var y)) pairs matchSubst pairs (Match n cs m) = Match (varSubst n) (map clauseSubst cs) (matchSubst pairs m) where clauseSubst (Clause c vs m) = Clause c (map varSubst vs) (matchSubst pairs m) varSubst v = fromMaybe v (lookup v pairs) isCon (PData _ _ : _, _) = True isCon _ = False isVar p = not (isCon p) match :: [String] -> [([Pattern],Expr)] -> Match -> GuidCounter Match match [] [] def = return def match [] [([],e)] Fail = return $ Other e match [] [([],e)] Break = return $ Other e match [] cs def = return $ Seq (map (Other . snd) cs ++ [def]) match vs cs def | all isVar cs = matchVar vs cs def | all isCon cs = matchCon vs cs def | otherwise = matchMix vs cs def matchVar :: [String] -> [([Pattern],Expr)] -> Match -> GuidCounter Match matchVar (v:vs) cs def = match vs (map subVar cs) def where subVar (PVar x : ps, e) = (ps, subst x (Var v) e) subVar (PAnything : ps, e) = (ps, e) matchCon :: [String] -> [([Pattern],Expr)] -> Match -> GuidCounter Match matchCon (v:vs) cs def = (flip (Match v) def) `liftM` mapM toClause css where css = groupBy (withName (==)) $ sortBy (withName compare) cs withName f (PData n1 _:_,_) (PData n2 _:_,_) = f n1 n2 toClause cs = let (PData name _ : _ , _) = head cs in matchClause name (v:vs) cs Break matchClause :: String -> [String] -> [([Pattern],Expr)] -> Match -> GuidCounter Clause matchClause c (v:vs) cs def = do vs' <- getVars Clause c vs' `liftM` match (vs' ++ vs) (map flatten cs) def where flatten (PData _ ps' : ps, e) = (ps' ++ ps, e) getVars = let (PData _ ps : _, _) = head cs in mapM (\_ -> newVar) ps matchMix :: [String] -> [([Pattern],Expr)] -> Match -> GuidCounter Match matchMix vs cs def = foldM (flip $ match vs) def (reverse css) where css = groupBy (\p1 p2 -> isCon p1 == isCon p2) cs