{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict            #-}
module Tokstyle.C.Linter.SizeArg (analyse) where

import           Data.Functor.Identity           (Identity)
import qualified Data.List                       as List
import qualified Data.Map                        as Map
import           Language.C.Analysis.AstAnalysis (ExprSide (..), defaultMD,
                                                  tExpr)
import           Language.C.Analysis.ConstEval   (constEval, intValue)
import           Language.C.Analysis.SemError    (invalidAST, typeMismatch)
import           Language.C.Analysis.SemRep      (GlobalDecls, ParamDecl (..),
                                                  Type (..))
import           Language.C.Analysis.TravMonad   (Trav, TravT, catchTravError,
                                                  recordError, throwTravError)
import           Language.C.Analysis.TypeUtils   (canonicalType)
import           Language.C.Data.Ident           (Ident (..))
import           Language.C.Pretty               (pretty)
import           Language.C.Syntax.AST           (CExpr, CExpression (..),
                                                  annotation)
import           Tokstyle.C.Env                  (Env)
import           Tokstyle.C.Patterns
import           Tokstyle.C.TraverseAst          (AstActions (..), astActions,
                                                  traverseAst)


checkArraySizes :: Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes :: Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId ((ParamDecl
_, CExpr
_, arrTy :: Type
arrTy@(ArrayTypeSize CExpr
arrSize)):(ParamName String
sizeParam, CExpr
sizeArg, Type
sizeTy):[(ParamDecl, CExpr, Type)]
args)
    | (String -> Bool) -> [String] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`List.isInfixOf` String
sizeParam) [String
"size", String
"len"] =
        -- Ignore any name lookup errors here. VLAs have locally defined
        -- array sizes, but we don't check VLAs.
        Trav Env () -> (CError -> Trav Env ()) -> Trav Env ()
forall (m :: * -> *) a.
MonadCError m =>
m a -> (CError -> m a) -> m a
catchTravError (do
            Maybe Integer
arrSizeVal <- CExpr -> Maybe Integer
intValue (CExpr -> Maybe Integer)
-> TravT Env Identity CExpr -> TravT Env Identity (Maybe Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MachineDesc -> Map Ident CExpr -> CExpr -> TravT Env Identity CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
arrSize
            Maybe Integer
sizeArgVal <- CExpr -> Maybe Integer
intValue (CExpr -> Maybe Integer)
-> TravT Env Identity CExpr -> TravT Env Identity (Maybe Integer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MachineDesc -> Map Ident CExpr -> CExpr -> TravT Env Identity CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
sizeArg
            case (Maybe Integer
arrSizeVal, Maybe Integer
sizeArgVal) of
                (Just Integer
arrSizeConst, Just Integer
sizeArgConst) | Integer
arrSizeConst Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
sizeArgConst ->
                    let annot :: (NodeInfo, Type)
annot = (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
sizeArg, Type
sizeTy) in
                    TypeMismatch -> Trav Env ()
forall (m :: * -> *) e. (MonadCError m, Error e) => e -> m ()
recordError (TypeMismatch -> Trav Env ()) -> TypeMismatch -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch (
                        String
"size parameter `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
sizeParam String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` is passed constant value `"
                        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (CExpr -> Doc
forall p. Pretty p => p -> Doc
pretty CExpr
sizeArg) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` (= " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
sizeArgConst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"),\n"
                        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"  which is greater than the array size of `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
arrTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`,\n"
                        String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"  potentially causing buffer overrun in `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
funId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`") (NodeInfo, Type)
annot (NodeInfo, Type)
annot
                (Maybe Integer, Maybe Integer)
_ -> () -> Trav Env ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()  -- not constant, or array size greater than size arg
            Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId [(ParamDecl, CExpr, Type)]
args
        ) ((CError -> Trav Env ()) -> Trav Env ())
-> (CError -> Trav Env ()) -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ Trav Env () -> CError -> Trav Env ()
forall a b. a -> b -> a
const (Trav Env () -> CError -> Trav Env ())
-> Trav Env () -> CError -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ () -> Trav Env ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

checkArraySizes Ident
funId ((ParamDecl, CExpr, Type)
_:[(ParamDecl, CExpr, Type)]
xs) = Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId [(ParamDecl, CExpr, Type)]
xs
checkArraySizes Ident
_ [] = () -> Trav Env ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()


linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
    { doExpr :: CExpr -> Trav Env () -> Trav Env ()
doExpr = \CExpr
node Trav Env ()
act -> case CExpr
node of
        CCall fun :: CExpr
fun@(CVar Ident
funId NodeInfo
_) [CExpr]
args NodeInfo
_ ->
            [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
fun TravT Env Identity Type -> (Type -> Trav Env ()) -> Trav Env ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                FunPtrParams [ParamDecl]
params -> do
                    [Type]
tys <- (CExpr -> TravT Env Identity Type)
-> [CExpr] -> TravT Env Identity [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Type -> Type)
-> TravT Env Identity Type -> TravT Env Identity Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
canonicalType (TravT Env Identity Type -> TravT Env Identity Type)
-> (CExpr -> TravT Env Identity Type)
-> CExpr
-> TravT Env Identity Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue) [CExpr]
args
                    Ident -> [(ParamDecl, CExpr, Type)] -> Trav Env ()
checkArraySizes Ident
funId ([ParamDecl] -> [CExpr] -> [Type] -> [(ParamDecl, CExpr, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [ParamDecl]
params [CExpr]
args [Type]
tys)
                    Trav Env ()
act
                Type
x -> InvalidASTError -> Trav Env ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (InvalidASTError -> Trav Env ()) -> InvalidASTError -> Trav Env ()
forall a b. (a -> b) -> a -> b
$ NodeInfo -> String -> InvalidASTError
invalidAST (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
node) (String -> InvalidASTError) -> String -> InvalidASTError
forall a b. (a -> b) -> a -> b
$ Type -> String
forall a. Show a => a -> String
show Type
x

        CExpr
_ -> Trav Env ()
act
    }


analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> Trav Env ()
analyse = AstActions (TravT Env Identity) -> GlobalDecls -> Trav Env ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter