{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE DerivingVia        #-}

{-# OPTIONS_GHC -Wno-incomplete-patterns #-} -- TODO(#1918): Only needed for GHC <9.0.1.
{-# OPTIONS_GHC -Wno-orphans #-}

module Language.Haskell.Liquid.Types.Variance (
  Variance(..), VarianceInfo, makeTyConVariance, flipVariance
  ) where

import Prelude hiding (error)
import Control.DeepSeq
import Data.Typeable hiding (TyCon)
import Data.Data     hiding (TyCon)
import GHC.Generics
import Data.Binary
import Data.Hashable
import Text.PrettyPrint.HughesPJ

import           Data.Maybe                (fromJust)
import qualified Data.List               as L
import qualified Data.HashSet            as S

import qualified Language.Fixpoint.Types as F

import qualified Language.Haskell.Liquid.GHC.Misc as GM
import           Liquid.GHC.API        as Ghc hiding (text)

type VarianceInfo = [Variance]

data Variance = Invariant | Bivariant | Contravariant | Covariant
              deriving (Variance -> Variance -> Bool
(Variance -> Variance -> Bool)
-> (Variance -> Variance -> Bool) -> Eq Variance
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Variance -> Variance -> Bool
== :: Variance -> Variance -> Bool
$c/= :: Variance -> Variance -> Bool
/= :: Variance -> Variance -> Bool
Eq, Typeable Variance
Typeable Variance =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> Variance -> c Variance)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c Variance)
-> (Variance -> Constr)
-> (Variance -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c Variance))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Variance))
-> ((forall b. Data b => b -> b) -> Variance -> Variance)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> Variance -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> Variance -> r)
-> (forall u. (forall d. Data d => d -> u) -> Variance -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> Variance -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> Variance -> m Variance)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Variance -> m Variance)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Variance -> m Variance)
-> Data Variance
Variance -> Constr
Variance -> DataType
(forall b. Data b => b -> b) -> Variance -> Variance
forall a.
Typeable a =>
(forall (c :: * -> *).
 (forall d b. Data d => c (d -> b) -> d -> c b)
 -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Variance -> u
forall u. (forall d. Data d => d -> u) -> Variance -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Variance -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Variance -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c Variance
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Variance -> c Variance
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c Variance)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Variance)
$cgfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Variance -> c Variance
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Variance -> c Variance
$cgunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c Variance
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c Variance
$ctoConstr :: Variance -> Constr
toConstr :: Variance -> Constr
$cdataTypeOf :: Variance -> DataType
dataTypeOf :: Variance -> DataType
$cdataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c Variance)
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c Variance)
$cdataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Variance)
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Variance)
$cgmapT :: (forall b. Data b => b -> b) -> Variance -> Variance
gmapT :: (forall b. Data b => b -> b) -> Variance -> Variance
$cgmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Variance -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Variance -> r
$cgmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Variance -> r
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Variance -> r
$cgmapQ :: forall u. (forall d. Data d => d -> u) -> Variance -> [u]
gmapQ :: forall u. (forall d. Data d => d -> u) -> Variance -> [u]
$cgmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Variance -> u
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Variance -> u
$cgmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
$cgmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
$cgmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Variance -> m Variance
Data, Typeable, Int -> Variance -> ShowS
[Variance] -> ShowS
Variance -> String
(Int -> Variance -> ShowS)
-> (Variance -> String) -> ([Variance] -> ShowS) -> Show Variance
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Variance -> ShowS
showsPrec :: Int -> Variance -> ShowS
$cshow :: Variance -> String
show :: Variance -> String
$cshowList :: [Variance] -> ShowS
showList :: [Variance] -> ShowS
Show, (forall x. Variance -> Rep Variance x)
-> (forall x. Rep Variance x -> Variance) -> Generic Variance
forall x. Rep Variance x -> Variance
forall x. Variance -> Rep Variance x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Variance -> Rep Variance x
from :: forall x. Variance -> Rep Variance x
$cto :: forall x. Rep Variance x -> Variance
to :: forall x. Rep Variance x -> Variance
Generic)
              deriving Eq Variance
