-- | Defines global checks, i.e. checks which need to look at more than one
--   declaration at a time.

module Language.Haskell.FreeTheorems.Frontend.CheckGlobal (checkGlobal) where



import Control.Monad (when)
import Control.Monad.Error (throwError)
import Control.Monad.Writer (tell)
import Data.Generics (Typeable, Data, everything, everywhereM, extQ, mkQ, mkM)
import Data.List (intersperse, partition, nub, intersect)
import qualified Data.Map as Map (Map, empty, insert, lookup)
import Data.Maybe (mapMaybe, fromJust)
import qualified Data.Set as Set
    ( Set, empty, singleton, union, fromList, isSubsetOf, member, difference
    , partition, null, elems, size )

import Language.Haskell.FreeTheorems.BasicSyntax
import Language.Haskell.FreeTheorems.ValidSyntax
import Language.Haskell.FreeTheorems.Frontend.Error





------- Global checks ---------------------------------------------------------


-- | Perform global checks, i.e. looks at more than one declaration at a time.
--   The following restrictions will be checked:
--
--   * Every symbol is declared at most once.
--   
--   * Every type constructor is used in the arity it was declared with.
--
--   * Type synonyms are not mutually recursive.
--
--   * The type class hierachy is acyclic.
--
--   * In every type expression, only declared type constructors and only
--     declared type classes occur.

checkGlobal :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration] 
checkGlobal :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkGlobal [ValidDeclaration]
vds [Declaration]
ds =
  -- run through all declarations in 'ds' to test whether any name occurs twice
  [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkUnique [ValidDeclaration]
vds [Declaration]
ds

  -- then, run through all remaining declarations and check the arities of all
  -- type constructors
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkArities [ValidDeclaration]
vds

  -- extract all type synonyms which are not mutually recursive
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Declaration] -> Checked [Declaration]
checkAcyclicTypeSynonyms

  -- extract all type classes whose type hierarchy is acyclic
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Declaration] -> Checked [Declaration]
checkAcyclicTypeClasses

  -- finally, take only those declarations which contain only declared type
  -- constructors and type classes
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkAllConsAndClassesDeclared [ValidDeclaration]
vds





------- Check that declarations are unique ------------------------------------


-- | Checks that every name has at most one declaration, or that there are no
--   two declarations with the same name.
--
--   The first argument gives a list of already checked declarations against
--   which the second argument is tested. The resulting list contains all
--   elements of the first argument and only the valid declarations of the
--   second argument.

checkUnique :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkUnique :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkUnique [ValidDeclaration]
vds [Declaration]
ds =
  let -- extract all known declaration names, both from 'vds' and from 'ds'
      knownNames :: [Identifier]
knownNames = forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName (forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration [ValidDeclaration]
vds forall a. [a] -> [a] -> [a]
++ [Declaration]
ds)
    
      -- test if the name of a declaration occurs more than once in 'knownNames'
      occursMoreThanOnce :: Declaration -> Bool
occursMoreThanOnce Declaration
d = 
        let allOccurrences :: [Identifier]
allOccurrences = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
== (Declaration -> Identifier
getDeclarationName Declaration
d)) [Identifier]
knownNames
         in forall (t :: * -> *) a. Foldable t => t a -> Int
length [Identifier]
allOccurrences forall a. Ord a => a -> a -> Bool
> Int
1
 
      -- construct a list 'us' of all unique declarations and a list 'ms' of all
      -- declarations which names occur more than once
      ([Declaration]
ms, [Declaration]
us) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Declaration -> Bool
occursMoreThanOnce [Declaration]
ds

      -- extract the names which occur more than once
      multiples :: [String]
multiples = forall a b. (a -> b) -> [a] -> [b]
map Identifier -> String
unpackIdent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => [a] -> [a]
nub forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName forall a b. (a -> b) -> a -> b
$ [Declaration]
ms

      error :: String -> [Doc]
error String
s = [String -> Doc
pp (String
"Multiple declarations for `" forall a. [a] -> [a] -> [a]
++ String
s forall a. [a] -> [a] -> [a]
++ String
"'.")]
   
   in do forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [String]
