{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}

-- | Implementation of unification and other core type system building
-- blocks.
module Language.Futhark.TypeChecker.Unify
  ( Constraint (..),
    Usage,
    mkUsage,
    mkUsage',
    Level,
    Constraints,
    MonadUnify (..),
    Rigidity (..),
    RigidSource (..),
    BreadCrumbs,
    noBreadCrumbs,
    hasNoBreadCrumbs,
    dimNotes,
    mkTypeVarName,
    zeroOrderType,
    arrayElemType,
    mustHaveConstr,
    mustHaveField,
    mustBeOneOf,
    equalityType,
    normType,
    normPatternType,
    normTypeFully,
    instantiateEmptyArrayDims,
    unify,
    expect,
    unifyMostCommon,
    anyDimOnMismatch,
    doUnification,
  )
where

import Control.Monad.Except
import Control.Monad.RWS.Strict hiding (Sum)
import Control.Monad.State
import Control.Monad.Writer hiding (Sum)
import Data.Bifoldable (biany)
import Data.List (intersect)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.Util.Pretty hiding (empty)
import Language.Futhark hiding (unifyDims)
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Types

-- | A piece of information that describes what process the type
-- checker currently performing.  This is used to give better error
-- messages for unification errors.
data BreadCrumb
  = MatchingTypes StructType StructType
  | MatchingFields [Name]
  | MatchingConstructor Name
  | Matching Doc

instance Pretty BreadCrumb where
  ppr :: BreadCrumb -> Doc
ppr (MatchingTypes StructType
t1 StructType
t2) =
    Doc
"When matching type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1)
      Doc -> Doc -> Doc
</> Doc
"with"
      Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2)
  ppr (MatchingFields [Name]
fields) =
    Doc
"When matching types of record field"
      Doc -> Doc -> Doc
<+> Doc -> Doc
pquote ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
punctuate Doc
"." ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
forall a. Pretty a => a -> Doc
ppr [Name]
fields) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
dot
  ppr (MatchingConstructor Name
c) =
    Doc
"When matching types of constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
dot
  ppr (Matching Doc
s) =
    Doc
s

-- | Unification failures can occur deep down inside complicated types
-- (consider nested records).  We leave breadcrumbs behind us so we
-- can report the path we took to find the mismatch.
newtype BreadCrumbs = BreadCrumbs [BreadCrumb]

-- | An empty path.
noBreadCrumbs :: BreadCrumbs
noBreadCrumbs :: BreadCrumbs
noBreadCrumbs = [BreadCrumb] -> BreadCrumbs
BreadCrumbs []

-- | Is the path empty?
hasNoBreadCrumbs :: BreadCrumbs -> Bool
hasNoBreadCrumbs :: BreadCrumbs -> Bool
hasNoBreadCrumbs (BreadCrumbs [BreadCrumb]
xs) = [BreadCrumb] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [BreadCrumb]
xs

-- | Drop a breadcrumb on the path behind you.
breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (MatchingFields [Name]
xs) (BreadCrumbs (MatchingFields [Name]
ys : [BreadCrumb]
bcs)) =
  [BreadCrumb] -> BreadCrumbs
BreadCrumbs ([BreadCrumb] -> BreadCrumbs) -> [BreadCrumb] -> BreadCrumbs
forall a b. (a -> b) -> a -> b
$ [Name] -> BreadCrumb
MatchingFields ([Name]
ys [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ [Name]
xs) BreadCrumb -> [BreadCrumb] -> [BreadCrumb]
forall a. a -> [a] -> [a]
: [BreadCrumb]
bcs
breadCrumb BreadCrumb
bc (BreadCrumbs [BreadCrumb]
bcs) =
  [BreadCrumb] -> BreadCrumbs
BreadCrumbs ([BreadCrumb] -> BreadCrumbs) -> [BreadCrumb] -> BreadCrumbs
forall a b. (a -> b) -> a -> b
$ BreadCrumb
bc BreadCrumb -> [BreadCrumb] -> [BreadCrumb]
forall a. a -> [a] -> [a]
: [BreadCrumb]
bcs

instance Pretty BreadCrumbs where
  ppr :: BreadCrumbs -> Doc
ppr (BreadCrumbs []) = Doc
forall a. Monoid a => a
mempty
  ppr (BreadCrumbs [BreadCrumb]
bcs) = Doc
line Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
stack ((BreadCrumb -> Doc) -> [BreadCrumb] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map BreadCrumb -> Doc
forall a. Pretty a => a -> Doc
ppr [BreadCrumb]
bcs)

-- | A usage that caused a type constraint.
data Usage = Usage (Maybe String) SrcLoc
  deriving (Int -> Usage -> ShowS
[Usage] -> ShowS
Usage -> String
(Int -> Usage -> ShowS)
-> (Usage -> String) -> ([Usage] -> ShowS) -> Show Usage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Usage] -> ShowS
$cshowList :: [Usage] -> ShowS
show :: Usage -> String
$cshow :: Usage -> String
showsPrec :: Int -> Usage -> ShowS
$cshowsPrec :: Int -> Usage -> ShowS
Show)

-- | Construct a 'Usage' from a location and a description.
mkUsage :: SrcLoc -> String -> Usage
mkUsage :: SrcLoc -> String -> Usage
mkUsage = (String -> SrcLoc -> Usage) -> SrcLoc -> String -> Usage
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Maybe String -> SrcLoc -> Usage
Usage (Maybe String -> SrcLoc -> Usage)
-> (String -> Maybe String) -> String -> SrcLoc -> Usage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe String
forall a. a -> Maybe a
Just)

-- | Construct a 'Usage' that has just a location, but no particular
-- description.
mkUsage' :: SrcLoc -> Usage
mkUsage' :: SrcLoc -> Usage
mkUsage' = Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing

instance Pretty Usage where
  ppr :: Usage -> Doc
ppr (Usage Maybe String
Nothing SrcLoc
loc) = Doc
"use at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
textwrap (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc)
  ppr (Usage (Just String
s) SrcLoc
loc) = String -> Doc
textwrap String
s Doc -> Doc -> Doc
<+/> Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
textwrap (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc)

instance Located Usage where
  locOf :: Usage -> Loc
locOf (Usage Maybe String
_ SrcLoc
loc) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
loc

-- | The level at which a type variable is bound.  Higher means
-- deeper.  We can only unify a type variable at level @i@ with a type
-- @t@ if all type names that occur in @t@ are at most at level @i@.
type Level = Int

-- | A constraint on a yet-ambiguous type variable.
data Constraint
  = NoConstraint Liftedness Usage
  | ParamType Liftedness SrcLoc
  | Constraint StructType Usage
  | Overloaded [PrimType] Usage
  | HasFields (M.Map Name StructType) Usage
  | Equality Usage
  | HasConstrs (M.Map Name [StructType]) Usage
  | ParamSize SrcLoc
  | -- | Is not actually a type, but a term-level size,
    -- possibly already set to something specific.
    Size (Maybe (DimDecl VName)) Usage
  | -- | A size that does not unify with anything -
    -- created from the result of applying a function
    -- whose return size is existential, or otherwise
    -- hiding a size.
    UnknowableSize SrcLoc RigidSource
  deriving (Int -> Constraint -> ShowS
[Constraint] -> ShowS
Constraint -> String
(Int -> Constraint -> ShowS)
-> (Constraint -> String)
-> ([Constraint] -> ShowS)
-> Show Constraint
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Constraint] -> ShowS
$cshowList :: [Constraint] -> ShowS
show :: Constraint -> String
$cshow :: Constraint -> String
showsPrec :: Int -> Constraint -> ShowS
$cshowsPrec :: Int -> Constraint -> ShowS
Show)

instance Located Constraint where
  locOf :: Constraint -> Loc
locOf (NoConstraint Liftedness
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (ParamType Liftedness
_ SrcLoc
usage) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
usage
  locOf (Constraint StructType
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (Overloaded [PrimType]
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (HasFields Map Name StructType
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (Equality Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (HasConstrs Map Name [StructType]
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (ParamSize SrcLoc
loc) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
loc
  locOf (Size Maybe (DimDecl VName)
_ Usage
usage) = Usage -> Loc
forall a. Located a => a -> Loc
locOf Usage
usage
  locOf (UnknowableSize SrcLoc
loc RigidSource
_) = SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf SrcLoc
loc

-- | Mapping from fresh type variables, instantiated from the type
-- schemes of polymorphic functions, to (possibly) specific types as
-- determined on application and the location of that application, or
-- a partial constraint on their type.
type Constraints = M.Map VName (Level, Constraint)

lookupSubst :: VName -> Constraints -> Maybe (Subst StructType)
lookupSubst :: VName -> Constraints -> Maybe (Subst StructType)
lookupSubst VName
v Constraints
constraints = case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
  Just (Constraint StructType
t Usage
_) -> Subst StructType -> Maybe (Subst StructType)
forall a. a -> Maybe a
Just (Subst StructType -> Maybe (Subst StructType))
-> Subst StructType -> Maybe (Subst StructType)
forall a b. (a -> b) -> a -> b
$ StructType -> Subst StructType
forall t. t -> Subst t
Subst StructType
t
  Just Overloaded {} -> Subst StructType -> Maybe (Subst StructType)
forall a. a -> Maybe a
Just Subst StructType
forall t. Subst t
PrimSubst
  Just (Size (Just DimDecl VName
d) Usage
_) ->
    Subst StructType -> Maybe (Subst StructType)
forall a. a -> Maybe a
Just (Subst StructType -> Maybe (Subst StructType))
-> Subst StructType -> Maybe (Subst StructType)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst StructType
forall t. DimDecl VName -> Subst t
SizeSubst (DimDecl VName -> Subst StructType)
-> DimDecl VName -> Subst StructType
forall a b. (a -> b) -> a -> b
$ (VName -> Maybe (Subst StructType))
-> DimDecl VName -> DimDecl VName
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) DimDecl VName
d
  Maybe Constraint
_ -> Maybe (Subst StructType)
forall a. Maybe a
Nothing

-- | The source of a rigid size.
data RigidSource
  = -- | A function argument that is not a constant or variable name.
    RigidArg (Maybe (QualName VName)) String
  | -- | An existential return size.
    RigidRet (Maybe (QualName VName))
  | RigidLoop
  | -- | Produced by a complicated slice expression.
    RigidSlice (Maybe (DimDecl VName)) String
  | -- | Produced by a complicated range expression.
    RigidRange
  | -- | Produced by a range expression with this bound.
    RigidBound String
  | -- | Mismatch in branches.
    RigidCond StructType StructType
  | -- | Invented during unification.
    RigidUnify
  | RigidOutOfScope SrcLoc VName
  deriving (RigidSource -> RigidSource -> Bool
(RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool) -> Eq RigidSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RigidSource -> RigidSource -> Bool
$c/= :: RigidSource -> RigidSource -> Bool
== :: RigidSource -> RigidSource -> Bool
$c== :: RigidSource -> RigidSource -> Bool
Eq, Eq RigidSource
Eq RigidSource
-> (RigidSource -> RigidSource -> Ordering)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> Bool)
-> (RigidSource -> RigidSource -> RigidSource)
-> (RigidSource -> RigidSource -> RigidSource)
-> Ord RigidSource
RigidSource -> RigidSource -> Bool
RigidSource -> RigidSource -> Ordering
RigidSource -> RigidSource -> RigidSource
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RigidSource -> RigidSource -> RigidSource
$cmin :: RigidSource -> RigidSource -> RigidSource
max :: RigidSource -> RigidSource -> RigidSource
$cmax :: RigidSource -> RigidSource -> RigidSource
>= :: RigidSource -> RigidSource -> Bool
$c>= :: RigidSource -> RigidSource -> Bool
> :: RigidSource -> RigidSource -> Bool
$c> :: RigidSource -> RigidSource -> Bool
<= :: RigidSource -> RigidSource -> Bool
$c<= :: RigidSource -> RigidSource -> Bool
< :: RigidSource -> RigidSource -> Bool
$c< :: RigidSource -> RigidSource -> Bool
compare :: RigidSource -> RigidSource -> Ordering
$ccompare :: RigidSource -> RigidSource -> Ordering
Ord, Int -> RigidSource -> ShowS
[RigidSource] -> ShowS
RigidSource -> String
(Int -> RigidSource -> ShowS)
-> (RigidSource -> String)
-> ([RigidSource] -> ShowS)
-> Show RigidSource
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RigidSource] -> ShowS
$cshowList :: [RigidSource] -> ShowS
show :: RigidSource -> String
$cshow :: RigidSource -> String
showsPrec :: Int -> RigidSource -> ShowS
$cshowsPrec :: Int -> RigidSource -> ShowS
Show)

-- | The ridigity of a size variable.  All rigid sizes are tagged with
-- information about how they were generated.
data Rigidity = Rigid RigidSource | Nonrigid
  deriving (Rigidity -> Rigidity -> Bool
(Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool) -> Eq Rigidity
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Rigidity -> Rigidity -> Bool
$c/= :: Rigidity -> Rigidity -> Bool
== :: Rigidity -> Rigidity -> Bool
$c== :: Rigidity -> Rigidity -> Bool
Eq, Eq Rigidity
Eq Rigidity
-> (Rigidity -> Rigidity -> Ordering)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Bool)
-> (Rigidity -> Rigidity -> Rigidity)
-> (Rigidity -> Rigidity -> Rigidity)
-> Ord Rigidity
Rigidity -> Rigidity -> Bool
Rigidity -> Rigidity -> Ordering
Rigidity -> Rigidity -> Rigidity
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Rigidity -> Rigidity -> Rigidity
$cmin :: Rigidity -> Rigidity -> Rigidity
max :: Rigidity -> Rigidity -> Rigidity
$cmax :: Rigidity -> Rigidity -> Rigidity
>= :: Rigidity -> Rigidity -> Bool
$c>= :: Rigidity -> Rigidity -> Bool
> :: Rigidity -> Rigidity -> Bool
$c> :: Rigidity -> Rigidity -> Bool
<= :: Rigidity -> Rigidity -> Bool
$c<= :: Rigidity -> Rigidity -> Bool
< :: Rigidity -> Rigidity -> Bool
$c< :: Rigidity -> Rigidity -> Bool
compare :: Rigidity -> Rigidity -> Ordering
$ccompare :: Rigidity -> Rigidity -> Ordering
Ord, Int -> Rigidity -> ShowS
[Rigidity] -> ShowS
Rigidity -> String
(Int -> Rigidity -> ShowS)
-> (Rigidity -> String) -> ([Rigidity] -> ShowS) -> Show Rigidity
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Rigidity] -> ShowS
$cshowList :: [Rigidity] -> ShowS
show :: Rigidity -> String
$cshow :: Rigidity -> String
showsPrec :: Int -> Rigidity -> ShowS
$cshowsPrec :: Int -> Rigidity -> ShowS
Show)

prettySource :: SrcLoc -> SrcLoc -> RigidSource -> Doc
prettySource :: SrcLoc -> SrcLoc -> RigidSource -> Doc
prettySource SrcLoc
ctx SrcLoc
loc (RigidRet Maybe (QualName VName)
Nothing) =
  Doc
"is unknown size returned by function at"
    Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
prettySource SrcLoc
ctx SrcLoc
loc (RigidRet (Just QualName VName
fname)) =
  Doc
"is unknown size returned by" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
fname)
    Doc -> Doc -> Doc