Eq Variance =>
(Int -> Variance -> Int) -> (Variance -> Int) -> Hashable Variance
Int -> Variance -> Int
Variance -> Int
forall a. Eq a => (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> Variance -> Int
hashWithSalt :: Int -> Variance -> Int
$chash :: Variance -> Int
hash :: Variance -> Int
Hashable via Generically Variance

flipVariance :: Variance -> Variance
flipVariance :: Variance -> Variance
flipVariance Variance
Invariant     = Variance
Invariant
flipVariance Variance
Bivariant     = Variance
Bivariant
flipVariance Variance
Contravariant = Variance
Covariant
flipVariance Variance
Covariant     = Variance
Contravariant

instance Semigroup Variance where
  Variance
Bivariant     <> :: Variance -> Variance -> Variance
<> Variance
_         = Variance
Bivariant
  Variance
_             <> Variance
Bivariant = Variance
Bivariant
  Variance
Invariant     <> Variance
v         = Variance
v
  Variance
v             <> Variance
Invariant = Variance
v
  Variance
Covariant     <> Variance
v         = Variance
v
  Variance
Contravariant <> Variance
v         = Variance -> Variance
flipVariance Variance
v

instance Monoid Variance where
  mempty :: Variance
mempty = Variance
Bivariant

instance Binary Variance
instance NFData Variance
instance F.PPrint Variance where
  pprintTidy :: Tidy -> Variance -> Doc
pprintTidy Tidy
_ = String -> Doc
text (String -> Doc) -> (Variance -> String) -> Variance -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variance -> String
forall a. Show a => a -> String
show



makeTyConVariance :: TyCon -> VarianceInfo
makeTyConVariance :: TyCon -> [Variance]
makeTyConVariance TyCon
tyCon = TyVar -> Variance
forall {a}. Outputable a => a -> Variance
varSignToVariance (TyVar -> Variance) -> [TyVar] -> [Variance]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVar]
tvs
  where
    tvs :: [TyVar]
tvs = TyCon -> [TyVar]
GM.tyConTyVarsDef TyCon
tyCon

    varsigns :: [(TyVar, Bool)]
varsigns = if TyCon -> Bool
Ghc.isTypeSynonymTyCon TyCon
tyCon
                  then Bool -> Type -> [(TyVar, Bool)]
go Bool
True (Maybe Type -> Type
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Type -> Type) -> Maybe Type -> Type
forall a b. (a -> b) -> a -> b
$ TyCon -> Maybe Type
Ghc.synTyConRhs_maybe TyCon
tyCon)
                  else [(TyVar, Bool)] -> [(TyVar, Bool)]