multiples)) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [Doc]
error) [String]
multiples
         forall (m :: * -> *) a. Monad m => a -> m a
return [Declaration]
us





------- Check arities of type constructors ------------------------------------


-- | Checks the arity of all type constructors. If an undeclared type
--   constructor is found, no arity check will be performed, because
--   any declaration containing undeclared type constructors will be filtered
--   out in the next step of checking (see 'checkGlobal').

checkArities :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkArities :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkArities [ValidDeclaration]
vds [Declaration]
ds =
  let -- build a map of arities
      mkMap :: Declaration -> Map Identifier Int -> Map Identifier Int
mkMap Declaration
d Map Identifier Int
m = case Declaration -> Maybe Int
getDeclarationArity Declaration
d of
                    Maybe Int
Nothing     -> Map Identifier Int
m
                    Just Int
arity  -> forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Declaration -> Identifier
getDeclarationName Declaration
d) Int
arity Map Identifier Int
m
      arityMap :: Map Identifier Int
arityMap = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Declaration -> Map Identifier Int -> Map Identifier Int
mkMap forall k a. Map k a
Map.empty (forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration [ValidDeclaration]
vds forall a. [a] -> [a] -> [a]
++ [Declaration]
ds)

   in forall a b. (a -> ErrorOr b) -> [a] -> Checked [a]
foldChecks (\Declaration
d -> forall a. Declaration -> ErrorOr a -> ErrorOr a
inDecl Declaration
d forall a b. (a -> b) -> a -> b
$ forall a.
(Typeable a, Data a) =>
Map Identifier Int -> a -> ErrorOr a
checkArity Map Identifier Int
arityMap Declaration
d) [Declaration]
ds



-- | Checks the arities of all occurring type constructors according to the 
--   given arity map.

checkArity :: (Typeable a, Data a) => Map.Map Identifier Int -> a -> ErrorOr a
checkArity :: forall a.
(Typeable a, Data a) =>
Map Identifier Int -> a -> ErrorOr a
checkArity Map Identifier Int
arityMap = forall (m :: * -> *). Monad m => GenericM m -> GenericM m
everywhereM (forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM TypeExpression -> Either Doc TypeExpression
checkCorrectArity)
  where
    -- extracts the type constructors and relates expected and found arities
    checkCorrectArity :: TypeExpression -> Either Doc TypeExpression
checkCorrectArity TypeExpression
t = case TypeExpression
t of
      TypeCon TypeConstructor
ConUnit [TypeExpression]
ts    -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"()"      Int
0 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon TypeConstructor
ConList [TypeExpression]
ts    -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"[]"      Int
1 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon TypeConstructor
ConInt [TypeExpression]
ts     -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"Int"     Int
0 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon TypeConstructor
ConInteger [TypeExpression]
ts -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"Integer" Int
0 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon TypeConstructor
ConFloat [TypeExpression]
ts   -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"Float"   Int
0 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon TypeConstructor
ConDouble [TypeExpression]
ts  -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"Double"  Int
0 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon TypeConstructor
ConChar [TypeExpression]
ts    -> forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
"Char"    Int
0 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      TypeCon (Con Identifier
c) [TypeExpression]
ts    -> case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Identifier
c Map Identifier Int
arityMap of
                                 Maybe Int
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return TypeExpression
t
                                 Just Int
i  -> let n :: String
n = Identifier -> String
unpackIdent Identifier
c
                                             in forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t String
n Int
i (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
      
      TypeCon (ConTuple Int
n) [TypeExpression]
ts -> do
        forall {a} {b}.
(Eq a, Num a, Show a) =>
b -> String -> a -> a -> Either Doc b
errorArity TypeExpression
t (String
"(" forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
nforall a. Num a => a -> a -> a
-Int
1) Char
',' forall a. [a] -> [a] -> [a]
++ String
")") Int
n (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeExpression]
ts)
        Bool -> Doc -> ErrorOr ()
errorIf (Int
n forall a. Ord a => a -> a -> Bool
< Int
2) forall a b. (a -> b) -> a -> b
$
          String -> Doc
pp String
"A tuple type constructor must have at least two arguments."
        forall (m :: * -> *) a. Monad m => a -> m a