<+> Doc
"at"
    Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
prettySource SrcLoc
ctx SrcLoc
loc (RigidArg Maybe (QualName VName)
fname String
arg) =
  Doc
"is value of argument"
    Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (String -> Doc
forall a. Pretty a => a -> Doc
shorten String
arg)
    Doc -> Doc -> Doc
</> Doc
"passed to" Doc -> Doc -> Doc
<+> Doc
fname' Doc -> Doc -> Doc
<+> Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  where
    fname' :: Doc
fname' = Doc -> (QualName VName -> Doc) -> Maybe (QualName VName) -> Doc
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Doc
"function" (Doc -> Doc
pquote (Doc -> Doc) -> (QualName VName -> Doc) -> QualName VName -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr) Maybe (QualName VName)
fname
prettySource SrcLoc
ctx SrcLoc
loc (RigidSlice Maybe (DimDecl VName)
d String
slice) =
  Doc
"is size produced by slice"
    Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (String -> Doc
forall a. Pretty a => a -> Doc
shorten String
slice)
    Doc -> Doc -> Doc
</> Doc
d_desc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
  where
    d_desc :: Doc
d_desc = case Maybe (DimDecl VName)
d of
      Just DimDecl VName
d' -> Doc
"of dimension of size " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d') Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
" "
      Maybe (DimDecl VName)
Nothing -> Doc
forall a. Monoid a => a
mempty
prettySource SrcLoc
ctx SrcLoc
loc RigidSource
RigidLoop =
  Doc
"is unknown size of value returned at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
prettySource SrcLoc
ctx SrcLoc
loc RigidSource
RigidRange =
  Doc
"is unknown length of range at" Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
prettySource SrcLoc
ctx SrcLoc
loc (RigidBound String
bound) =
  Doc
"generated from expression"
    Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (String -> Doc
forall a. Pretty a => a -> Doc
shorten String
bound)
    Doc -> Doc -> Doc
</> Doc
"used in range at " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
prettySource SrcLoc
ctx SrcLoc
loc (RigidOutOfScope SrcLoc
boundloc VName
v) =
  Doc
"is an unknown size arising from " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
v)
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
" going out of scope at "
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc)
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Doc -> Doc -> Doc
</> Doc
"Originally bound at "
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
boundloc)
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
prettySource SrcLoc
_ SrcLoc
_ RigidSource
RigidUnify =
  Doc
"is an artificial size invented during unification of functions with anonymous sizes"
prettySource SrcLoc
ctx SrcLoc
loc (RigidCond StructType
t1 StructType
t2) =
  Doc
"is unknown due to conditional expression at "
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (SrcLoc -> SrcLoc -> String
forall a b. (Located a, Located b) => a -> b -> String
locStrRel SrcLoc
ctx SrcLoc
loc)
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Doc -> Doc -> Doc
</> Doc
"One branch returns array of type: "
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1)
    Doc -> Doc -> Doc
</> Doc
"The other an array of type:       "
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
align (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2)

-- | Retrieve notes describing the purpose or origin of the given
-- 'DimDecl'.  The location is used as the *current* location, for the
-- purpose of reporting relative locations.
dimNotes :: (Located a, MonadUnify m) => a -> DimDecl VName -> m Notes
dimNotes :: forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes a
ctx (NamedDim QualName VName
d) = do
  Maybe (Int, Constraint)
c <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d) (Constraints -> Maybe (Int, Constraint))
-> m Constraints -> m (Maybe (Int, Constraint))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case Maybe (Int, Constraint)
c of
    Just (Int
_, UnknowableSize SrcLoc
loc RigidSource
rsrc) ->
      Notes -> m Notes
forall (m :: * -> *) a. Monad m => a -> m a
return (Notes -> m Notes) -> Notes -> m Notes
forall a b. (a -> b) -> a -> b
$
        String -> Notes
forall a. Pretty a => a -> Notes
aNote (String -> Notes) -> String -> Notes
forall a b. (a -> b) -> a -> b
$
          Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
            Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
d) Doc -> Doc -> Doc
<+> SrcLoc -> SrcLoc -> RigidSource -> Doc
prettySource (a -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf a
ctx) SrcLoc
loc RigidSource
rsrc
    Maybe (Int, Constraint)
_ -> Notes -> m Notes
forall (m :: * -> *) a. Monad m => a -> m a
return Notes
forall a. Monoid a => a
mempty
dimNotes a
_ DimDecl VName
_ = Notes -> m Notes
forall (m :: * -> *) a. Monad m => a -> m a
return Notes
forall a. Monoid a => a
mempty

typeNotes :: (Located a, MonadUnify m) => a -> StructType -> m Notes
typeNotes :: forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> StructType -> m Notes
typeNotes a
ctx =
  ([Notes] -> Notes) -> m [Notes] -> m Notes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Notes] -> Notes
forall a. Monoid a => [a] -> a
mconcat (m [Notes] -> m Notes)
-> (StructType -> m [Notes]) -> StructType -> m Notes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> m Notes) -> [VName] -> m [Notes]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (a -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes a
ctx (DimDecl VName -> m Notes)
-> (VName -> DimDecl VName) -> VName -> m Notes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> (VName -> QualName VName) -> VName -> DimDecl VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName)
    ([VName] -> m [Notes])
-> (StructType -> [VName]) -> StructType -> m [Notes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> [VName]
forall a. Set a -> [a]
S.toList
    (Set VName -> [VName])
-> (StructType -> Set VName) -> StructType -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames

-- | Monads that which to perform unification must implement this type
-- class.
class Monad m => MonadUnify m where
  getConstraints :: m Constraints
  putConstraints :: Constraints -> m ()
  modifyConstraints :: (Constraints -> Constraints) -> m ()
  modifyConstraints Constraints -> Constraints
f = do
    Constraints
x <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
    Constraints -> m ()
forall (m :: * -> *). MonadUnify m => Constraints -> m ()
putConstraints (Constraints -> m ()) -> Constraints -> m ()
forall a b. (a -> b) -> a -> b
$ Constraints -> Constraints
f Constraints
x

  newTypeVar :: Monoid als => SrcLoc -> String -> m (TypeBase dim als)
  newDimVar :: SrcLoc -> Rigidity -> String -> m VName

  curLevel :: m Level

  matchError ::
    Located loc =>
    loc ->
    Notes ->
    BreadCrumbs ->
    StructType ->
    StructType ->
    m a

  unifyError ::
    Located loc =>
    loc ->
    Notes ->
    BreadCrumbs ->
    Doc ->
    m a

-- | Replace all type variables with their substitution.
normTypeFully :: (Substitutable a, MonadUnify m) => a -> m a
normTypeFully :: forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully a
t = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ (VName -> Maybe (Subst StructType)) -> a -> a
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) a
t

-- | Replace any top-level type variable with its substitution.
normType :: MonadUnify m => StructType -> m StructType
normType :: forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType t :: StructType
t@(Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])) = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
    Just (Constraint StructType
t' Usage
_) -> StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t'
    Maybe Constraint
_ -> StructType -> m StructType
forall (m :: * -> *) a. Monad m => a -> m a
return StructType
t
normType StructType
t = StructType -> m StructType
forall (m :: * -> *) a. Monad m => a -> m a
return StructType
t

-- | Replace any top-level type variable with its substitution.
normPatternType :: MonadUnify m => PatternType -> m PatternType
normPatternType :: forall (m :: * -> *). MonadUnify m => PatternType -> m PatternType
normPatternType t :: PatternType
t@(Scalar (TypeVar Aliasing
als Uniqueness
u (TypeName [] VName
v) [])) = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
    Just (Constraint StructType
t' Usage
_) ->
      PatternType -> m PatternType
forall (m :: * -> *). MonadUnify m => PatternType -> m PatternType
normPatternType (PatternType -> m PatternType) -> PatternType -> m PatternType
forall a b. (a -> b) -> a -> b
$ StructType
t' StructType -> Uniqueness -> StructType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
u StructType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
als
    Maybe Constraint
_ -> PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t
normPatternType PatternType
t = PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t

rigidConstraint :: Constraint -> Bool
rigidConstraint :: Constraint -> Bool
rigidConstraint ParamType {} = Bool
True
rigidConstraint ParamSize {} = Bool
True
rigidConstraint UnknowableSize {} = Bool
True
rigidConstraint Constraint
_ = Bool
False

-- | Replace 'AnyDim' dimensions that occur as 'PosImmediate' or
-- 'PosParam' with a fresh 'NamedDim'.
instantiateEmptyArrayDims ::
  MonadUnify m =>
  SrcLoc ->
  String ->
  Rigidity ->
  TypeBase (DimDecl VName) als ->
  m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims :: forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
tloc String
desc Rigidity
r = WriterT [VName] m (TypeBase (DimDecl VName) als)
-> m (TypeBase (DimDecl VName) als, [VName])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [VName] m (TypeBase (DimDecl VName) als)
 -> m (TypeBase (DimDecl VName) als, [VName]))
-> (TypeBase (DimDecl VName) als
    -> WriterT [VName] m (TypeBase (DimDecl VName) als))
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Set VName
 -> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName))
-> TypeBase (DimDecl VName) als
-> WriterT [VName] m (TypeBase (DimDecl VName) als)
forall (f :: * -> *) fdim tdim als.
Applicative f =>
(Set VName -> DimPos -> fdim -> f tdim)
-> TypeBase fdim als -> f (TypeBase tdim als)
traverseDims Set VName
-> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall {p}.
p -> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName)
onDim
  where
    onDim :: p -> DimPos -> DimDecl VName -> WriterT [VName] m (DimDecl VName)
onDim p
_ DimPos
PosImmediate DimDecl VName
AnyDim = WriterT [VName] m (DimDecl VName)
inst
    onDim p
_ DimPos
PosParam DimDecl VName
AnyDim = WriterT [VName] m (DimDecl VName)
inst
    onDim p
_ DimPos
_ DimDecl VName
d = DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
d
    inst :: WriterT [VName] m (DimDecl VName)
inst = do
      VName
dim <- m VName -> WriterT [VName] m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> WriterT [VName] m VName)
-> m VName -> WriterT [VName] m VName
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> String -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
tloc Rigidity
r String
desc
      [VName] -> WriterT [VName] m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [VName
dim]
      DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> WriterT [VName] m (DimDecl VName))
-> DimDecl VName -> WriterT [VName] m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
dim

-- | Is the given type variable the name of an abstract type or type
-- parameter, which we cannot substitute?
isRigid :: VName -> Constraints -> Bool
isRigid :: VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints =
  Bool
-> ((Int, Constraint) -> Bool) -> Maybe (Int, Constraint) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Constraint -> Bool
rigidConstraint (Constraint -> Bool)
-> ((Int, Constraint) -> Constraint) -> (Int, Constraint) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd) (Maybe (Int, Constraint) -> Bool)
-> Maybe (Int, Constraint) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints

-- | If the given type variable is nonrigid, what is its level?
isNonRigid :: VName -> Constraints -> Maybe Level
isNonRigid :: VName -> Constraints -> Maybe Int
isNonRigid VName
v Constraints
constraints = do
  (Int
lvl, Constraint
c) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Constraint -> Bool
rigidConstraint Constraint
c
  Int -> Maybe Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
lvl

type UnifyDims m =
  BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> DimDecl VName -> DimDecl VName -> m ()

flipUnifyDims :: UnifyDims m -> UnifyDims m
flipUnifyDims :: forall (m :: * -> *). UnifyDims m -> UnifyDims m
flipUnifyDims UnifyDims m
onDims BreadCrumbs
bcs [VName]
bound VName -> Maybe Int
nonrigid DimDecl VName
t1 DimDecl VName
t2 =
  UnifyDims m
onDims BreadCrumbs
bcs [VName]
bound VName -> Maybe Int
nonrigid DimDecl VName
t2 DimDecl VName
t1

unifyWith ::
  MonadUnify m =>
  UnifyDims m ->
  Usage ->
  BreadCrumbs ->
  StructType ->
  StructType ->
  m ()
unifyWith :: forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage = Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
False [VName]
forall a. Monoid a => a
mempty
  where
    swap :: Bool -> a -> a -> (a, a)
swap Bool
True a
x a
y = (a
y, a
x)
    swap Bool
False a
x a
y = (a
x, a
y)

    subunify :: Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs StructType
t1 StructType
t2 = do
      Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

      StructType
t1' <- StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t1
      StructType
t2' <- StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t2

      let nonrigid :: VName -> Maybe Int
nonrigid VName
v = VName -> Constraints -> Maybe Int
isNonRigid VName
v Constraints
constraints

          failure :: m a
failure = SrcLoc -> Notes -> BreadCrumbs -> StructType -> StructType -> m a
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> StructType -> StructType -> m a
matchError (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs StructType
t1' StructType
t2'

          -- Remove any of the intermediate dimensions we added just
          -- for unification purposes.
          unbound :: StructType -> StructType
unbound = (VName -> Maybe (Subst StructType)) -> StructType -> StructType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst VName -> Maybe (Subst StructType)
forall {t}. VName -> Maybe (Subst t)
f
            where
              f :: VName -> Maybe (Subst t)
f VName
d
                | VName
d VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
bound = Subst t -> Maybe (Subst t)
forall a. a -> Maybe a
Just (Subst t -> Maybe (Subst t)) -> Subst t -> Maybe (Subst t)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
forall vn. DimDecl vn
AnyDim
                | Bool
otherwise = Maybe (Subst t)
forall a. Maybe a
Nothing

          link :: Bool -> VName -> Int -> StructType -> m ()
link Bool
ord' VName
v Int
lvl =
            UnifyDims m
-> Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
linkVarToType UnifyDims m
linkDims Usage
usage BreadCrumbs
bcs VName
v Int
lvl (StructType -> m ())
-> (StructType -> StructType) -> StructType -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StructType -> StructType
unbound
            where
              -- We may have to flip the order of future calls to
              -- onDims inside linkVarToType.
              linkDims :: UnifyDims m
linkDims
                | Bool
ord' = UnifyDims m -> UnifyDims m
forall (m :: * -> *). UnifyDims m -> UnifyDims m
flipUnifyDims UnifyDims m
onDims
                | Bool
otherwise = UnifyDims m
onDims

          unifyTypeArg :: BreadCrumbs
-> TypeArg (DimDecl VName) -> TypeArg (DimDecl VName) -> m ()
unifyTypeArg BreadCrumbs
bcs' (TypeArgDim DimDecl VName
d1 SrcLoc
_) (TypeArgDim DimDecl VName
d2 SrcLoc
_) =
            BreadCrumbs -> (DimDecl VName, DimDecl VName) -> m ()
onDims' BreadCrumbs
bcs' (Bool
-> DimDecl VName -> DimDecl VName -> (DimDecl VName, DimDecl VName)
forall {a}. Bool -> a -> a -> (a, a)
swap Bool
ord DimDecl VName
d1 DimDecl VName
d2)
          unifyTypeArg BreadCrumbs
bcs' (TypeArgType StructType
t SrcLoc
_) (TypeArgType StructType
arg_t SrcLoc
_) =
            Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs' StructType
t StructType
arg_t
          unifyTypeArg BreadCrumbs
bcs' TypeArg (DimDecl VName)
_ TypeArg (DimDecl VName)
_ =
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError
              Usage
usage
              Notes
forall a. Monoid a => a
mempty
              BreadCrumbs
bcs'
              Doc
"Cannot unify a type argument with a dimension argument (or vice versa)."

          onDims' :: BreadCrumbs -> (DimDecl VName, DimDecl VName) -> m ()
onDims' BreadCrumbs
bcs' (DimDecl VName
d1, DimDecl VName
d2) =
            UnifyDims m
onDims
              BreadCrumbs
bcs'
              [VName]
bound
              VName -> Maybe Int
nonrigid
              ((VName -> Maybe (Subst StructType))
-> DimDecl VName -> DimDecl VName
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) DimDecl VName
d1)
              ((VName -> Maybe (Subst StructType))
-> DimDecl VName -> DimDecl VName
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst (VName -> Constraints -> Maybe (Subst StructType)
`lookupSubst` Constraints
constraints) DimDecl VName
d2)

      case (StructType
t1', StructType
t2') of
        ( Scalar (Record Map Name StructType
fs),
          Scalar (Record Map Name StructType
arg_fs)
          )
            | Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fs [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
arg_fs ->
              [(Name, (StructType, StructType))]
-> ((Name, (StructType, StructType)) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map Name (StructType, StructType)
-> [(Name, (StructType, StructType))]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (StructType, StructType)
 -> [(Name, (StructType, StructType))])
-> Map Name (StructType, StructType)
-> [(Name, (StructType, StructType))]
forall a b. (a -> b) -> a -> b
$ (StructType -> StructType -> (StructType, StructType))
-> Map Name StructType
-> Map Name StructType
-> Map Name (StructType, StructType)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name StructType
fs Map Name StructType
arg_fs) (((Name, (StructType, StructType)) -> m ()) -> m ())
-> ((Name, (StructType, StructType)) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Name
k, (StructType
k_t1, StructType
k_t2)) -> do
                let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb ([Name] -> BreadCrumb
MatchingFields [Name
k]) BreadCrumbs
bcs
                Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs' StructType
k_t1 StructType
k_t2
            | Bool
otherwise -> do
              let missing :: [Name]
missing =
                    (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
arg_fs) (Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fs)
                      [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
fs) (Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
arg_fs)
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Unshared fields:" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Doc
forall a. Pretty a => a -> Doc
ppr [Name]
missing) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
        ( Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [VName]
_ VName
tn) [TypeArg (DimDecl VName)]
targs),
          Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [VName]
_ VName
arg_tn) [TypeArg (DimDecl VName)]
arg_targs)
          )
            | VName
tn VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arg_tn,
              [TypeArg (DimDecl VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeArg (DimDecl VName)]
targs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TypeArg (DimDecl VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeArg (DimDecl VName)]
arg_targs -> do
              let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Doc -> BreadCrumb
Matching Doc
"When matching type arguments.") BreadCrumbs
bcs
              (TypeArg (DimDecl VName) -> TypeArg (DimDecl VName) -> m ())
-> [TypeArg (DimDecl VName)] -> [TypeArg (DimDecl VName)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (BreadCrumbs
-> TypeArg (DimDecl VName) -> TypeArg (DimDecl VName) -> m ()
unifyTypeArg BreadCrumbs
bcs') [TypeArg (DimDecl VName)]
targs [TypeArg (DimDecl VName)]
arg_targs
        ( Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v1) []),
          Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v2) [])
          ) ->
            case (VName -> Maybe Int
nonrigid VName
v1, VName -> Maybe Int
nonrigid VName
v2) of
              (Maybe Int
Nothing, Maybe Int
Nothing) -> m ()
forall {a}. m a
failure
              (Just Int
lvl1, Maybe Int
Nothing) -> Bool -> VName -> Int -> StructType -> m ()
link Bool
ord VName
v1 Int
lvl1 StructType
t2'
              (Maybe Int
Nothing, Just Int
lvl2) -> Bool -> VName -> Int -> StructType -> m ()
link (Bool -> Bool
not Bool
ord) VName
v2 Int
lvl2 StructType
t1'
              (Just Int
lvl1, Just Int
lvl2)
                | Int
lvl1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
lvl2 -> Bool -> VName -> Int -> StructType -> m ()
link Bool
ord VName
v1 Int
lvl1 StructType
t2'
                | Bool
otherwise -> Bool -> VName -> Int -> StructType -> m ()
link (Bool -> Bool
not Bool
ord) VName
v2 Int
lvl2 StructType
t1'
        (Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v1) []), StructType
_)
          | Just Int
lvl <- VName -> Maybe Int
nonrigid VName
v1 ->
            Bool -> VName -> Int -> StructType -> m ()
link Bool
ord VName
v1 Int
lvl StructType
t2'
        (StructType
_, Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v2) []))
          | Just Int
lvl <- VName -> Maybe Int
nonrigid VName
v2 ->
            Bool -> VName -> Int -> StructType -> m ()
link (Bool -> Bool
not Bool
ord) VName
v2 Int
lvl StructType
t1'
        ( Scalar (Arrow ()
_ PName
p1 StructType
a1 StructType
b1),
          Scalar (Arrow ()
_ PName
p2 StructType
a2 StructType
b2)
          ) -> do
            let (Rigidity
r1, Rigidity
r2) = Bool -> Rigidity -> Rigidity -> (Rigidity, Rigidity)
forall {a}. Bool -> a -> a -> (a, a)
swap Bool
ord (RigidSource -> Rigidity
Rigid RigidSource
RigidUnify) Rigidity
Nonrigid
            (StructType
a1', [VName]
a1_dims) <- SrcLoc
-> String -> Rigidity -> StructType -> m (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) String
"anonymous" Rigidity
r1 StructType
a1
            (StructType
a2', [VName]
a2_dims) <- SrcLoc
-> String -> Rigidity -> StructType -> m (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) String
"anonymous" Rigidity
r2 StructType
a2
            let bound' :: [VName]