forall a. Eq a => [a] -> [a]
L.nub ([(TyVar, Bool)] -> [(TyVar, Bool)])
-> [(TyVar, Bool)] -> [(TyVar, Bool)]
forall a b. (a -> b) -> a -> b
$ (DataCon -> [(TyVar, Bool)]) -> [DataCon] -> [(TyVar, Bool)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap DataCon -> [(TyVar, Bool)]
goDCon ([DataCon] -> [(TyVar, Bool)]) -> [DataCon] -> [(TyVar, Bool)]
forall a b. (a -> b) -> a -> b
$ TyCon -> [DataCon]
Ghc.tyConDataCons TyCon
tyCon

    varSignToVariance :: a -> Variance
varSignToVariance a
v = case ((TyVar, Bool) -> Bool) -> [(TyVar, Bool)] -> [(TyVar, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(TyVar, Bool)
p -> TyVar -> String
forall a. Outputable a => a -> String
GM.showPpr ((TyVar, Bool) -> TyVar
forall a b. (a, b) -> a
fst (TyVar, Bool)
p) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== a -> String
forall a. Outputable a => a -> String
GM.showPpr a
v) [(TyVar, Bool)]
varsigns of
                            []       -> Variance
Invariant
                            [(TyVar
_, Bool
b)] -> if Bool
b then Variance
Covariant else Variance
Contravariant
                            [(TyVar, Bool)]
_        -> Variance
Bivariant


    goDCon :: DataCon -> [(TyVar, Bool)]
goDCon DataCon
dc = (Scaled Type -> [(TyVar, Bool)])
-> [Scaled Type] -> [(TyVar, Bool)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Bool -> Type -> [(TyVar, Bool)]
go Bool
True (Type -> [(TyVar, Bool)])
-> (Scaled Type -> Type) -> Scaled Type -> [(TyVar, Bool)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scaled Type -> Type
forall a. Scaled a -> a
irrelevantMult) (DataCon -> [Scaled Type]
Ghc.dataConOrigArgTys DataCon
dc)

    go :: Bool -> Type -> [(TyVar, Bool)]
go Bool
pos (FunTy FunTyFlag
_ Type
_ Type
t1 Type
t2) = Bool -> Type -> [(TyVar, Bool)]
go (Bool -> Bool
not Bool
pos) Type
t1 [(TyVar, Bool)] -> [(TyVar, Bool)] -> [(TyVar, Bool)]
forall a. [a] -> [a] -> [a]
++ Bool -> Type -> [(TyVar, Bool)]
go Bool
pos Type
t2
    go Bool
pos (ForAllTy ForAllTyBinder
_ Type
t)    = Bool -> Type -> [(TyVar, Bool)]
go Bool
pos Type
t
    go Bool
pos (TyVarTy TyVar
v)       = [(TyVar
v, Bool
pos)]
    go Bool
pos (AppTy Type
t1 Type
t2)     = Bool -> Type -> [(TyVar, Bool)]
go Bool
pos Type
t1 [(TyVar, Bool)] -> [(TyVar, Bool)] -> [(TyVar, Bool)]
forall a. [a] -> [a] -> [a]
++ Bool -> Type -> [(TyVar, Bool)]
go Bool
pos Type
t2
    go Bool
pos (TyConApp TyCon
c' [Type]
ts)
       | TyCon
tyCon TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
c'
       = []

-- NV fix that: what happens if we have mutually recursive data types?
-- now just provide "default" Bivariant for mutually rec types.
-- but there should be a finer solution
       | TyCon -> TyCon -> Bool
mutuallyRecursive TyCon
tyCon TyCon
c'
       = (Type -> [(TyVar, Bool)]) -> [Type] -> [(TyVar, Bool)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Bool -> Variance -> Type -> [(TyVar, Bool)]
goTyConApp Bool
pos Variance
Bivariant) [Type]
ts
       | Bool
otherwise
       = [[(TyVar, Bool)]] -> [(TyVar, Bool)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(TyVar, Bool)]] -> [(TyVar, Bool)])
-> [[(TyVar, Bool)]] -> [(TyVar, Bool)]
forall a b. (a -> b) -> a -> b
$ (Variance -> Type -> [(TyVar, Bool)])
-> [Variance] -> [Type] -> [[(TyVar, Bool)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Bool -> Variance -> Type -> [(TyVar, Bool)]
goTyConApp Bool
pos) (TyCon -> [Variance]
makeTyConVariance TyCon
c') [Type]
ts

    go Bool
_   (LitTy TyLit
_)       = []
    go Bool
_   (CoercionTy Coercion
_)  = []
    go Bool
pos (CastTy Type
t Coercion
_)    = Bool -> Type -> [(TyVar, Bool)]
go Bool
pos Type
t

    goTyConApp :: Bool -> Variance -> Type -> [(TyVar, Bool)]
goTyConApp Bool
_   Variance
Invariant     Type
_ = []
    goTyConApp Bool
pos Variance
Bivariant     Type
t = Bool -> Variance -> Type -> [(TyVar, Bool)]
goTyConApp Bool
pos Variance
Contravariant Type
t [(TyVar, Bool)] -> [(TyVar, Bool)] -> [(TyVar, Bool)]
forall a. [a] -> [a] -> [a]
++ Bool -> Variance -> Type -> [(TyVar, Bool)]
goTyConApp Bool
pos Variance
Covariant Type
t
    goTyConApp Bool
pos Variance
Covariant     Type
t = Bool -> Type -> [(TyVar, Bool)]
go Bool
pos       Type
t
    goTyConApp Bool
pos Variance
Contravariant Type
t = Bool -> Type -> [(TyVar, Bool)]
go (Bool -> Bool
not Bool
pos) Type
t

    mutuallyRecursive :: TyCon -> TyCon -> Bool
mutuallyRecursive TyCon
c TyCon
c' = TyCon
c TyCon -> HashSet TyCon -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` TyCon -> HashSet TyCon
dataConsOfTyCon TyCon
c'


dataConsOfTyCon :: TyCon -> S.HashSet TyCon
dataConsOfTyCon :: TyCon -> HashSet TyCon
dataConsOfTyCon = HashSet TyCon -> TyCon -> HashSet TyCon
dcs HashSet TyCon
forall a. HashSet a
S.empty
  where
    dcs :: HashSet TyCon -> TyCon -> HashSet TyCon
dcs HashSet TyCon
vis TyCon
c                 = [HashSet TyCon] -> HashSet TyCon
forall a. Monoid a => [a] -> a
mconcat ([HashSet TyCon] -> HashSet TyCon)
-> [HashSet TyCon] -> HashSet TyCon
forall a b. (a -> b) -> a -> b
$ HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis (Type -> HashSet TyCon) -> [Type] -> [HashSet TyCon]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scaled Type -> Type
forall a. Scaled a -> a
irrelevantMult Scaled Type
t | DataCon
dc <- TyCon -> [DataCon]
Ghc.tyConDataCons TyCon
c, Scaled Type
t <- DataCon -> [Scaled Type]
Ghc.dataConOrigArgTys DataCon
dc]
    go :: HashSet TyCon -> Type -> HashSet TyCon
go  HashSet TyCon
vis (FunTy FunTyFlag
_ Type
_ Type
t1 Type
t2) = HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis Type
t1 HashSet TyCon -> HashSet TyCon -> HashSet TyCon
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
`S.union` HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis Type
t2
    go  HashSet TyCon
vis (ForAllTy ForAllTyBinder
_ Type
t)    = HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis Type
t
    go  HashSet TyCon
_   (TyVarTy TyVar
_)       = HashSet TyCon
forall a. HashSet a
S.empty
    go  HashSet TyCon
vis (AppTy Type
t1 Type
t2)     = HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis Type
t1 HashSet TyCon -> HashSet TyCon -> HashSet TyCon
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
`S.union` HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis Type
t2
    go  HashSet TyCon
vis (TyConApp TyCon
c [Type]
ts)
      | TyCon
c TyCon -> HashSet TyCon -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet TyCon
vis
      = HashSet TyCon
forall a. HashSet a
S.empty
      | Bool
otherwise
      = TyCon -> HashSet TyCon -> HashSet TyCon
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
S.insert TyCon
c ([HashSet TyCon] -> HashSet TyCon
forall a. Monoid a => [a] -> a
mconcat ([HashSet TyCon] -> HashSet TyCon)
-> [HashSet TyCon] -> HashSet TyCon
forall a b. (a -> b) -> a -> b
$ HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis (Type -> HashSet TyCon) -> [Type] -> [HashSet TyCon]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
ts) HashSet TyCon -> HashSet TyCon -> HashSet TyCon
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
`S.union` HashSet TyCon -> TyCon -> HashSet TyCon
dcs (TyCon -> HashSet TyCon -> HashSet TyCon
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
S.insert TyCon
c HashSet TyCon
vis) TyCon
c
    go  HashSet TyCon
_   (LitTy TyLit
_)       = HashSet TyCon
forall a. HashSet a
S.empty
    go  HashSet TyCon
_   (CoercionTy Coercion
_)  = HashSet TyCon
forall a. HashSet a
S.empty
    go  HashSet TyCon
vis (CastTy Type
t Coercion
_)    = HashSet TyCon -> Type -> HashSet TyCon
go HashSet TyCon
vis Type
t