return TypeExpression
t
                                 
      TypeExpression
otherwise             -> forall (m :: * -> *) a. Monad m => a -> m a
return TypeExpression
t

    -- performs the actual checking and error message creation
    errorArity :: b -> String -> a -> a -> Either Doc b
errorArity b
t String
conName a
expected a
found = 
      let args :: a -> String
args a
k = case a
k of
            a
0         -> String
"no argument"
            a
1         -> String
"1 argument"
            a
otherwise -> forall a. Show a => a -> String
show a
k forall a. [a] -> [a] -> [a]
++ String
" arguments"
       in do Bool -> Doc -> ErrorOr ()
errorIf (a
found forall a. Eq a => a -> a -> Bool
/= a
expected) forall a b. (a -> b) -> a -> b
$
                String -> Doc
pp (String
"Type constructor `" forall a. [a] -> [a] -> [a]
++ String
conName forall a. [a] -> [a] -> [a]
++ String
"' was declared to have "
                    forall a. [a] -> [a] -> [a]
++ forall {a}. (Eq a, Num a, Show a) => a -> String
args a
expected forall a. [a] -> [a] -> [a]
++ String
", but it is used with " forall a. [a] -> [a] -> [a]
++ forall {a}. (Eq a, Num a, Show a) => a -> String
args a
found 
                    forall a. [a] -> [a] -> [a]
++ String
".")
             forall (m :: * -> *) a. Monad m => a -> m a
return b
t
   




------- Acyclic tests ---------------------------------------------------------


-- | Checks that type synonym declarations are not mutually recursive.
--   Error messages are created for all type synonym declarations which are
--   mutually recursive with other type synonym declarations.

checkAcyclicTypeSynonyms :: [Declaration] -> Checked [Declaration]
checkAcyclicTypeSynonyms :: [Declaration] -> Checked [Declaration]
checkAcyclicTypeSynonyms [Declaration]
ds =
  let -- gets the name of a type synonym declaration or Nothing
      getTypeSynonymName :: Declaration -> Maybe Identifier
getTypeSynonymName Declaration
d = 
        case Declaration
d of { TypeDecl TypeDeclaration
d -> forall a. a -> Maybe a
Just (TypeDeclaration -> Identifier
typeName TypeDeclaration
d) ; Declaration
otherwise -> forall a. Maybe a
Nothing }
      
      -- the list of all known type synonym names
      allTypeSynonymNames :: [Identifier]
allTypeSynonymNames = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Declaration -> Maybe Identifier
getTypeSynonymName [Declaration]
ds

      -- extracts a type synonym name from a type expression
      occurringTypeSynonyms :: TypeExpression -> Set Identifier
occurringTypeSynonyms TypeExpression
t = case TypeExpression
t of
        TypeCon (Con Identifier
c) [TypeExpression]
_ -> if Identifier
c forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Identifier]
allTypeSynonymNames 
                               then forall a. a -> Set a
Set.singleton Identifier
c
                               else forall a. Set a
Set.empty
        TypeExpression
otherwise         -> forall a. Set a
Set.empty
      
      -- given an element (e.g. a declaration), this function determines all
      -- type synonyms which this element is based on
      getDependencies :: Declaration -> Set Identifier
getDependencies = 
        forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. Ord a => Set a -> Set a -> Set a
Set.union (forall a. Set a
Set.empty forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
`mkQ` TypeExpression -> Set Identifier
occurringTypeSynonyms)

      -- the error message for all unaccepted declarations
      error :: String
error = String
"Declarations of type synonyms must not be mutually recursive."

      -- filter all mutually recursive declarations
   in [Declaration]
-> (Declaration -> Set Identifier)
-> String
-> String
-> Checked [Declaration]
checkDependencyGraph [Declaration]
ds Declaration -> Set Identifier
getDependencies String
error String
"type synonym"



-- | Checks that the type class hierarchy is acyclic. An error message is
--   created for every type class which is part of a cycle.
--
--   Undeclared type classes occurring as superclasses are ignored. They will
--   be filtered out in the next step (see 'checkGlobal').

