{-# LANGUAGE OverloadedStrings #-}
module Language.Futhark.TypeChecker.Match
( unmatched,
Match,
)
where
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Util (maybeHead, nubOrd)
import Futhark.Util.Pretty hiding (bool, group, space)
import Language.Futhark hiding (ExpBase (Constr))
data Constr
= Constr Name
| ConstrTuple
| ConstrRecord [Name]
|
ConstrLit PatLit
deriving (Constr -> Constr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Constr -> Constr -> Bool
$c/= :: Constr -> Constr -> Bool
== :: Constr -> Constr -> Bool
$c== :: Constr -> Constr -> Bool
Eq, Eq Constr
Constr -> Constr -> Bool
Constr -> Constr -> Ordering
Constr -> Constr -> Constr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Constr -> Constr -> Constr
$cmin :: Constr -> Constr -> Constr
max :: Constr -> Constr -> Constr
$cmax :: Constr -> Constr -> Constr
>= :: Constr -> Constr -> Bool
$c>= :: Constr -> Constr -> Bool
> :: Constr -> Constr -> Bool
$c> :: Constr -> Constr -> Bool
<= :: Constr -> Constr -> Bool
$c<= :: Constr -> Constr -> Bool
< :: Constr -> Constr -> Bool
$c< :: Constr -> Constr -> Bool
compare :: Constr -> Constr -> Ordering
$ccompare :: Constr -> Constr -> Ordering
Ord, Int -> Constr -> ShowS
[Constr] -> ShowS
Constr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Constr] -> ShowS
$cshowList :: [Constr] -> ShowS
show :: Constr -> String
$cshow :: Constr -> String
showsPrec :: Int -> Constr -> ShowS
$cshowsPrec :: Int -> Constr -> ShowS
Show)
data Match
= MatchWild StructType
| MatchConstr Constr [Match] StructType
deriving (Match -> Match -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Match -> Match -> Bool
$c/= :: Match -> Match -> Bool
== :: Match -> Match -> Bool
$c== :: Match -> Match -> Bool
Eq, Eq Match
Match -> Match -> Bool
Match -> Match -> Ordering
Match -> Match -> Match
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Match -> Match -> Match
$cmin :: Match -> Match -> Match
max :: Match -> Match -> Match
$cmax :: Match -> Match -> Match
>= :: Match -> Match -> Bool
$c>= :: Match -> Match -> Bool
> :: Match -> Match -> Bool
$c> :: Match -> Match -> Bool
<= :: Match -> Match -> Bool
$c<= :: Match -> Match -> Bool
< :: Match -> Match -> Bool
$c< :: Match -> Match -> Bool
compare :: Match -> Match -> Ordering
$ccompare :: Match -> Match -> Ordering
Ord, Int -> Match -> ShowS
[Match] -> ShowS
Match -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Match] -> ShowS
$cshowList :: [Match] -> ShowS
show :: Match -> String
$cshow :: Match -> String
showsPrec :: Int -> Match -> ShowS
$cshowsPrec :: Int -> Match -> ShowS
Show)
matchType :: Match -> StructType
matchType :: Match -> StructType
matchType (MatchWild StructType
t) = StructType
t
matchType (MatchConstr Constr
_ [Match]
_ StructType
t) = StructType
t
pprMatch :: Int -> Match -> Doc
pprMatch :: Int -> Match -> Doc
pprMatch Int
_ MatchWild {} = Doc
"_"
pprMatch Int
_ (MatchConstr (ConstrLit PatLit
l) [Match]
_ StructType
_) = forall a. Pretty a => a -> Doc
ppr PatLit
l
pprMatch Int
p (MatchConstr (Constr Name
c) [Match]
ps StructType
_) =
Bool -> Doc -> Doc
parensIf (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Match]
ps) Bool -> Bool -> Bool
&& Int
p forall a. Ord a => a -> a -> Bool
>= Int
10) forall a b. (a -> b) -> a -> b
$
Doc
"#" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Doc
ppr Name
c forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map ((Doc
" " forall a. Semigroup a => a -> a -> a
<>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Match -> Doc
pprMatch Int
10) [Match]
ps)
pprMatch Int
_ (MatchConstr Constr
ConstrTuple [Match]
ps StructType
_) =
Doc -> Doc
parens forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Int -> Match -> Doc
pprMatch (-Int
1)) [Match]
ps
pprMatch Int
_ (MatchConstr (ConstrRecord [Name]
fs) [Match]
ps StructType
_) =
Doc -> Doc
braces forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Match -> Doc
ppField [Name]
fs [Match]
ps
where
ppField :: Name -> Match -> Doc
ppField Name
name Match
t = String -> Doc
text (Name -> String
nameToString Name
name) forall a. Semigroup a => a -> a -> a
<> Doc
equals forall a. Semigroup a => a -> a -> a
<> Int -> Match -> Doc
pprMatch (-Int
1) Match
t
instance Pretty Match where
ppr :: Match -> Doc
ppr = Int -> Match -> Doc
pprMatch (-Int
1)
patternToMatch :: Pat -> Match
patternToMatch :: Pat -> Match
patternToMatch (Id VName
_ (Info PatType
t) SrcLoc
_) = StructType -> Match
MatchWild forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
patternToMatch (Wildcard (Info PatType
t) SrcLoc
_) = StructType -> Match
MatchWild forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
patternToMatch (PatParens Pat
p SrcLoc
_) = Pat -> Match
patternToMatch Pat
p
patternToMatch (PatAttr AttrInfo VName
_ Pat
p SrcLoc
_) = Pat -> Match
patternToMatch Pat
p
patternToMatch (PatAscription Pat
p TypeExp VName
_ SrcLoc
_) = Pat -> Match
patternToMatch Pat
p
patternToMatch (PatLit PatLit
l (Info PatType
t) SrcLoc
_) =
Constr -> [Match] -> StructType -> Match
MatchConstr (PatLit -> Constr
ConstrLit PatLit
l) [] forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
patternToMatch p :: Pat
p@(TuplePat [Pat]
ps SrcLoc
_) =
Constr -> [Match] -> StructType -> Match
MatchConstr Constr
ConstrTuple (forall a b. (a -> b) -> [a] -> [b]
map Pat -> Match
patternToMatch [Pat]
ps) forall a b. (a -> b) -> a -> b
$
Pat -> StructType
patternStructType Pat
p
patternToMatch p :: Pat
p@(RecordPat [(Name, Pat)]
fs SrcLoc
_) =
Constr -> [Match] -> StructType -> Match
MatchConstr ([Name] -> Constr
ConstrRecord [Name]
fnames) (forall a b. (a -> b) -> [a] -> [b]
map Pat -> Match
patternToMatch [Pat]
ps) forall a b. (a -> b) -> a -> b
$
Pat -> StructType
patternStructType Pat
p
where
([Name]
fnames, [Pat]
ps) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. Map Name a -> [(Name, a)]
sortFields forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Pat)]
fs
patternToMatch (PatConstr Name
c (Info PatType
t) [Pat]
args SrcLoc
_) =
Constr -> [Match] -> StructType -> Match
MatchConstr (Name -> Constr
Constr Name
c) (forall a b. (a -> b) -> [a] -> [b]
map Pat -> Match
patternToMatch [Pat]
args) forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
isConstr :: Match -> Maybe Name
isConstr :: Match -> Maybe Name
isConstr (MatchConstr (Constr Name
c) [Match]
_ StructType
_) = forall a. a -> Maybe a
Just Name
c
isConstr Match
_ = forall a. Maybe a
Nothing
complete :: [Match] -> Bool
complete :: [Match] -> Bool
complete [Match]
xs
| Just Match
x <- forall a. [a] -> Maybe a
maybeHead [Match]
xs,
Scalar (Sum Map Name [StructType]
all_cs) <- Match -> StructType
matchType Match
x,
Just [Name]
xs_cs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Match -> Maybe Name
isConstr [Match]
xs =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
xs_cs) (forall k a. Map k a -> [k]
M.keys Map Name [StructType]
all_cs)
| Bool
otherwise =
(forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> Match -> Bool
isBool Bool
True) [Match]
xs Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> Match -> Bool
isBool Bool
False) [Match]
xs)
Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Match -> Bool
isRecord [Match]
xs
Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Match -> Bool
isTuple [Match]
xs
where
isBool :: Bool -> Match -> Bool
isBool Bool
b1 (MatchConstr (ConstrLit (PatLitPrim (BoolValue Bool
b2))) [Match]
_ StructType
_) = Bool
b1 forall a. Eq a => a -> a -> Bool
== Bool
b2
isBool Bool
_ Match
_ = Bool
False
isRecord :: Match -> Bool
isRecord (MatchConstr ConstrRecord {} [Match]
_ StructType
_) = Bool
True
isRecord Match
_ = Bool
False
isTuple :: Match -> Bool
isTuple (MatchConstr Constr
ConstrTuple [Match]
_ StructType
_) = Bool
True
isTuple Match
_ = Bool
False
specialise :: [StructType] -> Match -> [[Match]] -> [[Match]]
specialise :: [StructType] -> Match -> [[Match]] -> [[Match]]
specialise [StructType]
ats Match
c1 = [[Match]] -> [[Match]]
go
where
go :: [[Match]] -> [[Match]]
go ((Match
c2 : [Match]
row) : [[Match]]
ps)
| Just [Match]
args <- Match -> Match -> Maybe [Match]
match Match
c1 Match
c2 =
([Match]
args forall a. [a] -> [a] -> [a]
++ [Match]
row) forall a. a -> [a] -> [a]
: [[Match]] -> [[Match]]
go [[Match]]
ps
| Bool
otherwise =
[[Match]] -> [[Match]]
go [[Match]]
ps
go [[Match]]
_ = []
match :: Match -> Match -> Maybe [Match]
match (MatchConstr Constr
c1' [Match]
_ StructType
_) (MatchConstr Constr
c2' [Match]
args StructType
_)
| Constr
c1' forall a. Eq a => a -> a -> Bool
== Constr
c2' =
forall a. a -> Maybe a
Just [Match]
args
| Bool
otherwise =
forall a. Maybe a
Nothing
match Match
_ MatchWild {} =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map StructType -> Match
MatchWild [StructType]
ats
match Match
_ Match
_ =
forall a. Maybe a
Nothing
defaultMat :: [[Match]] -> [[Match]]
defaultMat :: [[Match]] -> [[Match]]
defaultMat = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [Match] -> Maybe [Match]
onRow
where
onRow :: [Match] -> Maybe [Match]
onRow (MatchConstr {} : [Match]
_) = forall a. Maybe a
Nothing
onRow (MatchWild {} : [Match]
ps) = forall a. a -> Maybe a
Just [Match]
ps
onRow [] = forall a. Maybe a
Nothing
findUnmatched :: [[Match]] -> Int -> [[Match]]
findUnmatched :: [[Match]] -> Int -> [[Match]]
findUnmatched [[Match]]
pmat Int
n
| ((Match
p : [Match]
_) : [[Match]]
_) <- [[Match]]
pmat,
Just [Match]
heads <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a. [a] -> Maybe a
maybeHead [[Match]]
pmat =
if [Match] -> Bool
complete [Match]
heads
then [Match] -> [[Match]]
completeCase [Match]
heads
else StructType -> [Match] -> [[Match]]
incompleteCase (Match -> StructType
matchType Match
p) [Match]
heads
where
completeCase :: [Match] -> [[Match]]
completeCase [Match]
cs = do
Match
c <- [Match]
cs
let ats :: [StructType]
ats = case Match
c of
MatchConstr Constr
_ [Match]
args StructType
_ -> forall a b. (a -> b) -> [a] -> [b]
map Match -> StructType
matchType [Match]
args
MatchWild StructType
_ -> []
a_k :: Int
a_k = forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
ats
pmat' :: [[Match]]
pmat' = [StructType] -> Match -> [[Match]] -> [[Match]]
specialise [StructType]
ats Match
c [[Match]]
pmat
[Match]
u <- [[Match]] -> Int -> [[Match]]
findUnmatched [[Match]]
pmat' (Int
a_k forall a. Num a => a -> a -> a
+ Int
n forall a. Num a => a -> a -> a
- Int
1)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Match
c of
MatchConstr Constr
c' [Match]
_ StructType
t ->
let ([Match]
r, [Match]
p) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
a_k [Match]
u
in Constr -> [Match] -> StructType -> Match
MatchConstr Constr
c' [Match]
r StructType
t forall a. a -> [a] -> [a]
: [Match]
p
MatchWild StructType
t ->
StructType -> Match
MatchWild StructType
t forall a. a -> [a] -> [a]
: [Match]
u
incompleteCase :: StructType -> [Match] -> [[Match]]
incompleteCase StructType
pt [Match]
cs = do
[Match]
u <- [[Match]] -> Int -> [[Match]]
findUnmatched ([[Match]] -> [[Match]]
defaultMat [[Match]]
pmat) (Int
n forall a. Num a => a -> a -> a
- Int
1)
if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Match]
cs
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StructType -> Match
MatchWild StructType
pt forall a. a -> [a] -> [a]
: [Match]
u
else case StructType
pt of
Scalar (Sum Map Name [StructType]
all_cs) -> do
let sigma :: [Name]
sigma = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Match -> Maybe Name
isConstr [Match]
cs
notCovered :: (Name, b) -> Bool
notCovered (Name
k, b
_) = Name
k forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Name]
sigma
(Name
cname, [StructType]
ts) <- forall a. (a -> Bool) -> [a] -> [a]
filter forall {b}. (Name, b) -> Bool
notCovered forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map Name [StructType]
all_cs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Constr -> [Match] -> StructType -> Match
MatchConstr (Name -> Constr
Constr Name
cname) (forall a b. (a -> b) -> [a] -> [b]
map StructType -> Match
MatchWild [StructType]
ts) StructType
pt forall a. a -> [a] -> [a]
: [Match]
u
StructType
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ StructType -> Match
MatchWild StructType
pt forall a. a -> [a] -> [a]
: [Match]
u
findUnmatched [] Int
_ = [[]]
findUnmatched [[Match]]
_ Int
_ = []
{-# NOINLINE unmatched #-}
unmatched :: [Pat] -> [Match]
unmatched :: [Pat] -> [Match]
unmatched [Pat]
orig_ps =
forall a. Ord a => [a] -> [a]
nubOrd forall a b. (a -> b) -> a -> b
$
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall a. [a] -> Maybe a
maybeHead forall a b. (a -> b) -> a -> b
$
[[Match]] -> Int -> [[Match]]
findUnmatched (forall a b. (a -> b) -> [a] -> [b]
map ((forall a. a -> [a] -> [a]
: []) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> Match
patternToMatch) [Pat]
orig_ps) Int
1