{-# LANGUAGE OverloadedStrings #-}
module Clash.Core.EqSolver where
import Data.List.Extra (zipEqual)
import Data.Maybe (catMaybes, mapMaybe)
import Clash.Core.Name (Name(nameOcc))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var
import Clash.Core.VarEnv (VarSet, elemVarSet, emptyVarSet, mkVarSet)
data TypeEqSolution
= Solution (TyVar, Type)
| AbsurdSolution
| NoSolution
deriving (Int -> TypeEqSolution -> ShowS
[TypeEqSolution] -> ShowS
TypeEqSolution -> String
(Int -> TypeEqSolution -> ShowS)
-> (TypeEqSolution -> String)
-> ([TypeEqSolution] -> ShowS)
-> Show TypeEqSolution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TypeEqSolution] -> ShowS
$cshowList :: [TypeEqSolution] -> ShowS
show :: TypeEqSolution -> String
$cshow :: TypeEqSolution -> String
showsPrec :: Int -> TypeEqSolution -> ShowS
$cshowsPrec :: Int -> TypeEqSolution -> ShowS
Show, TypeEqSolution -> TypeEqSolution -> Bool
(TypeEqSolution -> TypeEqSolution -> Bool)
-> (TypeEqSolution -> TypeEqSolution -> Bool) -> Eq TypeEqSolution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TypeEqSolution -> TypeEqSolution -> Bool
$c/= :: TypeEqSolution -> TypeEqSolution -> Bool
== :: TypeEqSolution -> TypeEqSolution -> Bool
$c== :: TypeEqSolution -> TypeEqSolution -> Bool
Eq)
catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = (TypeEqSolution -> Maybe (TyVar, Type))
-> [TypeEqSolution] -> [(TyVar, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeEqSolution -> Maybe (TyVar, Type)
getSol
where
getSol :: TypeEqSolution -> Maybe (TyVar, Type)
getSol (Solution (TyVar, Type)
s) = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar, Type)
s
getSol TypeEqSolution
_ = Maybe (TyVar, Type)
forall a. Maybe a
Nothing
solveNonAbsurds :: TyConMap -> VarSet -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds :: TyConMap -> VarSet -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
_tcm VarSet
_ [] = []
solveNonAbsurds TyConMap
tcm VarSet
solveSet ((Type, Type)
eq:[(Type, Type)]
eqs) =
[(TyVar, Type)]
solved [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
forall a. [a] -> [a] -> [a]
++ TyConMap -> VarSet -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
tcm VarSet
solveSet [(Type, Type)]
eqs
where
solvers :: [(Type, Type) -> [TypeEqSolution]]
solvers = [TypeEqSolution -> [TypeEqSolution]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TypeEqSolution -> [TypeEqSolution])
-> ((Type, Type) -> TypeEqSolution)
-> (Type, Type)
-> [TypeEqSolution]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarSet -> (Type, Type) -> TypeEqSolution
solveAdd VarSet
solveSet, TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
solveSet]
solved :: [(TyVar, Type)]
solved = [TypeEqSolution] -> [(TyVar, Type)]
catSolutions ([[TypeEqSolution]] -> [TypeEqSolution]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [(Type, Type) -> [TypeEqSolution]
s (Type, Type)
eq | (Type, Type) -> [TypeEqSolution]
s <- [(Type, Type) -> [TypeEqSolution]]
solvers])
solveEq :: TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq :: TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
solveSet (TyConMap -> Type -> Type
coreView TyConMap
tcm -> Type
left, TyConMap -> Type -> Type
coreView TyConMap
tcm -> Type
right) =
case (Type
left, Type
right) of
(VarTy TyVar
tyVar, ConstTy {}) | TyVar -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet TyVar
tyVar VarSet
solveSet ->
[(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, Type
right)]
(ConstTy {}, VarTy TyVar
tyVar) | TyVar -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet TyVar
tyVar VarSet
solveSet ->
[(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, Type
left)]
(ConstTy {}, ConstTy {}) ->
if Type
left Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
right then [TypeEqSolution
AbsurdSolution] else []
(LitTy {}, LitTy {}) ->
if Type
left Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
right then [TypeEqSolution
AbsurdSolution] else []
(Type, Type)
_ ->
if (Type -> Bool) -> [Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TyConMap -> Type -> Bool
isTypeFamilyApplication TyConMap
tcm) [Type
left, Type
right] then
[]
else
case (Type -> TypeView
tyView Type
left, Type -> TypeView
tyView Type
right) of
(TyConApp TyConName
leftNm [Type]
leftTys, TyConApp TyConName
rightNm [Type]
rightTys) ->
if TyConName
leftNm TyConName -> TyConName -> Bool
forall a. Eq a => a -> a -> Bool
== TyConName
rightNm then
[[TypeEqSolution]] -> [TypeEqSolution]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat (((Type, Type) -> [TypeEqSolution])
-> [(Type, Type)] -> [[TypeEqSolution]]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
solveSet) ([Type] -> [Type] -> [(Type, Type)]
forall a b. HasCallStack => [a] -> [b] -> [(a, b)]
zipEqual [Type]
leftTys [Type]
rightTys))
else
[TypeEqSolution
AbsurdSolution]
(TypeView, TypeView)
_ ->
[]
solveAdd
:: VarSet
-> (Type, Type)
-> TypeEqSolution
solveAdd :: VarSet -> (Type, Type) -> TypeEqSolution
solveAdd VarSet
solveSet (Type, Type)
ab =
case (Type, Type) -> Maybe (Integer, Integer, Type)
normalizeAdd (Type, Type)
ab of
Just (Integer
n, Integer
m, VarTy TyVar
tyVar) | TyVar -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet TyVar
tyVar VarSet
solveSet ->
if Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 then
(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, (LitTy -> Type
LitTy (Integer -> LitTy
NumTy (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m))))
else
TypeEqSolution
AbsurdSolution
Maybe (Integer, Integer, Type)
_ ->
TypeEqSolution
NoSolution
normalizeAdd
:: (Type, Type)
-> Maybe (Integer, Integer, Type)
normalizeAdd :: (Type, Type) -> Maybe (Integer, Integer, Type)
normalizeAdd (Type
a, Type
b) = do
(Integer
n, Type
rhs) <- Type -> Type -> Maybe (Integer, Type)
lhsLit Type
a Type
b
case Type -> TypeView
tyView Type
rhs of
TyConApp (TyConName -> OccName
forall a. Name a -> OccName
nameOcc -> OccName
"GHC.TypeNats.+") [Type
left, Type
right] -> do
(Integer
m, Type
o) <- Type -> Type -> Maybe (Integer, Type)
lhsLit Type
left Type
right
(Integer, Integer, Type) -> Maybe (Integer, Integer, Type)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Integer
n, Integer
m, Type
o)
TypeView
_ ->
Maybe (Integer, Integer, Type)
forall a. Maybe a
Nothing
where
lhsLit :: Type -> Type -> Maybe (Integer, Type)
lhsLit Type
x (LitTy (NumTy Integer
n)) = (Integer, Type) -> Maybe (Integer, Type)
forall a. a -> Maybe a
Just (Integer
n, Type
x)
lhsLit (LitTy (NumTy Integer
n)) Type
y = (Integer, Type) -> Maybe (Integer, Type)
forall a. a -> Maybe a
Just (Integer
n, Type
y)
lhsLit Type
_ Type
_ = Maybe (Integer, Type)
forall a. Maybe a
Nothing
isAbsurdAlt
:: TyConMap
-> Alt
-> Bool
isAbsurdAlt :: TyConMap -> Alt -> Bool
isAbsurdAlt TyConMap
tcm Alt
alt =
((Type, Type) -> Bool) -> [(Type, Type)] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TyConMap -> VarSet -> (Type, Type) -> Bool
isAbsurdEq TyConMap
tcm VarSet
exts) (TyConMap -> Alt -> [(Type, Type)]
altEqs TyConMap
tcm Alt
alt)
where
exts :: VarSet
exts = case Alt
alt of
(DataPat DataCon
_dc [TyVar]
extNms [Id]
_ids,Term
_) -> [TyVar] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [TyVar]
extNms
Alt
_ -> VarSet
emptyVarSet
isAbsurdEq
:: TyConMap
-> VarSet
-> (Type, Type)
-> Bool
isAbsurdEq :: TyConMap -> VarSet -> (Type, Type) -> Bool
isAbsurdEq TyConMap
tcm VarSet
exts ((Type
left0, Type
right0)) =
case (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
left0, TyConMap -> Type -> Type
coreView TyConMap
tcm Type
right0) of
(VarSet -> (Type, Type) -> TypeEqSolution
solveAdd VarSet
exts -> TypeEqSolution
AbsurdSolution) -> Bool
True
(Type, Type)
lr -> (TypeEqSolution -> Bool) -> [TypeEqSolution] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TypeEqSolution -> TypeEqSolution -> Bool
forall a. Eq a => a -> a -> Bool
==TypeEqSolution
AbsurdSolution) (TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
exts (Type, Type)
lr)
altEqs
:: TyConMap
-> Alt
-> [(Type, Type)]
altEqs :: TyConMap -> Alt -> [(Type, Type)]
altEqs TyConMap
tcm (Pat
pat, Term
_term) =
[Maybe (Type, Type)] -> [(Type, Type)]
forall a. [Maybe a] -> [a]
catMaybes ((Id -> Maybe (Type, Type)) -> [Id] -> [Maybe (Type, Type)]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> Type -> Maybe (Type, Type)
typeEq TyConMap
tcm (Type -> Maybe (Type, Type))
-> (Id -> Type) -> Id -> Maybe (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
forall a. Var a -> Type
varType) (([TyVar], [Id]) -> [Id]
forall a b. (a, b) -> b
snd (Pat -> ([TyVar], [Id])
patIds Pat
pat)))
typeEq
:: TyConMap
-> Type
-> Maybe (Type, Type)
typeEq :: TyConMap -> Type -> Maybe (Type, Type)
typeEq TyConMap
tcm Type
ty =
case Type -> TypeView
tyView (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
ty) of
TyConApp (TyConName -> OccName
forall a. Name a -> OccName
nameOcc -> OccName
"GHC.Prim.~#") [Type
_, Type
_, Type
left, Type
right] ->
(Type, Type) -> Maybe (Type, Type)
forall a. a -> Maybe a
Just (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
left, TyConMap -> Type -> Type
coreView TyConMap
tcm Type
right)
TypeView
_ ->
Maybe (Type, Type)
forall a. Maybe a
Nothing