checkAcyclicTypeClasses :: [Declaration] -> Checked [Declaration]
checkAcyclicTypeClasses :: [Declaration] -> Checked [Declaration]
checkAcyclicTypeClasses [Declaration]
ds =
  let -- gets the name of a class declaration or Nothing
      getClassName :: Declaration -> Maybe Identifier
getClassName Declaration
d = 
        case Declaration
d of { ClassDecl ClassDeclaration
d -> forall a. a -> Maybe a
Just (ClassDeclaration -> Identifier
className ClassDeclaration
d) ; Declaration
otherwise -> forall a. Maybe a
Nothing }
      
      -- the list of all known class names
      allClassNames :: [Identifier]
allClassNames = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Declaration -> Maybe Identifier
getClassName [Declaration]
ds

      -- given a class declaration, this function returns the set of all
      -- superclasses having a known declaration
      getSuperClasses :: Declaration -> Set Identifier
getSuperClasses Declaration
d = case Declaration
d of
        ClassDecl ClassDeclaration
d -> let cs :: [Identifier]
cs = forall a b. (a -> b) -> [a] -> [b]
map (\(TC Identifier
c) -> Identifier
c) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClassDeclaration -> [TypeClass]
superClasses forall a b. (a -> b) -> a -> b
$ ClassDeclaration
d
                        in forall a. Ord a => [a] -> Set a
Set.fromList ([Identifier]
cs forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Identifier]
allClassNames)
        Declaration
otherwise   -> forall a. Set a
Set.empty

      -- the error message for all unaccepted declarations
      error :: String
error =
        String
"The type class hierarchy formed by the type classes and their "
        forall a. [a] -> [a] -> [a]
++ String
"superclasses must not be acyclic."

      -- filter all acyclic type classes
   in [Declaration]
-> (Declaration -> Set Identifier)
-> String
-> String
-> Checked [Declaration]
checkDependencyGraph [Declaration]
ds Declaration -> Set Identifier
getSuperClasses String
error String
"type class"



-- | Applies 'recursivePartition' to the arguments and generates error messages
--   for all erroneous declarations.

checkDependencyGraph :: 
    [Declaration] 
    -> (Declaration -> Set.Set Identifier) 
    -> String
    -> String
    -> Checked [Declaration]

checkDependencyGraph :: [Declaration]
-> (Declaration -> Set Identifier)
-> String
-> String
-> Checked [Declaration]
checkDependencyGraph [Declaration]
ds Declaration -> Set Identifier
getDependencies String
errMsg String
tag = do
  let ([Declaration]
ok, [(Declaration, Set Identifier)]
err) = [Declaration]
-> (Declaration -> Set Identifier)
-> ([Declaration], [(Declaration, Set Identifier)])
recursivePartition [Declaration]
ds Declaration -> Set Identifier
getDependencies
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Declaration, Set Identifier)]
err)) forall a b. (a -> b) -> a -> b
$
    forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [String -> Doc
pp (String
errMsg
              forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
violating String
tag 
                   (forall a b. (a -> b) -> [a] -> [b]
map (Identifier -> String
unpackIdent forall b c a. (b -> c) -> (a -> b) -> a -> c
. Declaration -> Identifier
getDeclarationName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Declaration, Set Identifier)]
err))]
  forall (m :: * -> *) a. Monad m => a -> m a
return [Declaration]
ok



-- | Partitions a list of declarations using a dependency function.
--   Every declaration, which depends only on the declarations given by the
--   third argument, is put into the left set.
--   Every declaration, which depends only on the declarations already in the
--   left set, is put also into the left set. This step is recursively repeated
--   until no more declarations are added to the left set.
--   This function terminates if the first argument is a finite list.

recursivePartition :: 
    [Declaration] 
    -> (Declaration -> Set.Set Identifier) 
    -> ([Declaration], [(Declaration, Set.Set Identifier)])

recursivePartition :: [Declaration]
-> (Declaration -> Set Identifier)
-> ([Declaration], [(Declaration, Set Identifier)])
recursivePartition [Declaration]
decls Declaration -> Set Identifier
getDependencies =
  let -- to increase efficency, calculate the dependencies beforehand
      -- and use the declaration names as keys (declaration names are unique)
      mkMap :: Declaration
-> Map Identifier (Set Identifier)
-> Map Identifier (Set Identifier)
mkMap Declaration
d Map Identifier (Set Identifier)
m = forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Declaration -> Identifier
getDeclarationName Declaration
d) (Declaration -> Set Identifier
getDependencies Declaration
d) Map Identifier (Set Identifier)
m
      depMap :: Map Identifier (Set Identifier)
depMap = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Declaration
-> Map Identifier (Set Identifier)
-> Map Identifier (Set Identifier)
mkMap forall k a. Map k a
Map.empty [Declaration]
decls

      -- checks if 'd' depends only on 'ds' and 'extras', 
      -- i.e. if 'd' is fully contained in 'ds' and 'extras'
      dependsOn :: Identifier -> Set Identifier -> Bool
dependsOn Identifier
d Set Identifier
ds = 
        let deps :: Set Identifier
deps = forall a. HasCallStack => Maybe a -> a
fromJust (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Identifier
d Map Identifier (Set Identifier)
depMap)
         in Set Identifier
deps forall a. Ord a => Set a -> Set a -> Bool
`Set.isSubsetOf` Set Identifier
ds

      -- implements the actual partitioning
      select :: (Set Identifier, Set Identifier)
-> (Set Identifier, Set Identifier)
select (Set Identifier
ds, Set Identifier
rs) = 
        let (Set Identifier
ds', Set Identifier
rs') = forall a. (a -> Bool) -> Set a -> (Set a, Set a)
Set.partition (\Identifier
d -> Identifier
d Identifier -> Set Identifier -> Bool
`dependsOn` Set Identifier
ds) Set Identifier
rs
         in if forall a. Set a -> Bool
Set.null Set Identifier
ds'
              then (Set Identifier
ds, Set Identifier
rs)
              else (Set Identifier, Set Identifier)
-> (Set Identifier, Set Identifier)
select (Set Identifier
ds forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set Identifier
ds', Set Identifier
rs')

      -- run the partitioning, 'ok' is the accepted set while 'err' contains
      -- all erroneous declarations
      (Set Identifier
s1, Set Identifier
s2) = (Set Identifier, Set Identifier)
-> (Set Identifier, Set Identifier)
select (forall a. Set a
Set.empty, forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName [Declaration]
decls))
      ([Declaration]
ok, [Declaration]
err) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\Declaration
d -> Declaration -> Identifier
getDeclarationName Declaration
d forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Identifier
s1) [Declaration]
decls

      -- reduce the mapping to erroneous declarations only such that every
      -- declaration is only mapped to names of erroneous declarations
      getErrDeps :: Declaration -> Set Identifier
getErrDeps Declaration
d = 
        let deps :: Set Identifier
deps = forall a. HasCallStack => Maybe a -> a
fromJust (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Declaration -> Identifier
getDeclarationName Declaration
d) Map Identifier (Set Identifier)
depMap)
         in Set Identifier
deps forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set Identifier
s1
      errMap :: [(Declaration, Set Identifier)]
errMap = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Declaration
d [(Declaration, Set Identifier)]
m -> (Declaration
d, Declaration -> Set Identifier
getErrDeps Declaration
d) forall a. a -> [a] -> [a]
: [(Declaration, Set Identifier)]
m) [] [Declaration]
err

   in ([Declaration]
ok, [(Declaration, Set Identifier)]
errMap)





------- Check declared type constuctors and classes ---------------------------


data Name
  = CON Identifier
  | CLA Identifier
  | OTH Identifier
  deriving (Name -> Name -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Name -> Name -> Bool
$c/= :: Name -> Name -> Bool
== :: Name -> Name -> Bool
$c== :: Name -> Name -> Bool
Eq, Eq Name
Name -> Name -> Bool
Name -> Name -> Ordering
Name -> Name -> Name
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Name -> Name -> Name
$cmin :: Name -> Name -> Name
max :: Name -> Name -> Name
$cmax :: Name -> Name -> Name
>= :: Name -> Name -> Bool
$c>= :: Name -> Name -> Bool
> :: Name -> Name -> Bool
$c> :: Name -> Name -> Bool
<= :: Name -> Name -> Bool
$c<= :: Name -> Name -> Bool
< :: Name -> Name -> Bool
$c< :: Name -> Name -> Bool
compare :: Name -> Name -> Ordering
$ccompare :: Name -> Name -> Ordering
Ord)


