module Language.Haskell.FreeTheorems.Frontend (
Checked
, Parsed
, runChecks
, check
, checkAgainst
) where
import Data.Generics (everything, extQ, mkQ)
import Data.List (partition, intersect)
import Data.Maybe (mapMaybe)
import Language.Haskell.FreeTheorems.Syntax
import Language.Haskell.FreeTheorems.ValidSyntax (ValidDeclaration (..))
import Language.Haskell.FreeTheorems.Frontend.Error (Checked, Parsed, runChecks)
import Language.Haskell.FreeTheorems.Frontend.TypeExpressions
(replaceAllTypeSynonyms, closeTypeExpressions)
import Language.Haskell.FreeTheorems.Frontend.CheckLocal
(checkLocal, checkDataAndNewtypeDeclarations)
import Language.Haskell.FreeTheorems.Frontend.CheckGlobal (checkGlobal)
check :: [Declaration] -> Checked [ValidDeclaration]
check :: [Declaration] -> Checked [ValidDeclaration]
check = [ValidDeclaration] -> [Declaration] -> Checked [ValidDeclaration]
checkAgainst []
checkAgainst ::
[ValidDeclaration]
-> [Declaration]
-> Checked [ValidDeclaration]
checkAgainst :: [ValidDeclaration] -> [Declaration] -> Checked [ValidDeclaration]
checkAgainst [ValidDeclaration]
vds [Declaration]
ds =
forall (m :: * -> *) a. Monad m => a -> m a
return [Declaration]
ds
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Declaration] -> WriterT [Doc] Identity [Declaration]
checkLocal
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [ValidDeclaration]
-> [Declaration] -> WriterT [Doc] Identity [Declaration]
checkGlobal [ValidDeclaration]
vds
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \[Declaration]
ds' ->
let getTypeSyn :: Declaration -> Maybe TypeDeclaration
getTypeSyn Declaration
d = case Declaration
d of { TypeDecl TypeDeclaration
t -> forall a. a -> Maybe a
Just TypeDeclaration
t ; Declaration
otherwise -> forall a. Maybe a
Nothing }
typeSyns :: [TypeDeclaration]
typeSyns = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Declaration -> Maybe TypeDeclaration
getTypeSyn (forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration [ValidDeclaration]
vds forall a. [a] -> [a] -> [a]
++ [Declaration]
ds')
in forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. (Typeable a, Data a) => [TypeDeclaration] -> a -> a
replaceAllTypeSynonyms [TypeDeclaration]
typeSyns [Declaration]
ds')
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Declaration] -> WriterT [Doc] Identity [Declaration]
checkDataAndNewtypeDeclarations
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid [ValidDeclaration]
vds forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Declaration] -> [Declaration]
closeTypeExpressions
makeValid :: [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid :: [ValidDeclaration] -> [Declaration] -> [ValidDeclaration]
makeValid [ValidDeclaration]
vds [Declaration]
ds =
let strict :: [Declaration]
strict = forall a b. (a -> b) -> [a] -> [b]
map ValidDeclaration -> Declaration
rawDeclaration (forall a. (a -> Bool) -> [a] -> [a]
filter ValidDeclaration -> Bool
isStrictDeclaration [ValidDeclaration]
vds)
knownStrict :: [Identifier]
knownStrict = forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName
([Declaration]
strict forall a. [a] -> [a] -> [a]
++ forall a. (a -> Bool) -> [a] -> [a]
filter forall {a}. Data a => a -> Bool
hasStrictnessFlags [Declaration]
ds)
rec :: [Identifier] -> [Declaration] -> [Identifier]
rec [Identifier]
ss [Declaration]
ds =
let ([Declaration]
ns, [Declaration]
os) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall {p}. Data p => [Identifier] -> p -> Bool
dependsOnStrictTypes [Identifier]
ss) [Declaration]
ds
in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Declaration]
ns
then [Identifier]
ss
else [Identifier] -> [Declaration] -> [Identifier]
rec ([Identifier]
ss forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map Declaration -> Identifier
getDeclarationName [Declaration]
ns) [Declaration]
os
allStrict :: [Identifier]
allStrict = [Identifier] -> [Declaration] -> [Identifier]
rec [Identifier]
knownStrict [Declaration]
ds
in forall a b. (a -> b) -> [a] -> [b]
map (\Declaration
d -> Declaration -> Bool -> ValidDeclaration
ValidDeclaration Declaration
d (Declaration -> Identifier
getDeclarationName Declaration
d forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Identifier]
allStrict)) [Declaration]
ds
where
hasStrictnessFlags :: a -> Bool
hasStrictnessFlags a
d =
let hasBang :: BangTypeExpression -> Bool
hasBang (Banged TypeExpression
_) = Bool
True
hasBang (Unbanged TypeExpression
_) = Bool
False
in forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(||) (Bool
False forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
`mkQ` BangTypeExpression -> Bool
hasBang) a
d
dependsOnStrictTypes :: [Identifier] -> p -> Bool
dependsOnStrictTypes [Identifier]
ss p
d =
let getCons :: TypeConstructor -> [Identifier]
getCons TypeConstructor
c = case TypeConstructor
c of { Con Identifier
n -> [Identifier
n] ; TypeConstructor
otherwise -> [] }
getClasses :: TypeClass -> [Identifier]
getClasses (TC Identifier
n) = [Identifier
n]
ns :: [Identifier]
ns = forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything forall a. [a] -> [a] -> [a]
(++) ([] forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
`mkQ` TypeConstructor -> [Identifier]
getCons forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` TypeClass -> [Identifier]
getClasses) p
d
in Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Identifier]
ns forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Identifier]
ss))