{-# LANGUAGE OverloadedStrings #-}

module Clash.Core.EqSolver where

import Data.Maybe (catMaybes, mapMaybe)

import Clash.Core.Name (Name(nameOcc))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var

-- | Data type that indicates what kind of solution (if any) was found
data TypeEqSolution
  = Solution (TyVar, Type)
  -- ^ Solution was found. Variable equals some integer.
  | AbsurdSolution
  -- ^ A solution was found, but it involved negative naturals.
  | NoSolution
  -- ^ Given type wasn't an equation, or it was unsolvable.
    deriving (Show, Eq)

catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = mapMaybe getSol
 where
  getSol (Solution s) = Just s
  getSol _ = Nothing

-- | Solve given equations and return all non-absurd solutions
solveNonAbsurds :: TyConMap -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds _tcm [] = []
solveNonAbsurds tcm (eq:eqs) =
  solved ++ solveNonAbsurds tcm eqs
 where
  solvers = [pure . solveAdd, solveEq tcm]
  solved = catSolutions (concat [s eq | s <- solvers])

-- | Solve simple equalities such as:
--
--   * a ~ 3
--   * 3 ~ a
--   * SomeType a b ~ SomeType 3 5
--   * SomeType 3 5 ~ SomeType a b
--   * SomeType a 5 ~ SomeType 3 b
--
solveEq :: TyConMap -> (Type, Type) -> [TypeEqSolution]
solveEq tcm (coreView tcm -> left, coreView tcm -> right) =
  case (left, right) of
    (VarTy tyVar, ConstTy {}) ->
      -- a ~ 3
      [Solution (tyVar, right)]
    (ConstTy {}, VarTy tyVar) ->
      -- 3 ~ a
      [Solution (tyVar, left)]
    (ConstTy {}, ConstTy {}) ->
      -- Int /= Char
      if left /= right then [AbsurdSolution] else []
    (LitTy {}, LitTy {}) ->
      -- 3 /= 5
      if left /= right then [AbsurdSolution] else []
    _ ->
      -- The call to 'coreView' at the start of 'solveEq' should have reduced
      -- all solvable type families. If we encounter one here that means the
      -- type family is stuck (and that we shouldn't compare it to anything!).
      if any (isTypeFamilyApplication tcm) [left, right] then
        []
      else
        case (tyView left, tyView right) of
          (TyConApp leftNm leftTys, TyConApp rightNm rightTys) ->
            -- SomeType a b ~ SomeType 3 5 (or other way around)
            if leftNm == rightNm then
              concat (map (solveEq tcm) (zip leftTys rightTys))
            else
              [AbsurdSolution]
          _ ->
            []

-- | Solve equations supported by @normalizeAdd@. See documentation of
-- @TypeEqSolution@ to understand the return value.
solveAdd
  :: (Type, Type)
  -> TypeEqSolution
solveAdd ab =
  case normalizeAdd ab of
    Just (n, m, VarTy tyVar) ->
      if n >= 0 && m >= 0 && n - m >= 0 then
        Solution (tyVar, (LitTy (NumTy (n - m))))
      else
        AbsurdSolution
    _ ->
      NoSolution

-- | Given the left and right side of an equation, normalize it such that
-- equations of the following forms:
--
--     * 5     ~ n + 2
--     * 5     ~ 2 + n
--     * n + 2 ~ 5
--     * 2 + n ~ 5
--
-- are returned as (5, 2, n)
normalizeAdd
  :: (Type, Type)
  -> Maybe (Integer, Integer, Type)
normalizeAdd (a, b) = do
  (n, rhs) <- lhsLit a b
  case tyView rhs of
    TyConApp (nameOcc -> "GHC.TypeNats.+") [left, right] -> do
      (m, o) <- lhsLit left right
      return (n, m, o)
    _ ->
      Nothing
 where
  lhsLit x                 (LitTy (NumTy n)) = Just (n, x)
  lhsLit (LitTy (NumTy n)) y                 = Just (n, y)
  lhsLit _                 _                 = Nothing

-- | Tests for unreachable alternative due to types being "absurd". See
-- @isAbsurdEq@ for more info.
isAbsurdAlt
  :: TyConMap
  -> Alt
  -> Bool
isAbsurdAlt tcm alt =
  any (isAbsurdEq tcm) (altEqs tcm alt)

-- | Determines if an "equation" obtained through @altEqs@ or @typeEq@ is
-- absurd. That is, it tests if two types that are definitely not equal are
-- asserted to be equal OR if the computation of the types yield some absurd
-- (intermediate) result such as -1.
isAbsurdEq
  :: TyConMap
  -> (Type, Type)
  -> Bool
isAbsurdEq tcm ((left0, right0)) =
  case (coreView tcm left0, coreView tcm right0) of
    (solveAdd -> AbsurdSolution) -> True
    lr -> any (==AbsurdSolution) (solveEq tcm lr)

-- | Get constraint equations
altEqs
  :: TyConMap
  -> Alt
  -> [(Type, Type)]
altEqs tcm (pat, _term) =
 catMaybes (map (typeEq tcm . varType) (snd (patIds pat)))

-- | If type is an equation, return LHS and RHS.
typeEq
  :: TyConMap
  -> Type
  -> Maybe (Type, Type)
typeEq tcm ty =
 case tyView (coreView tcm ty) of
  TyConApp (nameOcc -> "GHC.Prim.~#") [_, _, left, right] ->
    Just (coreView tcm left, coreView tcm right)
  _ ->
    Nothing