-- | Defines functions to ensure that only valid declarations and type 
--   signatures are fed to the FreeTheorems library. The given functions are
--   intended as second stage after parsing declarations.

module Language.Haskell.FreeTheorems.Frontend (
    Checked
  , Parsed
  , runChecks
  , check
  , checkAgainst
) where



import Data.Generics (everything, extQ, mkQ)
import Data.List (partition, intersect)
import Data.Maybe (mapMaybe)

import Language.Haskell.FreeTheorems.Syntax
import Language.Haskell.FreeTheorems.ValidSyntax (ValidDeclaration (..))
import Language.Haskell.FreeTheorems.Frontend.Error (Checked, Parsed, runChecks)
import Language.Haskell.FreeTheorems.Frontend.TypeExpressions
    (replaceAllTypeSynonyms, closeTypeExpressions)
import Language.Haskell.FreeTheorems.Frontend.CheckLocal
    (checkLocal, checkDataAndNewtypeDeclarations)
import Language.Haskell.FreeTheorems.Frontend.CheckGlobal (checkGlobal)



-- | Checks a list of declarations.
--   It returns a list of all declarations which are valid and an error message
--   for all those declarations which are not valid.

check :: [Declaration] -> Checked [ValidDeclaration]
check :: [Declaration] -> Checked [ValidDeclaration]
check = [ValidDeclaration] -> [Declaration] -> Checked [ValidDeclaration]
checkAgainst []



-- | Checks a list of declarations against a given list of valid
--   declarations.
--   It returns a list of all declarations from the second argument which are
--   valid. Moreover, the result contains an error message for all those
--   declarations which are not valid.
--
--   The declarations given in the second argument may be based on those of the
--   first argument. For example, if the first argument contains a valid
--   declaration of a type \"Foo\" and if the second argument contains the
--   following declaration
--
--   > type Bar = Foo
--
--   then also the declaration of \"Bar\" is valid.

checkAgainst :: 
    [ValidDeclaration] 
    -> [Declaration] 
    -> Checked [ValidDeclaration]

checkAgainst :: [ValidDeclaration] -> [Declaration] -> Checked [ValidDeclaration]
checkAgainst [ValidDeclaration]
vds [Declaration]
ds = 
    
    -- start from 'ds'
  forall (m :: * -> *) a. Monad m => a -> m a
return [Declaration]
ds
   
    -- perform local checks:
    --   * free variables of the right-hand side are declared on the left-hand
    --     of declarations
    --   * type variables of the left-hand side are pairwise distinct
    --   * primitive types are not declared
    --   * FixedTypeExpression does not occur anywhere
    --   * type synonyms are not recursive
    --   * data and newtype are not nested
    --   * classes methods are pairwise distinct, don't use the owning class
    --     and have the class variable as free variable
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Declaration] -> WriterT [Doc] Identity [Declaration]
checkLocal
  
    -- perform global checks:
    --   * at most one declaration per name
    --   * arity checks of type constructors in all type expressions
    --   * type class hierarchy is acyclic
    --   * type synonym declarations are not mutually recursive
    --   * all used constructors and classes are declared
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [ValidDeclaration]
-> [Declaration] -> WriterT [Doc] Identity [Declaration]
checkGlobal [ValidDeclaration]
vds

    -- replace all type synonyms, use also the valid type synonyms
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \[Declaration]
ds' -> 
    let getTypeSyn :: Declaration -> Maybe TypeDeclaration
getTypeSyn Declaration
d = case Declaration
d of { TypeDecl TypeDeclaration
t -> forall a. a -> Maybe a
Just TypeDeclaration
t ; Declaration
otherwise -> forall a. Maybe a
Nothing }
        typeSyns :: [TypeDeclaration]
typeSyns = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Declaration -> Maybe TypeDeclaration
getTypeSyn (forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration [ValidDeclaration]
vds forall a. [a] -> [a] -> [a]
++ [Declaration]
ds')
     in forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. (Typeable a, Data a) => [TypeDeclaration] -> a -> a
replaceAllTypeSynonyms [TypeDeclaration]
typeSyns [Declaration]
ds')

    -- checks in data and newtype declarations: no abstractions, no functions
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Declaration] -> WriterT [Doc] Identity [Declaration]
checkDataAndNewtypeDeclarations

    -- finally, close all type signatures and class methods and transform all
    -- declarations to valid ones
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid [ValidDeclaration]
vds forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Declaration] -> [Declaration]
closeTypeExpressions



-- | Turns a list of declarations into valid declarations.
--   Additionally, every declaration is checked whether it depends on any 
--   algebraic data type with strictness flags.

makeValid :: [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid :: [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid [ValidDeclaration]
vds [Declaration]
ds = 
  let strict :: [Declaration]
strict = forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration (forall a. (a -> Bool) -> [a] -> [a]
filter ValidDeclaration -> Bool
isStrictDeclaration [ValidDeclaration]
vds)
      knownStrict :: [Identifier]
knownStrict = forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName 
                        ([Declaration]
strict forall a. [a] -> [a] -> [a]
++ forall a. (a -> Bool) -> [a] -> [a]
filter forall {a}. Data a => a -> Bool
hasStrictnessFlags [Declaration]
ds)
      
      rec :: [Identifier] -> [Declaration] -> [Identifier]
rec [Identifier]
ss [Declaration]
ds = 
        let ([Declaration]
ns, [Declaration]
os) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall {p}. Data p => [Identifier] -> p -> Bool
dependsOnStrictTypes [Identifier]
ss) [Declaration]
ds
         in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Declaration]
ns
              then [Identifier]
ss
              else [Identifier] -> [Declaration] -> [Identifier]
rec ([Identifier]
ss forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName [Declaration]
ns) [Declaration]
os

      allStrict :: [Identifier]
allStrict = [Identifier] -> [Declaration] -> [Identifier]
rec [Identifier]
knownStrict [Declaration]
ds
   
   in forall a b. (a -> b) -> [a] -> [b]
map (\Declaration
d -> Declaration -> Bool -> ValidDeclaration
ValidDeclaration Declaration
d (Declaration -> Identifier
getDeclarationName Declaration
d forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Identifier]
allStrict)) [Declaration]
ds
  where
    hasStrictnessFlags :: a -> Bool
hasStrictnessFlags a
d = 
      let hasBang :: BangTypeExpression -> Bool
hasBang (Banged TypeExpression
_)   = Bool
True
          hasBang (Unbanged TypeExpression
_) = Bool
False
       in forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(||) (Bool
False forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
`mkQ` BangTypeExpression -> Bool
hasBang) a
d
    
    dependsOnStrictTypes :: [Identifier] -> p -> Bool
dependsOnStrictTypes [Identifier]
ss p
d = 
      let getCons :: TypeConstructor -> [Identifier]
getCons TypeConstructor
c = case TypeConstructor
c of { Con Identifier
n -> [Identifier
n] ; TypeConstructor
otherwise -> [] }
          getClasses :: TypeClass -> [Identifier]
getClasses (TC Identifier
n) = [Identifier
n]
          ns :: [Identifier]
ns = forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. [a] -> [a] -> [a]
(++) ([] forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
`mkQ` TypeConstructor -> [Identifier]
getCons forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` TypeClass -> [Identifier]
getClasses) p
d
       in Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Identifier]
ns forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Identifier]
ss))