-- |
-- Module      :  Cryptol.TypeCheck.Unify
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
module Cryptol.TypeCheck.Unify where

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Cryptol.Utils.RecordMap

import Control.Monad.Writer (Writer, writer, runWriter)
import qualified Data.Set as Set

import Prelude ()
import Prelude.Compat

-- | The most general unifier is a substitution and a set of constraints
-- on bound variables.
type MGU = (Subst,[Prop])

type Result a = Writer [UnificationError] a

runResult :: Result a -> (a, [UnificationError])
runResult = runWriter

data UnificationError
  = UniTypeMismatch Type Type
  | UniKindMismatch Kind Kind
  | UniTypeLenMismatch Int Int
  | UniRecursive TVar Type
  | UniNonPolyDepends TVar [TParam]
  | UniNonPoly TVar Type

uniError :: UnificationError -> Result MGU
uniError e = writer (emptyMGU, [e])


emptyMGU :: MGU
emptyMGU = (emptySubst, [])

mgu :: Type -> Type -> Result MGU

mgu (TUser c1 ts1 _) (TUser c2 ts2 _)
  | c1 == c2 && ts1 == ts2  = return emptyMGU

mgu (TVar x) t        = bindVar x t
mgu t (TVar x)        = bindVar x t

mgu (TUser _ _ t1) t2 = mgu t1 t2
mgu t1 (TUser _ _ t2) = mgu t1 t2

mgu (TCon (TC tc1) ts1) (TCon (TC tc2) ts2)
  | tc1 == tc2 = mguMany ts1 ts2

mgu (TCon (TF f1) ts1) (TCon (TF f2) ts2)
  | f1 == f2 && ts1 == ts2  = return emptyMGU

mgu t1 t2
  | TCon (TF _) _ <- t1, isNum, k1 == k2 = return (emptySubst, [t1 =#= t2])
  | TCon (TF _) _ <- t2, isNum, k1 == k2 = return (emptySubst, [t1 =#= t2])
  where
  k1 = kindOf t1
  k2 = kindOf t2

  isNum = k1 == KNum

mgu (TRec fs1) (TRec fs2)
  | fieldSet fs1 == fieldSet fs2 = mguMany (recordElements fs1) (recordElements fs2)

mgu t1 t2
  | not (k1 == k2)  = uniError $ UniKindMismatch k1 k2
  | otherwise       = uniError $ UniTypeMismatch t1 t2
  where
  k1 = kindOf t1
  k2 = kindOf t2


mguMany :: [Type] -> [Type] -> Result MGU
mguMany [] [] = return emptyMGU
mguMany (t1 : ts1) (t2 : ts2) =
  do (su1,ps1) <- mgu t1 t2
     (su2,ps2) <- mguMany (apSubst su1 ts1) (apSubst su1 ts2)
     return (su2 @@ su1, ps1 ++ ps2)
mguMany t1 t2 = uniError $ UniTypeLenMismatch (length t1) (length t2)


bindVar :: TVar -> Type -> Result MGU

bindVar x (tNoUser -> TVar y)
  | x == y                      = return emptyMGU

bindVar v@(TVBound {}) (tNoUser -> TVar v1@(TVFree {})) = bindVar v1 (TVar v)

bindVar v@(TVBound {}) t
  | k == kindOf t = if k == KNum
                       then return (emptySubst, [TVar v =#= t])
                       else uniError $ UniNonPoly v t
  | otherwise     = uniError $ UniKindMismatch k (kindOf t)
  where k = kindOf v

bindVar x@(TVFree _ xk xscope _) (tNoUser -> TVar y@(TVFree _ yk yscope _))
  | xscope `Set.isProperSubsetOf` yscope, xk == yk =
    return (uncheckedSingleSubst y (TVar x), [])
    -- In this case, we can add the reverse binding y ~> x to the
    -- substitution, but the instantiation x ~> y would be forbidden
    -- because it would allow y to escape from its scope.

bindVar x t =
  case singleSubst x t of
    Left SubstRecursive
      | kindOf x == KType -> uniError $ UniRecursive x t
      | otherwise -> return (emptySubst, [TVar x =#= t])
    Left (SubstEscaped tps) ->
      uniError $ UniNonPolyDepends x tps
    Left (SubstKindMismatch k1 k2) ->
      uniError $ UniKindMismatch k1 k2
    Right su ->
      return (su, [])