getDeclarationName' :: Declaration -> Name
getDeclarationName' :: Declaration -> Name
getDeclarationName' (TypeDecl TypeDeclaration
d)    = Identifier -> Name
CON (TypeDeclaration -> Identifier
typeName TypeDeclaration
d)
getDeclarationName' (DataDecl DataDeclaration
d)    = Identifier -> Name
CON (DataDeclaration -> Identifier
dataName DataDeclaration
d)
getDeclarationName' (NewtypeDecl NewtypeDeclaration
d) = Identifier -> Name
CON (NewtypeDeclaration -> Identifier
newtypeName NewtypeDeclaration
d)
getDeclarationName' (ClassDecl ClassDeclaration
d)   = Identifier -> Name
CLA (ClassDeclaration -> Identifier
className ClassDeclaration
d)
getDeclarationName' (TypeSig Signature
s)     = Identifier -> Name
OTH (Signature -> Identifier
signatureName Signature
s)


unpackName :: Name -> Identifier
unpackName :: Name -> Identifier
unpackName (CON Identifier
c) = Identifier
c
unpackName (CLA Identifier
c) = Identifier
c
unpackName (OTH Identifier
c) = Identifier
c



-- | Checks that all declarations depend only on declared type constructors and
--   declared type classes.

checkAllConsAndClassesDeclared :: 
    [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkAllConsAndClassesDeclared :: [ValidDeclaration] -> [Declaration] -> Checked [Declaration]
checkAllConsAndClassesDeclared [ValidDeclaration]
vds [Declaration]
ds = 
  let -- gets a type constructor name occurring in a type expression
      getCons :: TypeExpression -> Set Name
getCons TypeExpression
t = case TypeExpression
t of
        TypeCon (Con Identifier
c) [TypeExpression]
_ -> forall a. a -> Set a
Set.singleton (Identifier -> Name
CON Identifier
c)
        TypeExpression
otherwise         -> forall a. Set a
Set.empty

      -- gets a type class name
      getClasses :: TypeClass -> Set Name
getClasses (TC Identifier
c) = forall a. a -> Set a
Set.singleton (Identifier -> Name
CLA Identifier
c)

      -- gets all type class names and all type constructor names occurring
      -- in an element (e.g. a declaration)
      getDependencies :: Declaration -> Set Name
getDependencies = 
        forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. Ord a => Set a -> Set a -> Set a
Set.union (forall a b. a -> b -> a
const forall a. Set a
Set.empty forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` TypeExpression -> Set Name
getCons forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` TypeClass -> Set Name
getClasses)

      -- the error message for all unaccepted declarations
      error :: Declaration -> [Name] -> ErrorOr a
error Declaration
d [Name]
is = 
        forall a. Declaration -> ErrorOr a -> ErrorOr a
inDecl Declaration
d forall a b. (a -> b) -> a -> b
$
          forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$
            String -> Doc
pp (String
"The following type constructors or type classes are not "
                forall a. [a] -> [a] -> [a]
++ String
"declared or their declaration contains errors: "
                forall a. [a] -> [a] -> [a]
++ (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a] -> [a]
intersperse String
", " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Identifier -> String
unpackIdent forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Identifier
unpackName) 
                   forall a b. (a -> b) -> a -> b
$ [Name]
is))

      ([Declaration]
ok, [(Declaration, Set Name)]
err) = [Declaration]
-> (Declaration -> Set Name)
-> [Declaration]
-> ([Declaration], [(Declaration, Set Name)])
partitionDeclared [Declaration]
ds Declaration -> Set Name
getDependencies (forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration [ValidDeclaration]
vds)

      -- filter all declarations which only depend on declared type constructors
      -- and declared type classes
   in do forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Declaration
d, Set Name
is) -> forall a. ErrorOr a -> Maybe Doc
getError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. Declaration -> [Name] -> ErrorOr a
error Declaration
d forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
Set.elems forall a b. (a -> b) -> a -> b
$ Set Name
is) [(Declaration, Set Name)]
err)
         forall (m :: * -> *) a. Monad m => a -> m a
