module LambdaCube.Compiler.InferMonad where
import Data.Monoid
import Data.List
import qualified Data.Set as Set
import qualified Data.Map as Map
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Arrow hiding ((<+>))
import LambdaCube.Compiler.DeBruijn
import LambdaCube.Compiler.Pretty hiding (braces, parens)
import LambdaCube.Compiler.DesugaredSource hiding (getList)
import LambdaCube.Compiler.Parser (ParseWarning (..))
import LambdaCube.Compiler.Core
data ErrorMsg
= ErrorMsg Doc
| ECantFind SName SI
| ETypeError Doc SI
| ERedefined SName SI SI
errorRange_ = \case
ErrorMsg s -> []
ECantFind s si -> [si]
ETypeError msg si -> [si]
ERedefined s si si' -> [si, si']
instance PShow ErrorMsg where
pShow = \case
ErrorMsg s -> s
ECantFind s si -> "can't find:" <+> text s <+> "in" <+> pShow si
ETypeError msg si -> "type error:" <+> msg <$$> "in" <+> pShow si
ERedefined s si si' -> "already defined" <+> text s <+> "at" <+> pShow si <$$> "and at" <+> pShow si'
data Info
= Info Range Doc
| IType SIName Exp
| ITrace String String
| IError ErrorMsg
| ParseWarning ParseWarning
instance PShow Info where
pShow = \case
Info r s -> nest 4 $ shortForm (pShow r) <$$> s
IType a b -> shAnn (pShow a) (pShow b)
ITrace i s -> text i <> ": " <+> text s
IError e -> "!" <> pShow e
ParseWarning w -> pShow w
errorRange is = [r | IError e <- is, RangeSI r <- errorRange_ e ]
type Infos = [Info]
throwError' e = tell [IError e] >> throwError e
mkInfoItem (RangeSI r) i = [Info r i]
mkInfoItem _ _ = mempty
listAllInfos f m
= h "trace" (listTraceInfos m) ++ listAllInfos' f m
where
h x [] = []
h x xs = ("------------" <+> x) : xs
listAllInfos' f m
= h "tooltips" [ nest 4 $ shortForm $ showRangeWithoutFileName r <$$> hsep (intersperse "|" is)
| (r, is) <- listTypeInfos m, maybe False (rangeFile r ==) f ]
++ h "warnings" [ pShow w | ParseWarning w <- m ]
where
h x [] = []
h x xs = ("------------" <+> x) : xs
listTraceInfos m = [DResetFreshNames $ pShow i | i <- m, case i of Info{} -> False; ParseWarning{} -> False; _ -> True]
listTypeInfos m = Map.toList $ Map.unionsWith (<>) [Map.singleton r [DResetFreshNames i] | Info r i <- m]
listErrors m = Map.toList $ Map.unionsWith (<>) [Map.singleton r [DResetFreshNames (pShow e)] | IError e <- m, RangeSI r <- errorRange_ e]
listWarnings m = Map.toList $ Map.unionsWith (<>) [Map.singleton r [DResetFreshNames msg] | ParseWarning (getRangeAndMsg -> Just (r, msg)) <- m]
where
getRangeAndMsg = \case
Unreachable r -> Just (r, "Unreachable")
w@(Uncovered (getRange . sourceInfo -> Just r) _) -> Just (r, pShow w)
_ -> Nothing
tellType si t = tell $ mkInfoItem (sourceInfo si) $ DTypeNamespace True $ pShow t
type GlobalEnv = Map.Map SName (Exp, Type, SI)
initEnv :: GlobalEnv
initEnv = Map.fromList
[ (,) "'Type" (TType, TType, debugSI "source-of-Type")
]
type IM m = ExceptT ErrorMsg (ReaderT (Extensions, GlobalEnv) (WriterT Infos m))
expAndType s (e, t, si) = (ET e t)
lookupName s@(Ticked s') m = expAndType s <$> (Map.lookup s m `mplus` Map.lookup s' m)
lookupName s m = expAndType s <$> Map.lookup s m
getDef te si s = do
nv <- asks snd
maybe (throwError' $ ECantFind s si) return (lookupName s nv)
addToEnv :: Monad m => SIName -> ExpType -> IM m GlobalEnv
addToEnv sn@(SIName si s) (ET x t) = do
tell [IType sn t]
v <- asks $ Map.lookup s . snd
case v of
Nothing -> return $ Map.singleton s (x, t, si)
Just (_, _, si') -> throwError' $ ERedefined s si si'
removeHiddenUnit (Pi Hidden (hnf -> Unit) (down 0 -> Just t)) = removeHiddenUnit t
removeHiddenUnit (Pi h a b) = Pi h a $ removeHiddenUnit b
removeHiddenUnit t = t
addParams ps t = foldr (uncurry Pi) t ps
addLams ps t = foldr (const Lam) t ps
lamify t x = addLams (fst $ getParams t) $ x $ downTo 0 $ arity t
lamify' t x = addLams (fst $ getParams t) $ x $ downTo' 0 $ arity t
arity :: Exp -> Int
arity = length . fst . getParams
downTo n m = map Var [n+m1, n+m2..n]
downTo' n m = map Var [n, n+1..n+m1]
withEnv e = local $ second (<> e)
lamPi h t (ET x y) = ET (Lam x) (Pi h t y)
ambiguityCheck :: String -> Exp -> Maybe String
ambiguityCheck s ty = case ambigVars ty of
[] -> Nothing
err -> Just $ s ++ " has ambiguous type:\n" ++ ppShow ty ++ "\nproblematic vars:\n" ++ ppShow err
ambigVars :: Exp -> [(Int, Exp)]
ambigVars ty = [(n, c) | (n, c) <- hid, not $ any (`Set.member` defined) $ Set.insert n $ free c]
where
(defined, hid, _i) = compDefined False ty
floatLetMeta :: Exp -> Bool
floatLetMeta ty = (i1) `Set.member` defined
where
(defined, hid, i) = compDefined True ty
compDefined b ty = (defined, hid, i)
where
defined = dependentVars hid $ Set.map (if b then (+i) else id) $ free ty
i = length hid_
hid = zipWith (\k t -> (k, up (k+1) t)) (reverse [0..i1]) hid_
(hid_, ty') = hiddenVars ty
free = Set.fromList . freeVars . getFreeVars
hiddenVars (Pi Hidden a b) = first (a:) $ hiddenVars b
hiddenVars t = ([], t)
dependentVars :: [(Int, Exp)] -> Set.Set Int -> Set.Set Int
dependentVars ie = cycle mempty
where
freeVars = free
cycle acc s
| Set.null s = acc
| otherwise = cycle (acc <> s) (grow s Set.\\ acc)
grow = flip foldMap ie $ \case
(n, t) -> (Set.singleton n <-> freeVars t) <> case t of
(hnf -> CW (hnf -> CstrT _ ty f)) -> freeVars ty <-> freeVars f
(hnf -> CSplit a b c) -> freeVars a <-> (freeVars b <> freeVars c)
_ -> mempty
where
a --> b = \s -> if Set.null $ a `Set.intersection` s then mempty else b
a <-> b = (a --> b) <> (b --> a)