{-# LANGUAGE OverloadedStrings #-}
module Clash.Core.EqSolver where
import Data.Maybe (catMaybes, mapMaybe)
import Clash.Core.Name (Name(nameOcc))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var
data TypeEqSolution
= Solution (TyVar, Type)
| AbsurdSolution
| NoSolution
deriving (Show, Eq)
catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = mapMaybe getSol
where
getSol (Solution s) = Just s
getSol _ = Nothing
solveNonAbsurds :: TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds _tcm [] = []
solveNonAbsurds tcm (eq:eqs) =
solved ++ solveNonAbsurds tcm eqs
where
solvers = [pure . solveAdd, solveEq tcm]
solved = catSolutions (concat [s eq | s <- solvers])
solveEq :: TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq tcm (coreView tcm -> left, coreView tcm -> right) =
case (left, right) of
(VarTy tyVar, ConstTy {}) ->
[Solution (tyVar, right)]
(ConstTy {}, VarTy tyVar) ->
[Solution (tyVar, left)]
(ConstTy {}, ConstTy {}) ->
if left /= right then [AbsurdSolution] else []
(LitTy {}, LitTy {}) ->
if left /= right then [AbsurdSolution] else []
_ ->
if any (isTypeFamilyApplication tcm) [left, right] then
[]
else
case (tyView left, tyView right) of
(TyConApp leftNm leftTys, TyConApp rightNm rightTys) ->
if leftNm == rightNm then
concat (map (solveEq tcm) (zip leftTys rightTys))
else
[AbsurdSolution]
_ ->
[]
solveAdd
:: (Type, Type)
-> TypeEqSolution
solveAdd ab =
case normalizeAdd ab of
Just (n, m, VarTy tyVar) ->
if n >= 0 && m >= 0 && n - m >= 0 then
Solution (tyVar, (LitTy (NumTy (n - m))))
else
AbsurdSolution
_ ->
NoSolution
normalizeAdd
:: (Type, Type)
-> Maybe (Integer, Integer, Type)
normalizeAdd (a, b) = do
(n, rhs) <- lhsLit a b
case tyView rhs of
TyConApp (nameOcc -> "GHC.TypeNats.+") [left, right] -> do
(m, o) <- lhsLit left right
return (n, m, o)
_ ->
Nothing
where
lhsLit x (LitTy (NumTy n)) = Just (n, x)
lhsLit (LitTy (NumTy n)) y = Just (n, y)
lhsLit _ _ = Nothing
isAbsurdAlt
:: TyConMap
-> Alt
-> Bool
isAbsurdAlt tcm alt =
any (isAbsurdEq tcm) (altEqs tcm alt)
isAbsurdEq
:: TyConMap
-> (Type, Type)
-> Bool
isAbsurdEq tcm ((left0, right0)) =
case (coreView tcm left0, coreView tcm right0) of
(solveAdd -> AbsurdSolution) -> True
lr -> any (==AbsurdSolution) (solveEq tcm lr)
altEqs
:: TyConMap
-> Alt
-> [(Type, Type)]
altEqs tcm (pat, _term) =
catMaybes (map (typeEq tcm . varType) (snd (patIds pat)))
typeEq
:: TyConMap
-> Type
-> Maybe (Type, Type)
typeEq tcm ty =
case tyView (coreView tcm ty) of
TyConApp (nameOcc -> "GHC.Prim.~#") [_, _, left, right] ->
Just (coreView tcm left, coreView tcm right)
_ ->
Nothing