bound' = [VName]
bound [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> (PName -> Maybe VName) -> [PName] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PName -> Maybe VName
pname [PName
p1, PName
p2] [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
a1_dims [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
a2_dims
            Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify
              (Bool -> Bool
not Bool
ord)
              [VName]
bound
              (BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Doc -> BreadCrumb
Matching Doc
"When matching parameter types.") BreadCrumbs
bcs)
              StructType
a1'
              StructType
a2'
            Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify
              Bool
ord
              [VName]
bound'
              (BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Doc -> BreadCrumb
Matching Doc
"When matching return types.") BreadCrumbs
bcs)
              StructType
b1'
              StructType
b2'
            where
              (StructType
b1', StructType
b2') =
                -- Replace one parameter name with the other in the
                -- return type, in case of dependent types.  I.e.,
                -- we want type '(n: i32) -> [n]i32' to unify with
                -- type '(x: i32) -> [x]i32'.
                case (PName
p1, PName
p2) of
                  (Named VName
p1', Named VName
p2') ->
                    let f :: VName -> Maybe (Subst t)
f VName
v
                          | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
p2' = Subst t -> Maybe (Subst t)
forall a. a -> Maybe a
Just (Subst t -> Maybe (Subst t)) -> Subst t -> Maybe (Subst t)
forall a b. (a -> b) -> a -> b
$ DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst (DimDecl VName -> Subst t) -> DimDecl VName -> Subst t
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
p1'
                          | Bool
otherwise = Maybe (Subst t)
forall a. Maybe a
Nothing
                     in (StructType
b1, (VName -> Maybe (Subst StructType)) -> StructType -> StructType
forall a.
Substitutable a =>
(VName -> Maybe (Subst StructType)) -> a -> a
applySubst VName -> Maybe (Subst StructType)
forall {t}. VName -> Maybe (Subst t)
f StructType
b2)
                  (PName
_, PName
_) ->
                    (StructType
b1, StructType
b2)

              pname :: PName -> Maybe VName
pname (Named VName
x) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
x
              pname PName
Unnamed = Maybe VName
forall a. Maybe a
Nothing
        (Array {}, Array {})
          | ShapeDecl (DimDecl VName
t1_d : [DimDecl VName]
_) <- StructType -> ShapeDecl (DimDecl VName)
forall dim as. TypeBase dim as -> ShapeDecl dim
arrayShape StructType
t1',
            ShapeDecl (DimDecl VName
t2_d : [DimDecl VName]
_) <- StructType -> ShapeDecl (DimDecl VName)
forall dim as. TypeBase dim as -> ShapeDecl dim
arrayShape StructType
t2',
            Just StructType
t1'' <- Int -> StructType -> Maybe StructType
forall dim as. Int -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray Int
1 StructType
t1',
            Just StructType
t2'' <- Int -> StructType -> Maybe StructType
forall dim as. Int -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray Int
1 StructType
t2' -> do
            BreadCrumbs -> (DimDecl VName, DimDecl VName) -> m ()
onDims' BreadCrumbs
bcs (Bool
-> DimDecl VName -> DimDecl VName -> (DimDecl VName, DimDecl VName)
forall {a}. Bool -> a -> a -> (a, a)
swap Bool
ord DimDecl VName
t1_d DimDecl VName
t2_d)
            Bool -> [VName] -> BreadCrumbs -> StructType -> StructType -> m ()
subunify Bool
ord [VName]
bound BreadCrumbs
bcs StructType
t1'' StructType
t2''
        ( Scalar (Sum Map Name [StructType]
cs),
          Scalar (Sum Map Name [StructType]
arg_cs)
          )
            | Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
cs [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
arg_cs ->
              UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors
                UnifyDims m
onDims
                Usage
usage
                BreadCrumbs
bcs
                ((StructType -> StructType) -> [StructType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> StructType
unbound ([StructType] -> [StructType])
-> Map Name [StructType] -> Map Name [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name [StructType]
cs)
                ((StructType -> StructType) -> [StructType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map StructType -> StructType
unbound ([StructType] -> [StructType])
-> Map Name [StructType] -> Map Name [StructType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name [StructType]
arg_cs)
            | Bool
otherwise -> do
              let missing :: [Name]
missing =
                    (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
arg_cs) (Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
cs)
                      [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ (Name -> Bool) -> [Name] -> [Name]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
cs) (Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
arg_cs)
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Unshared constructors:" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((Name -> Doc) -> [Name] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ((Doc
"#" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>) (Doc -> Doc) -> (Name -> Doc) -> Name -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Doc
forall a. Pretty a => a -> Doc
ppr) [Name]
missing) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
        (StructType, StructType)
_
          | StructType
t1' StructType -> StructType -> Bool
forall a. Eq a => a -> a -> Bool
== StructType
t2' -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          | Bool
otherwise -> m ()
forall {a}. m a
failure

unifyDims :: MonadUnify m => Usage -> UnifyDims m
unifyDims :: forall (m :: * -> *). MonadUnify m => Usage -> UnifyDims m
unifyDims Usage
_ BreadCrumbs
_ [VName]
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2
  | DimDecl VName
d1 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d2 = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
unifyDims Usage
usage BreadCrumbs
bcs [VName]
_ VName -> Maybe Int
nonrigid (NamedDim (QualName [VName]
_ VName
d1)) DimDecl VName
d2
  | Just Int
lvl1 <- VName -> Maybe Int
nonrigid VName
d1 =
    Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d1 Int
lvl1 DimDecl VName
d2
unifyDims Usage
usage BreadCrumbs
bcs [VName]
_ VName -> Maybe Int
nonrigid DimDecl VName
d1 (NamedDim (QualName [VName]
_ VName
d2))
  | Just Int
lvl2 <- VName -> Maybe Int
nonrigid VName
d2 =
    Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d2 Int
lvl2 DimDecl VName
d1
unifyDims Usage
usage BreadCrumbs
bcs [VName]
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2 = do
  Notes
notes <- Notes -> Notes -> Notes
forall a. Semigroup a => a -> a -> a
(<>) (Notes -> Notes -> Notes) -> m Notes -> m (Notes -> Notes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d1 m (Notes -> Notes) -> m Notes -> m Notes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d2
  Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
    Doc
"Dimensions" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d1)
      Doc -> Doc -> Doc
<+> Doc
"and"
      Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d2)
      Doc -> Doc -> Doc
<+> Doc
"do not match."

-- | Unifies two types.
unify :: MonadUnify m => Usage -> StructType -> StructType -> m ()
unify :: forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage = UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith (Usage -> UnifyDims m
forall (m :: * -> *). MonadUnify m => Usage -> UnifyDims m
unifyDims Usage
usage) Usage
usage BreadCrumbs
noBreadCrumbs

-- | @expect super sub@ checks that @sub@ is a subtype of @super@.
expect :: MonadUnify m => Usage -> StructType -> StructType -> m ()
expect :: forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
expect Usage
usage = UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
forall {m :: * -> *} {t :: * -> *}.
(Foldable t, MonadUnify m) =>
BreadCrumbs
-> t VName
-> (VName -> Maybe Int)
-> DimDecl VName
-> DimDecl VName
-> m ()
onDims Usage
usage BreadCrumbs
noBreadCrumbs
  where
    onDims :: BreadCrumbs
-> t VName
-> (VName -> Maybe Int)
-> DimDecl VName
-> DimDecl VName
-> m ()
onDims BreadCrumbs
_ t VName
_ VName -> Maybe Int
_ DimDecl VName
AnyDim DimDecl VName
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    onDims BreadCrumbs
_ t VName
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2
      | DimDecl VName
d1 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d2 = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    onDims BreadCrumbs
bcs t VName
bound VName -> Maybe Int
nonrigid (NamedDim (QualName [VName]
_ VName
d1)) DimDecl VName
d2
      | Just Int
lvl1 <- VName -> Maybe Int
nonrigid VName
d1,
        DimDecl VName
d2 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
/= DimDecl VName
forall vn. DimDecl vn
AnyDim,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ t VName -> DimDecl VName -> Bool
forall {t :: * -> *} {a}.
(Foldable t, Eq a) =>
t a -> DimDecl a -> Bool
boundParam t VName
bound DimDecl VName
d2 =
        Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d1 Int
lvl1 DimDecl VName
d2
    onDims BreadCrumbs
bcs t VName
bound VName -> Maybe Int
nonrigid DimDecl VName
d1 (NamedDim (QualName [VName]
_ VName
d2))
      | Just Int
lvl2 <- VName -> Maybe Int
nonrigid VName
d2,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ t VName -> DimDecl VName -> Bool
forall {t :: * -> *} {a}.
(Foldable t, Eq a) =>
t a -> DimDecl a -> Bool
boundParam t VName
bound DimDecl VName
d1 =
        Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
d2 Int
lvl2 DimDecl VName
d1
    onDims BreadCrumbs
bcs t VName
_ VName -> Maybe Int
_ DimDecl VName
d1 DimDecl VName
d2 = do
      Notes
notes <- Notes -> Notes -> Notes
forall a. Semigroup a => a -> a -> a
(<>) (Notes -> Notes -> Notes) -> m Notes -> m (Notes -> Notes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d1 m (Notes -> Notes) -> m Notes -> m Notes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
d2
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Dimensions" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d1)
          Doc -> Doc -> Doc
<+> Doc
"and"
          Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (DimDecl VName -> Doc
forall a. Pretty a => a -> Doc
ppr DimDecl VName
d2)
          Doc -> Doc -> Doc
<+> Doc
"do not match."

    boundParam :: t a -> DimDecl a -> Bool
boundParam t a
bound (NamedDim (QualName [a]
_ a
d)) = a
d a -> t a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t a
bound
    boundParam t a
_ DimDecl a
_ = Bool
False

hasEmptyDims :: StructType -> Bool
hasEmptyDims :: StructType -> Bool
hasEmptyDims = (DimDecl VName -> Bool) -> (() -> Bool) -> StructType -> Bool
forall (t :: * -> * -> *) a b.
Bifoldable t =>
(a -> Bool) -> (b -> Bool) -> t a b -> Bool
biany DimDecl VName -> Bool
forall {vn}. DimDecl vn -> Bool
empty (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
False)
  where
    empty :: DimDecl vn -> Bool
empty DimDecl vn
AnyDim = Bool
True
    empty DimDecl vn
_ = Bool
False

occursCheck ::
  MonadUnify m =>
  Usage ->
  BreadCrumbs ->
  VName ->
  StructType ->
  m ()
occursCheck :: forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> StructType -> m ()
occursCheck Usage
usage BreadCrumbs
bcs VName
vn StructType
tp =
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VName
vn VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` StructType -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars StructType
tp) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Occurs check: cannot instantiate"
        Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn
        Doc -> Doc -> Doc
<+> Doc
"with"
        Doc -> Doc -> Doc
<+> StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

scopeCheck ::
  MonadUnify m =>
  Usage ->
  BreadCrumbs ->
  VName ->
  Level ->
  StructType ->
  m ()
scopeCheck :: forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
bcs VName
vn Int
max_lvl StructType
tp = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  Constraints -> StructType -> m ()
forall {m :: * -> *} {als}.
(MonadUnify m, Monoid als) =>
Constraints -> TypeBase (DimDecl VName) als -> m ()
checkType Constraints
constraints StructType
tp
  where
    checkType :: Constraints -> TypeBase (DimDecl VName) als -> m ()
checkType Constraints
constraints TypeBase (DimDecl VName) als
t =
      (VName -> m ()) -> Set VName -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Constraints -> VName -> m ()
forall {m :: * -> *}. MonadUnify m => Constraints -> VName -> m ()
check Constraints
constraints) (Set VName -> m ()) -> Set VName -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase (DimDecl VName) als -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars TypeBase (DimDecl VName) als
t Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> TypeBase (DimDecl VName) als -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames TypeBase (DimDecl VName) als
t

    check :: Constraints -> VName -> m ()
check Constraints
constraints VName
v
      | Just (Int
lvl, Constraint
c) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints,
        Int
lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
max_lvl =
        if Constraint -> Bool
rigidConstraint Constraint
c
          then VName -> m ()
forall {m :: * -> *} {v} {b}. (MonadUnify m, IsName v) => v -> m b
scopeViolation VName
v
          else (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Int
max_lvl, Constraint
c)
      | Bool
otherwise =
        () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    scopeViolation :: v -> m b
scopeViolation v
v = do
      Notes
notes <- Usage -> StructType -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> StructType -> m Notes
typeNotes Usage
usage StructType
tp
      Usage -> Notes -> BreadCrumbs -> Doc -> m b
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m b) -> Doc -> m b
forall a b. (a -> b) -> a -> b
$
        Doc
"Cannot unify type"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp)
          Doc -> Doc -> Doc
</> Doc
"with"
          Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
          Doc -> Doc -> Doc
<+> Doc
"(scope violation)."
          Doc -> Doc -> Doc
</> Doc
"This is because"
          Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (v -> Doc
forall v. IsName v => v -> Doc
pprName v
v)
          Doc -> Doc -> Doc
<+> Doc
"is rigidly bound in a deeper scope."

linkVarToType ::
  MonadUnify m =>
  UnifyDims m ->
  Usage ->
  BreadCrumbs ->
  VName ->
  Level ->
  StructType ->
  m ()
linkVarToType :: forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
linkVarToType UnifyDims m
onDims Usage
usage BreadCrumbs
bcs VName
vn Int
lvl StructType
tp = do
  Usage -> BreadCrumbs -> VName -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> StructType -> m ()
occursCheck Usage
usage BreadCrumbs
bcs VName
vn StructType
tp
  Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
bcs VName
vn Int
lvl StructType
tp

  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  let tp' :: StructType
tp' = StructType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness StructType
tp
  (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, StructType -> Usage -> Constraint
Constraint StructType
tp' Usage
usage)
  case (Int, Constraint) -> Constraint
forall a b. (a, b) -> b
snd ((Int, Constraint) -> Constraint)
-> Maybe (Int, Constraint) -> Maybe Constraint
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of
    Just (NoConstraint Liftedness
l Usage
unlift_usage)
      | Liftedness
l Liftedness -> Liftedness -> Bool
forall a. Ord a => a -> a -> Bool
< Liftedness
Lifted -> do
        let bcs' :: BreadCrumbs
bcs' =
              BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb
                ( Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$
                    Doc
"When verifying that" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
                      Doc -> Doc -> Doc
<+> String -> Doc
textwrap String
"is not instantiated with a function type, due to"
                      Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
unlift_usage
                )
                BreadCrumbs
bcs

        Usage -> BreadCrumbs -> StructType -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
arrayElemTypeWith Usage
usage BreadCrumbs
bcs' StructType
tp'

        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Liftedness
l Liftedness -> Liftedness -> Bool
forall a. Eq a => a -> a -> Bool
== Liftedness
Unlifted) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (StructType -> Bool
hasEmptyDims StructType
tp') (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Type variable" Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn
                Doc -> Doc -> Doc
<+> Doc
"cannot be instantiated with type containing anonymous sizes:"
                Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp)
                Doc -> Doc -> Doc
</> String -> Doc
textwrap String
"This is usually because the size of an array returned by a higher-order function argument cannot be determined statically.  This can also be due to the return size being a value parameter.  Add type annotation to clarify."
    Just (Equality Usage
_) ->
      Usage -> StructType -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> TypeBase dim as -> m ()
equalityType Usage
usage StructType
tp'
    Just (Overloaded [PrimType]
ts Usage
old_usage)
      | StructType
tp StructType -> [StructType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (PrimType -> StructType) -> [PrimType] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> (PrimType -> ScalarTypeBase (DimDecl VName) ())
-> PrimType
-> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim) [PrimType]
ts ->
        case StructType
tp' of
          Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
            | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints ->
              Usage -> VName -> [PrimType] -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> VName -> [PrimType] -> m ()
linkVarToTypes Usage
usage VName
v [PrimType]
ts
          StructType
_ ->
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Cannot instantiate" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
                Doc -> Doc -> Doc
<+> Doc
"with type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp) Doc -> Doc -> Doc
</> Doc
"as"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
                Doc -> Doc -> Doc
<+> Doc
"must be one of"
                Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts)
                Doc -> Doc -> Doc
<+/> Doc
"due to"
                Doc -> Doc -> Doc
<+/> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
old_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Just (HasFields Map Name StructType
required_fields Usage
old_usage) ->
      case StructType
tp of
        Scalar (Record Map Name StructType
tp_fields)
          | (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Name -> Map Name StructType -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name StructType
tp_fields) ([Name] -> Bool) -> [Name] -> Bool
forall a b. (a -> b) -> a -> b
$ Map Name StructType -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name StructType
required_fields -> do
            Map Name StructType
required_fields' <- (StructType -> m StructType)
-> Map Name StructType -> m (Map Name StructType)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM StructType -> m StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully Map Name StructType
required_fields
            let bcs' :: BreadCrumbs
bcs' =
                  BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb
                    ( Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$
                        VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn
                          Doc -> Doc -> Doc
<+> Doc
"must be a record with at least the fields:"
                          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (ScalarTypeBase (DimDecl VName) () -> Doc
forall a. Pretty a => a -> Doc
ppr (Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name StructType
required_fields'))
                          Doc -> Doc -> Doc
</> Doc
"due to"
                          Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
old_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
                    )
                    BreadCrumbs
bcs
            ((StructType, StructType) -> m ())
-> [(StructType, StructType)] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((StructType -> StructType -> m ())
-> (StructType, StructType) -> m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((StructType -> StructType -> m ())
 -> (StructType, StructType) -> m ())
-> (StructType -> StructType -> m ())
-> (StructType, StructType)
-> m ()
forall a b. (a -> b) -> a -> b
$ UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs') ([(StructType, StructType)] -> m ())
-> [(StructType, StructType)] -> m ()
forall a b. (a -> b) -> a -> b
$
              Map Name (StructType, StructType) -> [(StructType, StructType)]
forall k a. Map k a -> [a]
M.elems (Map Name (StructType, StructType) -> [(StructType, StructType)])
-> Map Name (StructType, StructType) -> [(StructType, StructType)]
forall a b. (a -> b) -> a -> b
$
                (StructType -> StructType -> (StructType, StructType))
-> Map Name StructType
-> Map Name StructType
-> Map Name (StructType, StructType)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name StructType
required_fields Map Name StructType
tp_fields
        Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints ->
            (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$
              VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                VName
v
                (Int
lvl, Map Name StructType -> Usage -> Constraint
HasFields Map Name StructType
required_fields Usage
old_usage)
        StructType
_ ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
            Doc
"Cannot instantiate" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"with type"
              Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
tp)
              Doc -> Doc -> Doc
</> Doc
"as" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn) Doc -> Doc -> Doc
<+> Doc
"must be a record with fields"
              Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (ScalarTypeBase (DimDecl VName) () -> Doc
forall a. Pretty a => a -> Doc
ppr (Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name StructType
required_fields))
              Doc -> Doc -> Doc
</> Doc
"due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
old_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Just (HasConstrs Map Name [StructType]
required_cs Usage
old_usage) ->
      case StructType
tp of
        Scalar (Sum Map Name [StructType]
ts)
          | (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Name -> Map Name [StructType] -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name [StructType]
ts) ([Name] -> Bool) -> [Name] -> Bool
forall a b. (a -> b) -> a -> b
$ Map Name [StructType] -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name [StructType]
required_cs ->
            UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Map Name [StructType]
required_cs Map Name [StructType]
ts
        Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints -> do
            case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Constraints
constraints of
              Just (Int
_, HasConstrs Map Name [StructType]
v_cs Usage
_) ->
                UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Map Name [StructType]
required_cs Map Name [StructType]
v_cs
              Maybe (Int, Constraint)
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$
              ((Int, Constraint) -> (Int, Constraint) -> (Int, Constraint))
-> VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith
                (Int, Constraint) -> (Int, Constraint) -> (Int, Constraint)
forall {a}.
(Int, Constraint) -> (a, Constraint) -> (Int, Constraint)
combineConstrs
                VName
v
                (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs Map Name [StructType]
required_cs Usage
old_usage)
          where
            combineConstrs :: (Int, Constraint) -> (a, Constraint) -> (Int, Constraint)
combineConstrs (Int
_, HasConstrs Map Name [StructType]
cs1 Usage
usage1) (a
_, HasConstrs Map Name [StructType]
cs2 Usage
_) =
              (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs (Map Name [StructType]
-> Map Name [StructType] -> Map Name [StructType]
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map Name [StructType]
cs1 Map Name [StructType]
cs2) Usage
usage1)
            combineConstrs (Int, Constraint)
hasCs (a, Constraint)
_ = (Int, Constraint)
hasCs
        StructType
_ -> m ()
forall {a}. m a
noSumType
    Maybe Constraint
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    noSumType :: m a
noSumType =
      Usage -> Notes -> BreadCrumbs -> Doc -> m a
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError
        Usage
usage
        Notes
forall a. Monoid a => a
mempty
        BreadCrumbs
bcs
        Doc
"Cannot unify a sum type with a non-sum type"

linkVarToDim ::
  MonadUnify m =>
  Usage ->
  BreadCrumbs ->
  VName ->
  Level ->
  DimDecl VName ->
  m ()
linkVarToDim :: forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> DimDecl VName -> m ()
linkVarToDim Usage
usage BreadCrumbs
bcs VName
vn Int
lvl DimDecl VName
dim = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints

  case DimDecl VName
dim of
    NamedDim QualName VName
dim'
      | Just (Int
dim_lvl, Constraint
c) <- QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
dim' VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Constraints
constraints,
        Int
dim_lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
lvl ->
        case Constraint
c of
          ParamSize {} -> do
            Notes
notes <- Usage -> DimDecl VName -> m Notes
forall a (m :: * -> *).
(Located a, MonadUnify m) =>
a -> DimDecl VName -> m Notes
dimNotes Usage
usage DimDecl VName
dim
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
notes BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Cannot unify size variable" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
dim')
                Doc -> Doc -> Doc
<+> Doc
"with"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
                Doc -> Doc -> Doc
<+> Doc
"(scope violation)."
                Doc -> Doc -> Doc
</> Doc
"This is because"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (QualName VName -> Doc
forall a. Pretty a => a -> Doc
ppr QualName VName
dim')
                Doc -> Doc -> Doc
<+> Doc
"is rigidly bound in a deeper scope."
          Constraint
_ -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
dim') (Int
lvl, Constraint
c)
    DimDecl VName
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Maybe (DimDecl VName) -> Usage -> Constraint
Size (DimDecl VName -> Maybe (DimDecl VName)
forall a. a -> Maybe a
Just DimDecl VName
dim) Usage
usage)

removeUniqueness :: TypeBase dim as -> TypeBase dim as
removeUniqueness :: forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness (Scalar (Record Map Name (TypeBase dim as)
ets)) =
  ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase dim as) -> ScalarTypeBase dim as
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase dim as) -> ScalarTypeBase dim as)
-> Map Name (TypeBase dim as) -> ScalarTypeBase dim as
forall a b. (a -> b) -> a -> b
$ (TypeBase dim as -> TypeBase dim as)
-> Map Name (TypeBase dim as) -> Map Name (TypeBase dim as)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness Map Name (TypeBase dim as)
ets
removeUniqueness (Scalar (Arrow as
als PName
p TypeBase dim as
t1 TypeBase dim as
t2)) =
  ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow as
als PName
p (TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness TypeBase dim as
t1) (TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness TypeBase dim as
t2)
removeUniqueness (Scalar (Sum Map Name [TypeBase dim as]
cs)) =
  ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ Map Name [TypeBase dim as] -> ScalarTypeBase dim as
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase dim as] -> ScalarTypeBase dim as)
-> Map Name [TypeBase dim as] -> ScalarTypeBase dim as
forall a b. (a -> b) -> a -> b
$ (([TypeBase dim as] -> [TypeBase dim as])
-> Map Name [TypeBase dim as] -> Map Name [TypeBase dim as]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([TypeBase dim as] -> [TypeBase dim as])
 -> Map Name [TypeBase dim as] -> Map Name [TypeBase dim as])
-> ((TypeBase dim as -> TypeBase dim as)
    -> [TypeBase dim as] -> [TypeBase dim as])
-> (TypeBase dim as -> TypeBase dim as)
-> Map Name [TypeBase dim as]
-> Map Name [TypeBase dim as]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeBase dim as -> TypeBase dim as)
-> [TypeBase dim as] -> [TypeBase dim as]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) TypeBase dim as -> TypeBase dim as
forall dim as. TypeBase dim as -> TypeBase dim as
removeUniqueness Map Name [TypeBase dim as]
cs
removeUniqueness TypeBase dim as
t = TypeBase dim as
t TypeBase dim as -> Uniqueness -> TypeBase dim as
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique

-- | Assert that this type must be one of the given primitive types.
mustBeOneOf :: MonadUnify m => [PrimType] -> Usage -> StructType -> m ()
mustBeOneOf :: forall (m :: * -> *).
MonadUnify m =>
[PrimType] -> Usage -> StructType -> m ()
mustBeOneOf [PrimType
req_t] Usage
usage StructType
t = Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
req_t)) StructType
t
mustBeOneOf [PrimType]
ts Usage
usage StructType
t = do
  StructType
t' <- StructType -> m StructType
forall (m :: * -> *). MonadUnify m => StructType -> m StructType
normType StructType
t
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  let isRigid' :: VName -> Bool
isRigid' VName
v = VName -> Constraints -> Bool
isRigid VName
v Constraints
constraints

  case StructType
t' of
    Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
v) [])
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
isRigid' VName
v -> Usage -> VName -> [PrimType] -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> VName -> [PrimType] -> m ()
linkVarToTypes Usage
usage VName
v [PrimType]
ts
    Scalar (Prim PrimType
pt) | PrimType
pt PrimType -> [PrimType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PrimType]
ts -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    StructType
_ -> m ()
forall {a}. m a
failure
  where
    failure :: m a
failure =
      Usage -> Notes -> BreadCrumbs -> Doc -> m a
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m a) -> Doc -> m a
forall a b. (a -> b) -> a -> b
$
        String -> Doc
text String
"Cannot unify type" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t)
          Doc -> Doc -> Doc
<+> Doc
"with any of " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

linkVarToTypes :: MonadUnify m => Usage -> VName -> [PrimType] -> m ()
linkVarToTypes :: forall (m :: * -> *).
MonadUnify m =>
Usage -> VName -> [PrimType] -> m ()
linkVarToTypes Usage
usage VName
vn [PrimType]
ts = do
  Maybe (Int, Constraint)
vn_constraint <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn (Constraints -> Maybe (Int, Constraint))
-> m Constraints -> m (Maybe (Int, Constraint))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case Maybe (Int, Constraint)
vn_constraint of
    Just (Int
lvl, Overloaded [PrimType]
vn_ts Usage
vn_usage) ->
      case [PrimType]
ts [PrimType] -> [PrimType] -> [PrimType]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [PrimType]
vn_ts of
        [] ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
            Doc
"Type constrained to one of"
              Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts)
              Doc -> Doc -> Doc
<+> Doc
"but also one of"
              Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
vn_ts)
              Doc -> Doc -> Doc
<+> Doc
"due to"
              Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
vn_usage Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
        [PrimType]
ts' -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, [PrimType] -> Usage -> Constraint
Overloaded [PrimType]
ts' Usage
usage)
    Just (Int
_, HasConstrs Map Name [StructType]
_ Usage
vn_usage) ->
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Type constrained to one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts)
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", but also inferred to be sum type due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
vn_usage
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Just (Int
_, HasFields Map Name StructType
_ Usage
vn_usage) ->
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Type constrained to one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts)
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
", but also inferred to be record due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
vn_usage
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Just (Int
lvl, Constraint
_) -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, [PrimType] -> Usage -> Constraint
Overloaded [PrimType]
ts Usage
usage)
    Maybe (Int, Constraint)
Nothing ->
      Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
        Doc
"Cannot constrain type to one of" Doc -> Doc -> Doc
<+> [Doc] -> Doc
commasep ((PrimType -> Doc) -> [PrimType] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Doc
forall a. Pretty a => a -> Doc
ppr [PrimType]
ts)

-- | Assert that this type must support equality.
equalityType ::
  (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
  Usage ->
  TypeBase dim as ->
  m ()
equalityType :: forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> TypeBase dim as -> m ()
equalityType Usage
usage TypeBase dim as
t = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase dim as -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero TypeBase dim as
t) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Type " Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
<+> Doc
"does not support equality (is higher-order)."
  (VName -> m ()) -> Set VName -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> m ()
forall {m :: * -> *}. MonadUnify m => VName -> m ()
mustBeEquality (Set VName -> m ()) -> Set VName -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim as -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars TypeBase dim as
t
  where
    mustBeEquality :: VName -> m ()
mustBeEquality VName
vn = do
      Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
      case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of
        Just (Int
_, Constraint (Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [] VName
vn') [])) Usage
_) ->
          VName -> m ()
mustBeEquality VName
vn'
        Just (Int
_, Constraint StructType
vn_t Usage
cusage)
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero StructType
vn_t ->
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Type" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
<+> Doc
"does not support equality."
                Doc -> Doc -> Doc
</> Doc
"Constrained to be higher-order due to" Doc -> Doc -> Doc
<+> Usage -> Doc
forall a. Pretty a => a -> Doc
ppr Usage
cusage Doc -> Doc -> Doc
<+> Doc
"."
          | Bool
otherwise -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just (Int
lvl, NoConstraint Liftedness
_ Usage
_) ->
          (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Usage -> Constraint
Equality Usage
usage)
        Just (Int
_, Overloaded [PrimType]
_ Usage
_) ->
          () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- All primtypes support equality.
        Just (Int
_, Equality {}) ->
          () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just (Int
_, HasConstrs Map Name [StructType]
cs Usage
_) ->
          (StructType -> m ()) -> [StructType] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Usage -> StructType -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> TypeBase dim as -> m ()
equalityType Usage
usage) ([StructType] -> m ()) -> [StructType] -> m ()
forall a b. (a -> b) -> a -> b
$ [[StructType]] -> [StructType]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[StructType]] -> [StructType]) -> [[StructType]] -> [StructType]
forall a b. (a -> b) -> a -> b
$ Map Name [StructType] -> [[StructType]]
forall k a. Map k a -> [a]
M.elems Map Name [StructType]
cs
        Maybe (Int, Constraint)
_ ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
            Doc
"Type" Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn Doc -> Doc -> Doc
<+> Doc
"does not support equality."

zeroOrderTypeWith ::
  (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
  Usage ->
  BreadCrumbs ->
  TypeBase dim as ->
  m ()
zeroOrderTypeWith :: forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
zeroOrderTypeWith Usage
usage BreadCrumbs
bcs TypeBase dim as
t = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase dim as -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero TypeBase dim as
t) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
</> Doc
"found to be functional."
  (VName -> m ()) -> [VName] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> m ()
forall {m :: * -> *}. MonadUnify m => VName -> m ()
mustBeZeroOrder ([VName] -> m ())
-> (TypeBase dim as -> [VName]) -> TypeBase dim as -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName])
-> (TypeBase dim as -> Set VName) -> TypeBase dim as -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase dim as -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars (TypeBase dim as -> m ()) -> TypeBase dim as -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim as
t
  where
    mustBeZeroOrder :: VName -> m ()
mustBeZeroOrder VName
vn = do
      Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
      case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of
        Just (Int
lvl, NoConstraint Liftedness
_ Usage
_) ->
          (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Unlifted Usage
usage)
        Just (Int
_, ParamType Liftedness
Lifted SrcLoc
ploc) ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
            Doc
"Type parameter"
              Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
              Doc -> Doc -> Doc
<+> Doc
"at"
              Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
ploc)
              Doc -> Doc -> Doc
<+> Doc
"may be a function."
        Maybe (Int, Constraint)
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Assert that this type must be zero-order.
zeroOrderType ::
  (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
  Usage ->
  String ->
  TypeBase dim as ->
  m ()
zeroOrderType :: forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType Usage
usage String
desc =
  Usage -> BreadCrumbs -> TypeBase dim as -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
zeroOrderTypeWith Usage
usage (BreadCrumbs -> TypeBase dim as -> m ())
-> BreadCrumbs -> TypeBase dim as -> m ()
forall a b. (a -> b) -> a -> b
$ BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb BreadCrumb
bc BreadCrumbs
noBreadCrumbs
  where
    bc :: BreadCrumb
bc = Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$ Doc
"When checking" Doc -> Doc -> Doc
<+> String -> Doc
textwrap String
desc

arrayElemTypeWith ::
  (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
  Usage ->
  BreadCrumbs ->
  TypeBase dim as ->
  m ()
arrayElemTypeWith :: forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
arrayElemTypeWith Usage
usage BreadCrumbs
bcs TypeBase dim as
t = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase dim as -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero TypeBase dim as
t) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
      Doc
"Type" Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (TypeBase dim as -> Doc
forall a. Pretty a => a -> Doc
ppr TypeBase dim as
t) Doc -> Doc -> Doc
</> Doc
"found to be functional."
  (VName -> m ()) -> [VName] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> m ()
forall {m :: * -> *}. MonadUnify m => VName -> m ()
mustBeZeroOrder ([VName] -> m ())
-> (TypeBase dim as -> [VName]) -> TypeBase dim as -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName])
-> (TypeBase dim as -> Set VName) -> TypeBase dim as -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase dim as -> Set VName
forall as dim. Monoid as => TypeBase dim as -> Set VName
typeVars (TypeBase dim as -> m ()) -> TypeBase dim as -> m ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim as
t
  where
    mustBeZeroOrder :: VName -> m ()
mustBeZeroOrder VName
vn = do
      Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
      case VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vn Constraints
constraints of
        Just (Int
lvl, NoConstraint Liftedness
_ Usage
_) ->
          (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vn (Int
lvl, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
SizeLifted Usage
usage)
        Just (Int
_, ParamType Liftedness
l SrcLoc
ploc)
          | Liftedness
l Liftedness -> [Liftedness] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Liftedness
Lifted, Liftedness
SizeLifted] ->
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Type parameter"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
vn)
                Doc -> Doc -> Doc
<+> Doc
"bound at"
                Doc -> Doc -> Doc
<+> String -> Doc
text (SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
ploc)
                Doc -> Doc -> Doc
<+> Doc
"is lifted and cannot be an array element."
        Maybe (Int, Constraint)
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Assert that this type must be valid as an array element.
arrayElemType ::
  (MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
  Usage ->
  String ->
  TypeBase dim as ->
  m ()
arrayElemType :: forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
arrayElemType Usage
usage String
desc =
  Usage -> BreadCrumbs -> TypeBase dim as -> m ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> BreadCrumbs -> TypeBase dim as -> m ()
arrayElemTypeWith Usage
usage (BreadCrumbs -> TypeBase dim as -> m ())
-> BreadCrumbs -> TypeBase dim as -> m ()
forall a b. (a -> b) -> a -> b
$ BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb BreadCrumb
bc BreadCrumbs
noBreadCrumbs
  where
    bc :: BreadCrumb
bc = Doc -> BreadCrumb
Matching (Doc -> BreadCrumb) -> Doc -> BreadCrumb
forall a b. (a -> b) -> a -> b
$ Doc
"When checking" Doc -> Doc -> Doc
<+> String -> Doc
textwrap String
desc

unifySharedConstructors ::
  MonadUnify m =>
  UnifyDims m ->
  Usage ->
  BreadCrumbs ->
  M.Map Name [StructType] ->
  M.Map Name [StructType] ->
  m ()
unifySharedConstructors :: forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage
-> BreadCrumbs
-> Map Name [StructType]
-> Map Name [StructType]
-> m ()
unifySharedConstructors UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Map Name [StructType]
cs1 Map Name [StructType]
cs2 =
  [(Name, ([StructType], [StructType]))]
-> ((Name, ([StructType], [StructType])) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map Name ([StructType], [StructType])
-> [(Name, ([StructType], [StructType]))]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name ([StructType], [StructType])
 -> [(Name, ([StructType], [StructType]))])
-> Map Name ([StructType], [StructType])
-> [(Name, ([StructType], [StructType]))]
forall a b. (a -> b) -> a -> b
$ ([StructType] -> [StructType] -> ([StructType], [StructType]))
-> Map Name [StructType]
-> Map Name [StructType]
-> Map Name ([StructType], [StructType])
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (,) Map Name [StructType]
cs1 Map Name [StructType]
cs2) (((Name, ([StructType], [StructType])) -> m ()) -> m ())
-> ((Name, ([StructType], [StructType])) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Name
c, ([StructType]
f1, [StructType]
f2)) ->
    Name -> [StructType] -> [StructType] -> m ()
unifyConstructor Name
c [StructType]
f1 [StructType]
f2
  where
    unifyConstructor :: Name -> [StructType] -> [StructType] -> m ()
unifyConstructor Name
c [StructType]
f1 [StructType]
f2
      | [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
f1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
f2 = do
        let bcs' :: BreadCrumbs
bcs' = BreadCrumb -> BreadCrumbs -> BreadCrumbs
breadCrumb (Name -> BreadCrumb
MatchingConstructor Name
c) BreadCrumbs
bcs
        (StructType -> StructType -> m ())
-> [StructType] -> [StructType] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs') [StructType]
f1 [StructType]
f2
      | Bool
otherwise =
        Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
          Doc
"Cannot unify constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall v. IsName v => v -> Doc
pprName Name
c) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."

-- | In @mustHaveConstr usage c t fs@, the type @t@ must have a
-- constructor named @c@ that takes arguments of types @ts@.
mustHaveConstr ::
  MonadUnify m =>
  Usage ->
  Name ->
  StructType ->
  [StructType] ->
  m ()
mustHaveConstr :: forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> StructType -> [StructType] -> m ()
mustHaveConstr Usage
usage Name
c StructType
t [StructType]
fs = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  case StructType
t of
    Scalar (TypeVar ()
_ Uniqueness
_ (TypeName [VName]
_ VName
tn) [])
      | Just (Int
lvl, NoConstraint {}) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints -> do
        (StructType -> m ()) -> [StructType] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
noBreadCrumbs VName
tn Int
lvl) [StructType]
fs
        (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs (Name -> [StructType] -> Map Name [StructType]
forall k a. k -> a -> Map k a
M.singleton Name
c [StructType]
fs) Usage
usage)
      | Just (Int
lvl, HasConstrs Map Name [StructType]
cs Usage
_) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints ->
        case Name -> Map Name [StructType] -> Maybe [StructType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c Map Name [StructType]
cs of
          Maybe [StructType]
Nothing -> (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn (Int
lvl, Map Name [StructType] -> Usage -> Constraint
HasConstrs (Name
-> [StructType] -> Map Name [StructType] -> Map Name [StructType]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
c [StructType]
fs Map Name [StructType]
cs) Usage
usage)
          Just [StructType]
fs'
            | [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs' -> (StructType -> StructType -> m ())
-> [StructType] -> [StructType] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage) [StructType]
fs [StructType]
fs'
            | Bool
otherwise ->
              Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
                Doc
"Different arity for constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    Scalar (Sum Map Name [StructType]
cs) ->
      case Name -> Map Name [StructType] -> Maybe [StructType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c Map Name [StructType]
cs of
        Maybe [StructType]
Nothing ->
          Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
            Doc
"Constuctor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
<+> Doc
"not present in type."
        Just [StructType]
fs'
          | [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StructType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
fs' -> (StructType -> StructType -> m ())
-> [StructType] -> [StructType] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage) [StructType]
fs [StructType]
fs'
          | Bool
otherwise ->
            Usage -> Notes -> BreadCrumbs -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
noBreadCrumbs (Doc -> m ()) -> Doc -> m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Different arity for constructor" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
c) Doc -> Doc -> Doc
<+> Doc
"."
    StructType
_ -> do
      Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
t (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ())
-> Map Name [StructType] -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ Name -> [StructType] -> Map Name [StructType]
forall k a. k -> a -> Map k a
M.singleton Name
c [StructType]
fs
      () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

mustHaveFieldWith ::
  MonadUnify m =>
  UnifyDims m ->
  Usage ->
  BreadCrumbs ->
  Name ->
  PatternType ->
  m PatternType
mustHaveFieldWith :: forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> Name -> PatternType -> m PatternType
mustHaveFieldWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs Name
l PatternType
t = do
  Constraints
constraints <- m Constraints
forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  PatternType
l_type <- SrcLoc -> String -> m PatternType
forall (m :: * -> *) als dim.
(MonadUnify m, Monoid als) =>
SrcLoc -> String -> m (TypeBase dim als)
newTypeVar (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) String
"t"
  let l_type' :: StructType
l_type' = PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
l_type
  case PatternType
t of
    Scalar (TypeVar Aliasing
_ Uniqueness
_ (TypeName [VName]
_ VName
tn) [])
      | Just (Int
lvl, NoConstraint {}) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints -> do
        Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> BreadCrumbs -> VName -> Int -> StructType -> m ()
scopeCheck Usage
usage BreadCrumbs
bcs VName
tn Int
lvl StructType
l_type'
        (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
tn (Int
lvl, Map Name StructType -> Usage -> Constraint
HasFields (Name -> StructType -> Map Name StructType
forall k a. k -> a -> Map k a
M.singleton Name
l StructType
l_type') Usage
usage)
        PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
l_type
      | Just (Int
lvl, HasFields Map Name StructType
fields Usage
_) <- VName -> Constraints -> Maybe (Int, Constraint)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tn Constraints
constraints -> do
        case Name -> Map Name StructType -> Maybe StructType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
l Map Name StructType
fields of
          Just StructType
t' -> UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
onDims Usage
usage BreadCrumbs
bcs StructType
l_type' StructType
t'
          Maybe StructType
Nothing ->
            (Constraints -> Constraints) -> m ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> m ())
-> (Constraints -> Constraints) -> m ()
forall a b. (a -> b) -> a -> b
$
              VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                VName
tn
                (Int
lvl, Map Name StructType -> Usage -> Constraint
HasFields (Name -> StructType -> Map Name StructType -> Map Name StructType
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
l StructType
l_type' Map Name StructType
fields) Usage
usage)
        PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
l_type
    Scalar (Record Map Name PatternType
fields)
      | Just PatternType
t' <- Name -> Map Name PatternType -> Maybe PatternType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
l Map Name PatternType
fields -> do
        Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage StructType
l_type' (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t'
        PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
t'
      | Bool
otherwise ->
        Usage -> Notes -> BreadCrumbs -> Doc -> m PatternType
forall (m :: * -> *) loc a.
(MonadUnify m, Located loc) =>
loc -> Notes -> BreadCrumbs -> Doc -> m a
unifyError Usage
usage Notes
forall a. Monoid a => a
mempty BreadCrumbs
bcs (Doc -> m PatternType) -> Doc -> m PatternType
forall a b. (a -> b) -> a -> b
$
          Doc
"Attempt to access field" Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (Name -> Doc
forall a. Pretty a => a -> Doc
ppr Name
l) Doc -> Doc -> Doc
<+> Doc
" of value of type"
            Doc -> Doc -> Doc
<+> TypeBase () () -> Doc
forall a. Pretty a => a -> Doc
ppr (PatternType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural PatternType
t) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
    PatternType
_ -> do
      Usage -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify Usage
usage (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t) (StructType -> m ()) -> StructType -> m ()
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name StructType -> ScalarTypeBase (DimDecl VName) ())
-> Map Name StructType -> ScalarTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ Name -> StructType -> Map Name StructType
forall k a. k -> a -> Map k a
M.singleton Name
l StructType
l_type'
      PatternType -> m PatternType
forall (m :: * -> *) a. Monad m => a -> m a
return PatternType
l_type

-- | Assert that some type must have a field with this name and type.
mustHaveField ::
  MonadUnify m =>
  Usage ->
  Name ->
  PatternType ->
  m PatternType
mustHaveField :: forall (m :: * -> *).
MonadUnify m =>
Usage -> Name -> PatternType -> m PatternType
mustHaveField Usage
usage = UnifyDims m
-> Usage -> BreadCrumbs -> Name -> PatternType -> m PatternType
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> Name -> PatternType -> m PatternType
mustHaveFieldWith (Usage -> UnifyDims m
forall (m :: * -> *). MonadUnify m => Usage -> UnifyDims m
unifyDims Usage
usage) Usage
usage BreadCrumbs
noBreadCrumbs

-- | Replace dimension mismatches with AnyDim.
anyDimOnMismatch ::
  Monoid as =>
  TypeBase (DimDecl VName) as ->
  TypeBase (DimDecl VName) as ->
  (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
anyDimOnMismatch :: forall as.
Monoid as =>
TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
anyDimOnMismatch TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2 = Writer
  [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
forall w a. Writer w a -> (a, w)
runWriter (Writer
   [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
 -> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)]))
-> Writer
     [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
-> (TypeBase (DimDecl VName) as, [(DimDecl VName, DimDecl VName)])
forall a b. (a -> b) -> a -> b
$ (DimDecl VName
 -> DimDecl VName
 -> WriterT
      [(DimDecl VName, DimDecl VName)] Identity (DimDecl VName))
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> Writer
     [(DimDecl VName, DimDecl VName)] (TypeBase (DimDecl VName) as)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
(d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims DimDecl VName
-> DimDecl VName
-> WriterT
     [(DimDecl VName, DimDecl VName)] Identity (DimDecl VName)
forall {vn} {m :: * -> *}.
(Eq (DimDecl vn), MonadWriter [(DimDecl vn, DimDecl vn)] m) =>
DimDecl vn -> DimDecl vn -> m (DimDecl vn)
onDims TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2
  where
    onDims :: DimDecl vn -> DimDecl vn -> m (DimDecl vn)
onDims DimDecl vn
d1 DimDecl vn
d2
      | DimDecl vn
d1 DimDecl vn -> DimDecl vn -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl vn
d2 = DimDecl vn -> m (DimDecl vn)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl vn
d1
      | Bool
otherwise = do
        [(DimDecl vn, DimDecl vn)] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [(DimDecl vn
d1, DimDecl vn
d2)]
        DimDecl vn -> m (DimDecl vn)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl vn
forall vn. DimDecl vn
AnyDim

newDimOnMismatch ::
  (Monoid as, MonadUnify m) =>
  SrcLoc ->
  TypeBase (DimDecl VName) as ->
  TypeBase (DimDecl VName) as ->
  m (TypeBase (DimDecl VName) as, [VName])
newDimOnMismatch :: forall as (m :: * -> *).
(Monoid as, MonadUnify m) =>
SrcLoc
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> m (TypeBase (DimDecl VName) as, [VName])
newDimOnMismatch SrcLoc
loc TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2 = do
  (TypeBase (DimDecl VName) as
t, Map (DimDecl VName, DimDecl VName) VName
seen) <- StateT
  (Map (DimDecl VName, DimDecl VName) VName)
  m
  (TypeBase (DimDecl VName) as)
-> Map (DimDecl VName, DimDecl VName) VName
-> m (TypeBase (DimDecl VName) as,
      Map (DimDecl VName, DimDecl VName) VName)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((DimDecl VName
 -> DimDecl VName
 -> StateT
      (Map (DimDecl VName, DimDecl VName) VName) m (DimDecl VName))
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> StateT
     (Map (DimDecl VName, DimDecl VName) VName)
     m
     (TypeBase (DimDecl VName) as)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
(d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims DimDecl VName
-> DimDecl VName
-> StateT
     (Map (DimDecl VName, DimDecl VName) VName) m (DimDecl VName)
forall {t :: (* -> *) -> * -> *} {m :: * -> *}.
(MonadState (Map (DimDecl VName, DimDecl VName) VName) (t m),
 MonadTrans t, MonadUnify m) =>
DimDecl VName -> DimDecl VName -> t m (DimDecl VName)
onDims TypeBase (DimDecl VName) as
t1 TypeBase (DimDecl VName) as
t2) Map (DimDecl VName, DimDecl VName) VName
forall a. Monoid a => a
mempty
  (TypeBase (DimDecl VName) as, [VName])
-> m (TypeBase (DimDecl VName) as, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase (DimDecl VName) as
t, Map (DimDecl VName, DimDecl VName) VName -> [VName]
forall k a. Map k a -> [a]
M.elems Map (DimDecl VName, DimDecl VName) VName
seen)
  where
    r :: Rigidity
r = RigidSource -> Rigidity
Rigid (RigidSource -> Rigidity) -> RigidSource -> Rigidity
forall a b. (a -> b) -> a -> b
$ StructType -> StructType -> RigidSource
RigidCond (TypeBase (DimDecl VName) as -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase (DimDecl VName) as
t1) (TypeBase (DimDecl VName) as -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct TypeBase (DimDecl VName) as
t2)
    onDims :: DimDecl VName -> DimDecl VName -> t m (DimDecl VName)
onDims DimDecl VName
d1 DimDecl VName
d2
      | DimDecl VName
d1 DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d2 = DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl VName
d1
      | Bool
otherwise = do
        -- Remember mismatches we have seen before and reuse the
        -- same new size.
        Maybe VName
maybe_d <- (Map (DimDecl VName, DimDecl VName) VName -> Maybe VName)
-> t m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map (DimDecl VName, DimDecl VName) VName -> Maybe VName)
 -> t m (Maybe VName))
-> (Map (DimDecl VName, DimDecl VName) VName -> Maybe VName)
-> t m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (DimDecl VName, DimDecl VName)
-> Map (DimDecl VName, DimDecl VName) VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (DimDecl VName
d1, DimDecl VName
d2)
        case Maybe VName
maybe_d of
          Just VName
d -> DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d
          Maybe VName
Nothing -> do
            VName
d <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> String -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> String -> m VName
newDimVar SrcLoc
loc Rigidity
r String
"differ"
            (Map (DimDecl VName, DimDecl VName) VName
 -> Map (DimDecl VName, DimDecl VName) VName)
-> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map (DimDecl VName, DimDecl VName) VName
  -> Map (DimDecl VName, DimDecl VName) VName)
 -> t m ())
-> (Map (DimDecl VName, DimDecl VName) VName
    -> Map (DimDecl VName, DimDecl VName) VName)
-> t m ()
forall a b. (a -> b) -> a -> b
$ (DimDecl VName, DimDecl VName)
-> VName
-> Map (DimDecl VName, DimDecl VName) VName
-> Map (DimDecl VName, DimDecl VName) VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (DimDecl VName
d1, DimDecl VName
d2) VName
d
            DimDecl VName -> t m (DimDecl VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d

-- | Like unification, but creates new size variables where mismatches
-- occur.  Returns the new dimensions thus created.
unifyMostCommon ::
  MonadUnify m =>
  Usage ->
  PatternType ->
  PatternType ->
  m (PatternType, [VName])
unifyMostCommon :: forall (m :: * -> *).
MonadUnify m =>
Usage -> PatternType -> PatternType -> m (PatternType, [VName])
unifyMostCommon Usage
usage PatternType
t1 PatternType
t2 = do
  -- We are ignoring the dimensions here, because any mismatches
  -- should be turned into fresh size variables.
  let allOK :: p -> p -> p -> p -> p -> m ()
allOK p
_ p
_ p
_ p
_ p
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
forall (m :: * -> *).
MonadUnify m =>
UnifyDims m
-> Usage -> BreadCrumbs -> StructType -> StructType -> m ()
unifyWith UnifyDims m
forall {m :: * -> *} {p} {p} {p} {p} {p}.
Monad m =>
p -> p -> p -> p -> p -> m ()
allOK Usage
usage BreadCrumbs
noBreadCrumbs (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t1) (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
t2)
  PatternType
t1' <- PatternType -> m PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t1
  PatternType
t2' <- PatternType -> m PatternType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatternType
t2
  SrcLoc -> PatternType -> PatternType -> m (PatternType, [VName])
forall as (m :: * -> *).
(Monoid as, MonadUnify m) =>
SrcLoc
-> TypeBase (DimDecl VName) as
-> TypeBase (DimDecl VName) as
-> m (TypeBase (DimDecl VName) as, [VName])
newDimOnMismatch (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
usage) PatternType
t1' PatternType
t2'

-- Simple MonadUnify implementation.

type UnifyMState = (Constraints, Int)

newtype UnifyM a = UnifyM (StateT UnifyMState (Except TypeError) a)
  deriving
    ( Applicative UnifyM
Applicative UnifyM
-> (forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b)
-> (forall a b. UnifyM a -> UnifyM b -> UnifyM b)
-> (forall a. a -> UnifyM a)
-> Monad UnifyM
forall a. a -> UnifyM a
forall a b. UnifyM a -> UnifyM b -> UnifyM b
forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> UnifyM a
$creturn :: forall a. a -> UnifyM a
>> :: forall a b. UnifyM a -> UnifyM b -> UnifyM b
$c>> :: forall a b. UnifyM a -> UnifyM b -> UnifyM b
>>= :: forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b
$c>>= :: forall a b. UnifyM a -> (a -> UnifyM b) -> UnifyM b
Monad,
      (forall a b. (a -> b) -> UnifyM a -> UnifyM b)
-> (forall a b. a -> UnifyM b -> UnifyM a) -> Functor UnifyM
forall a b. a -> UnifyM b -> UnifyM a
forall a b. (a -> b) -> UnifyM a -> UnifyM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> UnifyM b -> UnifyM a
$c<$ :: forall a b. a -> UnifyM b -> UnifyM a
fmap :: forall a b. (a -> b) -> UnifyM a -> UnifyM b
$cfmap :: forall a b. (a -> b) -> UnifyM a -> UnifyM b
Functor,
      Functor UnifyM
Functor UnifyM
-> (forall a. a -> UnifyM a)
-> (forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b)
-> (forall a b c.
    (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c)
-> (forall a b. UnifyM a -> UnifyM b -> UnifyM b)
-> (forall a b. UnifyM a -> UnifyM b -> UnifyM a)
-> Applicative UnifyM
forall a. a -> UnifyM a
forall a b. UnifyM a -> UnifyM b -> UnifyM a
forall a b. UnifyM a -> UnifyM b -> UnifyM b
forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b
forall a b c. (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. UnifyM a -> UnifyM b -> UnifyM a
$c<* :: forall a b. UnifyM a -> UnifyM b -> UnifyM a
*> :: forall a b. UnifyM a -> UnifyM b -> UnifyM b
$c*> :: forall a b. UnifyM a -> UnifyM b -> UnifyM b
liftA2 :: forall a b c. (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
$cliftA2 :: forall a b c. (a -> b -> c) -> UnifyM a -> UnifyM b -> UnifyM c
<*> :: forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b
$c<*> :: forall a b. UnifyM (a -> b) -> UnifyM a -> UnifyM b
pure :: forall a. a -> UnifyM a
$cpure :: forall a. a -> UnifyM a
Applicative,
      MonadState UnifyMState,
      MonadError TypeError
    )

newVar :: String -> UnifyM VName
newVar :: String -> UnifyM VName
newVar String
name = do
  (Constraints
x, Int
i) <- UnifyM (Constraints, Int)
forall s (m :: * -> *). MonadState s m => m s
get
  (Constraints, Int) -> UnifyM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Constraints
x, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  VName -> UnifyM VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> UnifyM VName) -> VName -> UnifyM VName
forall a b. (a -> b) -> a -> b
$ Name -> Int -> VName
VName (String -> Int -> Name
mkTypeVarName String
name Int
i) Int
i

instance MonadUnify UnifyM where
  getConstraints :: UnifyM Constraints
getConstraints = ((Constraints, Int) -> Constraints) -> UnifyM Constraints
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Constraints, Int) -> Constraints
forall a b. (a, b) -> a
fst
  putConstraints :: Constraints -> UnifyM ()
putConstraints Constraints
x = ((Constraints, Int) -> (Constraints, Int)) -> UnifyM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Constraints, Int) -> (Constraints, Int)) -> UnifyM ())
-> ((Constraints, Int) -> (Constraints, Int)) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ \(Constraints
_, Int
i) -> (Constraints
x, Int
i)

  newTypeVar :: forall als dim.
Monoid als =>
SrcLoc -> String -> UnifyM (TypeBase dim als)
newTypeVar SrcLoc
loc String
name = do
    VName
v <- String -> UnifyM VName
newVar String
name
    (Constraints -> Constraints) -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> UnifyM ())
-> (Constraints -> Constraints) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (Int
0, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
Lifted (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc)
    TypeBase dim als -> UnifyM (TypeBase dim als)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase dim als -> UnifyM (TypeBase dim als))
-> TypeBase dim als -> UnifyM (TypeBase dim als)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase dim als -> TypeBase dim als
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim als -> TypeBase dim als)
-> ScalarTypeBase dim als -> TypeBase dim als
forall a b. (a -> b) -> a -> b
$ als
-> Uniqueness
-> TypeName
-> [TypeArg dim]
-> ScalarTypeBase dim als
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar als
forall a. Monoid a => a
mempty Uniqueness
Nonunique (VName -> TypeName
typeName VName
v) []

  newDimVar :: SrcLoc -> Rigidity -> String -> UnifyM VName
newDimVar SrcLoc
loc Rigidity
rigidity String
name = do
    VName
dim <- String -> UnifyM VName
newVar String
name
    case Rigidity
rigidity of
      Rigid RigidSource
src -> (Constraints -> Constraints) -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> UnifyM ())
-> (Constraints -> Constraints) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
dim (Int
0, SrcLoc -> RigidSource -> Constraint
UnknowableSize SrcLoc
loc RigidSource
src)
      Rigidity
Nonrigid -> (Constraints -> Constraints) -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> UnifyM ())
-> (Constraints -> Constraints) -> UnifyM ()
forall a b. (a -> b) -> a -> b
$ VName -> (Int, Constraint) -> Constraints -> Constraints
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
dim (Int
0, Maybe (DimDecl VName) -> Usage -> Constraint
Size Maybe (DimDecl VName)
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc)
    VName -> UnifyM VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
dim

  curLevel :: UnifyM Int
curLevel = Int -> UnifyM Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0

  unifyError :: forall loc a.
Located loc =>
loc -> Notes -> BreadCrumbs -> Doc -> UnifyM a
unifyError loc
loc Notes
notes BreadCrumbs
bcs Doc
doc =
    TypeError -> UnifyM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> UnifyM a) -> TypeError -> UnifyM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs

  matchError :: forall loc a.
Located loc =>
loc -> Notes -> BreadCrumbs -> StructType -> StructType -> UnifyM a
matchError loc
loc Notes
notes BreadCrumbs
bcs StructType
t1 StructType
t2 =
    TypeError -> UnifyM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> UnifyM a) -> TypeError -> UnifyM a
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Notes -> Doc -> TypeError
TypeError (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) Notes
notes (Doc -> TypeError) -> Doc -> TypeError
forall a b. (a -> b) -> a -> b
$ Doc
doc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> BreadCrumbs -> Doc
forall a. Pretty a => a -> Doc
ppr BreadCrumbs
bcs
    where
      doc :: Doc
doc =
        Doc
"Types"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t1)
          Doc -> Doc -> Doc
</> Doc
"and"
          Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (StructType -> Doc
forall a. Pretty a => a -> Doc
ppr StructType
t2)
          Doc -> Doc -> Doc
</> Doc
"do not match."

-- | Construct the name of a new type variable given a base
-- description and a tag number (note that this is distinct from
-- actually constructing a VName; the tag here is intended for human
-- consumption but the machine does not care).
mkTypeVarName :: String -> Int -> Name
mkTypeVarName :: String -> Int -> Name
mkTypeVarName String
desc Int
i =
  String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Char -> Maybe Char) -> ShowS
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Char -> Maybe Char
subscript (Int -> String
forall a. Show a => a -> String
show Int
i)
  where
    subscript :: Char -> Maybe Char
subscript = (Char -> [(Char, Char)] -> Maybe Char)
-> [(Char, Char)] -> Char -> Maybe Char
forall a b c. (a -> b -> c) -> b -> a -> c
flip Char -> [(Char, Char)] -> Maybe Char
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ([(Char, Char)] -> Char -> Maybe Char)
-> [(Char, Char)] -> Char -> Maybe Char
forall a b. (a -> b) -> a -> b
$ String -> String -> [(Char, Char)]
forall a b. [a] -> [b] -> [(a, b)]
zip String
"0123456789" String
"₀₁₂₃₄₅₆₇₈₉"

runUnifyM :: [TypeParam] -> UnifyM a -> Either TypeError a
runUnifyM :: forall a. [TypeParam] -> UnifyM a -> Either TypeError a
runUnifyM [TypeParam]
tparams (UnifyM StateT (Constraints, Int) (Except TypeError) a
m) = Except TypeError a -> Either TypeError a
forall e a. Except e a -> Either e a
runExcept (Except TypeError a -> Either TypeError a)
-> Except TypeError a -> Either TypeError a
forall a b. (a -> b) -> a -> b
$ StateT (Constraints, Int) (Except TypeError) a
-> (Constraints, Int) -> Except TypeError a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT (Constraints, Int) (Except TypeError) a
m (Constraints
constraints, Int
0)
  where
    constraints :: Constraints
constraints = [(VName, (Int, Constraint))] -> Constraints
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (Int, Constraint))] -> Constraints)
-> [(VName, (Int, Constraint))] -> Constraints
forall a b. (a -> b) -> a -> b
$ (TypeParam -> (VName, (Int, Constraint)))
-> [TypeParam] -> [(VName, (Int, Constraint))]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> (VName, (Int, Constraint))
forall {a} {a}. Num a => TypeParamBase a -> (a, (a, Constraint))
f [TypeParam]
tparams
    f :: TypeParamBase a -> (a, (a, Constraint))
f (TypeParamDim a
p SrcLoc
loc) = (a
p, (a
0, Maybe (DimDecl VName) -> Usage -> Constraint
Size Maybe (DimDecl VName)
forall a. Maybe a
Nothing (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc))
    f (TypeParamType Liftedness
l a
p SrcLoc
loc) = (a
p, (a
0, Liftedness -> Usage -> Constraint
NoConstraint Liftedness
l (Usage -> Constraint) -> Usage -> Constraint
forall a b. (a -> b) -> a -> b
$ Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc))

-- | Perform a unification of two types outside a monadic context.
-- The type parameters are allowed to be instantiated; all other types
-- are considered rigid.
doUnification ::
  SrcLoc ->
  [TypeParam] ->
  StructType ->
  StructType ->
  Either TypeError StructType
doUnification :: SrcLoc
-> [TypeParam]
-> StructType
-> StructType
-> Either TypeError StructType
doUnification SrcLoc
loc [TypeParam]
tparams StructType
t1 StructType
t2 = [TypeParam] -> UnifyM StructType -> Either TypeError StructType
forall a. [TypeParam] -> UnifyM a -> Either TypeError a
runUnifyM [TypeParam]
tparams (UnifyM StructType -> Either TypeError StructType)
-> UnifyM StructType -> Either TypeError StructType
forall a b. (a -> b) -> a -> b
$ do
  let rsrc :: RigidSource
rsrc = RigidSource
RigidUnify
  (StructType
t1', [VName]
_) <- SrcLoc
-> String -> Rigidity -> StructType -> UnifyM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"n" (RigidSource -> Rigidity
Rigid RigidSource
rsrc) StructType
t1
  (StructType
t2', [VName]
_) <- SrcLoc
-> String -> Rigidity -> StructType -> UnifyM (StructType, [VName])
forall (m :: * -> *) als.
MonadUnify m =>
SrcLoc
-> String
-> Rigidity
-> TypeBase (DimDecl VName) als
-> m (TypeBase (DimDecl VName) als, [VName])
instantiateEmptyArrayDims SrcLoc
loc String
"m" (RigidSource -> Rigidity
Rigid RigidSource
rsrc) StructType
t2
  Usage -> StructType -> StructType -> UnifyM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
expect (Maybe String -> SrcLoc -> Usage
Usage Maybe String
forall a. Maybe a
Nothing SrcLoc
loc) StructType
t1' StructType
t2'
  StructType -> UnifyM StructType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
t2