module Idris.Elab.AsPat(desugarAs) where
import Idris.AbsSyntax
import Idris.Core.TT
import Control.Monad.State.Strict
import Data.Generics.Uniplate.Data (transformM, universe)
desugarAs :: PTerm -> PTerm -> [PDecl] -> (PTerm, PTerm, [PDecl])
desugarAs lhs rhs whereblock
= (lhs', bindOnRight rhs pats, map (bindInWhereDecl pats) whereblock)
where
(lhs', pats) = runState (collectAs (replaceUnderscore lhs)) []
bindOnRight :: PTerm -> [(Name, FC, PTerm)] -> PTerm
bindOnRight = foldr (\(n, fc, tm) -> PLet fc RigW n NoFC Placeholder tm)
bindInWhereDecl :: [(Name, FC, PTerm)] -> PDecl -> PDecl
bindInWhereDecl pats (PClauses fc opts n clauses)
= PClauses fc opts n $ map (bindInWhereClause pats) clauses
bindInWhereDecl _ d = d
bindInWhereClause :: [(Name, FC, PTerm)] -> PClause -> PClause
bindInWhereClause pats (PClause fc n lhs ws rhs wb)
= let bound = [ n | (PRef _ _ n) <- universe lhs ]
pats' = filter (not . (`elem` bound) . \(n,_,_) -> n) pats
rhs' = bindOnRight rhs pats' in
PClause fc n lhs ws rhs' $ map (bindInWhereDecl pats') wb
bindInWhereClause _ c = c
collectAs :: PTerm -> State [(Name, FC, PTerm)] PTerm
collectAs (PAs fc n tm) = do tm' <- collectAs tm
pats <- get
put (pats ++ [(n, fc, tm')])
return tm'
collectAs (PApp fc t as)
= do as_tm <- mapM collectAs (map getTm as)
let as' = zipWith (\a tm -> a { getTm = tm }) as as_tm
return (PApp fc t as')
collectAs tm@(PAlternative ns (ExactlyOne d) (a : as))
= do a' <- collectAs a
pats <- get
as' <- mapM collectAs as
put pats
return (PAlternative ns (ExactlyOne d) (a' : as'))
collectAs x = return x
replaceUnderscore :: PTerm -> PTerm
replaceUnderscore tm = evalState (transformM (underAs replaceUnderscore') tm) 0
where
underAs :: (PTerm -> State Int PTerm) -> PTerm -> State Int PTerm
underAs f (PAs fc n tm) = PAs fc n <$> transformM f tm
underAs f x = return x
fresh :: State Int Name
fresh = modify (+1) >> flip sMN "underscorePatVar" <$> get
replaceUnderscore' :: PTerm -> State Int PTerm
replaceUnderscore' Placeholder = PRef emptyFC [] <$> fresh
replaceUnderscore' tm = return tm