return [Declaration]
ok
  


-- | Partitions a given list to all those declarations which don't rely 
--   directly or indirectly on undeclared type constructors or type classes.
--   Compare with 'recursivePartition'.

partitionDeclared :: 
    [Declaration] 
    -> (Declaration -> Set.Set Name) 
    -> [Declaration]
    -> ([Declaration], [(Declaration, Set.Set Name)])

partitionDeclared :: [Declaration]
-> (Declaration -> Set Name)
-> [Declaration]
-> ([Declaration], [(Declaration, Set Name)])
partitionDeclared [Declaration]
decls Declaration -> Set Name
getDependencies [Declaration]
extraDecls =
  let -- to increase efficency, calculate the dependencies beforehand
      -- and use the declaration names as keys (declaration names are unique)
      mkMap :: Declaration -> Map Name (Set Name) -> Map Name (Set Name)
mkMap Declaration
d Map Name (Set Name)
m = forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Declaration -> Name
getDeclarationName' Declaration
d) (Declaration -> Set Name
getDependencies Declaration
d) Map Name (Set Name)
m
      depMap :: Map Name (Set Name)
depMap = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Declaration -> Map Name (Set Name) -> Map Name (Set Name)
mkMap forall k a. Map k a
Map.empty [Declaration]
decls

      -- the list of extra names
      extras :: Set Name
extras = forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Name
getDeclarationName' [Declaration]
extraDecls)

      -- checks if 'd' depends only on 'ds' and 'extras', 
      -- i.e. if 'd' is fully contained in 'ds' and 'extras'
      dependsOn :: Name -> Set Name -> Bool
dependsOn Name
d Set Name
ds = 
        let deps :: Set Name
deps = forall a. HasCallStack => Maybe a -> a
fromJust (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
d Map Name (Set Name)
depMap)
         in Set Name
deps forall a. Ord a => Set a -> Set a -> Bool
`Set.isSubsetOf` (Set Name
extras forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set Name
ds)

      -- implements the actual partitioning
      select :: (Set Name, Set Name) -> (Set Name, Set Name)
select (Set Name
ds, Set Name
es) = 
        let (Set Name
ds', Set Name
es') = forall a. (a -> Bool) -> Set a -> (Set a, Set a)
Set.partition (\Name
d -> Name
d Name -> Set Name -> Bool
`dependsOn` Set Name
ds) Set Name
ds
         in if forall a. Set a -> Int
Set.size Set Name
ds forall a. Eq a => a -> a -> Bool
== forall a. Set a -> Int
Set.size Set Name
ds'
              then (Set Name
ds, Set Name
es)
              else (Set Name, Set Name) -> (Set Name, Set Name)
select (Set Name
ds', Set Name
es forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set Name
es')

      -- run the partitioning, 'ok' is the accepted set while 'err' contains
      -- all erroneous declarations
      (Set Name
s1, Set Name
s2) = (Set Name, Set Name) -> (Set Name, Set Name)
select (forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Name
getDeclarationName' [Declaration]
decls), forall a. Set a
Set.empty)
      ([Declaration]
ok, [Declaration]
err) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\Declaration
d -> Declaration -> Name
getDeclarationName' Declaration
d forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
s1) [Declaration]
decls

      -- reduce the mapping to erroneous declarations only such that every
      -- declaration is only mapped to names of erroneous declarations
      getErrDeps :: Declaration -> Set Name
getErrDeps Declaration
d = 
        let deps :: Set Name
deps = forall a. HasCallStack => Maybe a -> a
fromJust (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup (Declaration -> Name
getDeclarationName' Declaration
d) Map Name (Set Name)
depMap)
         in Set Name
deps forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` (Set Name
extras forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set Name
s1)
      errMap :: [(Declaration, Set Name)]
errMap = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Declaration
d [(Declaration, Set Name)]
m -> (Declaration
d, Declaration -> Set Name
getErrDeps Declaration
d) forall a. a -> [a] -> [a]
: [(Declaration, Set Name)]
m) [] [Declaration]
err

   in ([Declaration]
ok, [(Declaration, Set Name)]
errMap)