{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict            #-}
module Tokstyle.C.Linter.Cast (analyse) where

import           Control.Monad                   (unless, zipWithM_)
import           Data.Functor.Identity           (Identity)
import qualified Data.Map                        as Map
import           Language.C.Analysis.AstAnalysis (ExprSide (..), defaultMD,
                                                  tExpr)
import           Language.C.Analysis.ConstEval   (constEval, intValue)
import           Language.C.Analysis.DefTable    (lookupTag)
import           Language.C.Analysis.SemError    (typeMismatch)
import           Language.C.Analysis.SemRep      (EnumType (..),
                                                  EnumTypeRef (..),
                                                  Enumerator (..), GlobalDecls,
                                                  TagDef (..), Type (..),
                                                  TypeName (..))
import           Language.C.Analysis.TravMonad   (MonadTrav, Trav, TravT,
                                                  getDefTable, recordError,
                                                  throwTravError)
import           Language.C.Analysis.TypeUtils   (canonicalType)
import           Language.C.Data.Error           (userErr)
import           Language.C.Data.Ident           (Ident (..))
import           Language.C.Pretty               (pretty)
import           Language.C.Syntax.AST           (CConstant (..), CExpr,
                                                  CExpression (..), annotation)
import           Language.C.Syntax.Constants     (CInteger (..))
import qualified Tokstyle.C.Env                  as Env
import           Tokstyle.C.Env                  (Env)
import           Tokstyle.C.Patterns
import           Tokstyle.C.TraverseAst          (AstActions (..), astActions,
                                                  traverseAst)
import           Tokstyle.C.TravUtils            (getJust)


sameEnum :: MonadTrav m => Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum :: Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum Type
leftTy Type
rightTy (Ident
leftId, CExpr
leftExpr) (Ident
rightId, CExpr
rightExpr) = do
    Integer
leftVal  <- String -> Maybe Integer -> m Integer
forall (m :: * -> *) a. MonadTrav m => String -> Maybe a -> m a
getJust String
failMsg (Maybe Integer -> m Integer)
-> (CExpr -> Maybe Integer) -> CExpr -> m Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CExpr -> Maybe Integer
intValue (CExpr -> m Integer) -> m CExpr -> m Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
leftExpr
    Integer
rightVal <- String -> Maybe Integer -> m Integer
forall (m :: * -> *) a. MonadTrav m => String -> Maybe a -> m a
getJust String
failMsg (Maybe Integer -> m Integer)
-> (CExpr -> Maybe Integer) -> CExpr -> m Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CExpr -> Maybe Integer
intValue (CExpr -> m Integer) -> m CExpr -> m Integer
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
forall (m :: * -> *).
MonadTrav m =>
MachineDesc -> Map Ident CExpr -> CExpr -> m CExpr
constEval MachineDesc
defaultMD Map Ident CExpr
forall k a. Map k a
Map.empty CExpr
rightExpr
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Integer
leftVal Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
rightVal) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        TypeMismatch -> m ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (TypeMismatch -> m ()) -> TypeMismatch -> m ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch
            (String
"invalid cast: enumerator value for `"
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
leftId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
leftVal
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` does not match `"
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Ident -> Doc
forall p. Pretty p => p -> Doc
pretty Ident
rightId) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" = " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
rightVal String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`")
            (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
leftExpr, Type
leftTy)
            (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
rightExpr, Type
rightTy)
  where
    failMsg :: String
failMsg = String
"invalid cast: could not determine enumerator values"

checkEnumCast :: MonadTrav m => Type -> Type -> CExpr -> m ()
checkEnumCast :: Type -> Type -> CExpr -> m ()
checkEnumCast Type
castTy Type
exprTy CExpr
_ = do
    [(Ident, CExpr)]
castEnums <- Type -> m [(Ident, CExpr)]
forall (m :: * -> *). MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators (Type -> Type
canonicalType Type
castTy)
    [(Ident, CExpr)]
exprEnums <- Type -> m [(Ident, CExpr)]
forall (m :: * -> *). MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators (Type -> Type
canonicalType Type
exprTy)
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([(Ident, CExpr)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Ident, CExpr)]
castEnums Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [(Ident, CExpr)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Ident, CExpr)]
exprEnums) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        UserError -> m ()
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m ()) -> UserError -> m ()
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
            String
"enum types `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
castTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` and `"
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
exprTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` have different a number of enumerators"
    ((Ident, CExpr) -> (Ident, CExpr) -> m ())
-> [(Ident, CExpr)] -> [(Ident, CExpr)] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
forall (m :: * -> *).
MonadTrav m =>
Type -> Type -> (Ident, CExpr) -> (Ident, CExpr) -> m ()
sameEnum Type
castTy Type
exprTy) [(Ident, CExpr)]
castEnums [(Ident, CExpr)]
exprEnums

enumerators :: MonadTrav m => Type -> m [(Ident, CExpr)]
enumerators :: Type -> m [(Ident, CExpr)]
enumerators (DirectType (TyEnum (EnumTypeRef SUERef
name NodeInfo
_)) TypeQuals
_ Attributes
_) = do
    DefTable
defs <- m DefTable
forall (m :: * -> *). MonadSymtab m => m DefTable
getDefTable
    case SUERef -> DefTable -> Maybe TagEntry
lookupTag SUERef
name DefTable
defs of
      Just (Right (EnumDef (EnumType SUERef
_ [Enumerator]
enums Attributes
_ NodeInfo
_))) ->
          [(Ident, CExpr)] -> m [(Ident, CExpr)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Ident, CExpr)] -> m [(Ident, CExpr)])
-> [(Ident, CExpr)] -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ (Enumerator -> (Ident, CExpr)) -> [Enumerator] -> [(Ident, CExpr)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Enumerator Ident
i CExpr
e EnumType
_ NodeInfo
_) -> (Ident
i, CExpr
e)) [Enumerator]
enums
      Maybe TagEntry
_ ->
        UserError -> m [(Ident, CExpr)]
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m [(Ident, CExpr)])
-> UserError -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
            String
"couldn't find enum type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (SUERef -> Doc
forall p. Pretty p => p -> Doc
pretty SUERef
name) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"
enumerators Type
ty =
    UserError -> m [(Ident, CExpr)]
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m [(Ident, CExpr)])
-> UserError -> m [(Ident, CExpr)]
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$ String
"invalid enum type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
ty) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"


checkCast :: MonadTrav m => Type -> Type -> CExpr -> m ()
-- Cast to void: OK.
checkCast :: Type -> Type -> CExpr -> m ()
checkCast (DirectType TypeName
TyVoid TypeQuals
_ Attributes
_) Type
_ CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting between `void*` and `T*`: OK
checkCast PtrType{} Type
TY_void_ptr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkCast Type
TY_void_ptr PtrType{} CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting between `char*` and `uint8_t*`: OK
checkCast Type
TY_uint8_t_ptr Type
TY_char_ptr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkCast Type
TY_uint8_t_ptr Type
TY_char_arr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkCast Type
TY_char_ptr Type
TY_uint8_t_ptr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkCast Type
TY_char_ptr Type
TY_uint8_t_arr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting literal 0 to `T*`: OK
checkCast PtrType{} Type
_ (CConst (CIntConst (CInteger Integer
0 CIntRepr
_ Flags CIntFlag
_) NodeInfo
_)) = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting sockaddr_storage to any of the sockaddr_... types: OK
checkCast Type
TY_sockaddr_ptr     Type
TY_sockaddr_storage_ptr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkCast Type
TY_sockaddr_in_ptr  Type
TY_sockaddr_storage_ptr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkCast Type
TY_sockaddr_in6_ptr Type
TY_sockaddr_storage_ptr CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting between numeric types: OK
checkCast Type
castTy Type
exprTy CExpr
_ | Type -> Bool
isNumeric Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isNumeric Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting from enum to int: OK
checkCast Type
castTy Type
exprTy CExpr
_ | Type -> Bool
isIntegral Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isEnum Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting between enums: check whether they have the same enumerators.
checkCast Type
castTy Type
exprTy CExpr
e | Type -> Bool
isEnum Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isEnum Type
exprTy = Type -> Type -> CExpr -> m ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> CExpr -> m ()
checkEnumCast Type
castTy Type
exprTy CExpr
e
-- Casting to `Messenger**`: NOT OK, but toxav does this.
-- TODO(iphydf): Fix this.
checkCast (PtrType (PtrType (TY_typedef String
"Messenger") TypeQuals
_ Attributes
_) TypeQuals
_ Attributes
_) Type
_ CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting to `void**`: probably not ok, but toxav also does this.
-- TODO(iphydf): Investigate.
checkCast (PtrType Type
TY_void_ptr TypeQuals
_ Attributes
_) Type
_ CExpr
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- Casting from int to enum: actually NOT OK, but we do this a lot, so meh.
-- TODO(iphydf): Fix these.
checkCast Type
castTy Type
exprTy CExpr
_ | Type -> Bool
isEnum Type
castTy Bool -> Bool -> Bool
&& Type -> Bool
isIntegral Type
exprTy = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- Any other casts: NOT OK
checkCast Type
castTy Type
exprTy CExpr
e =
    let annot :: (NodeInfo, Type)
annot = (CExpr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation CExpr
e, Type
castTy) in
    TypeMismatch -> m ()
forall (m :: * -> *) e. (MonadCError m, Error e) => e -> m ()
recordError (TypeMismatch -> m ()) -> TypeMismatch -> m ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch (String
"disallowed cast from " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
        Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
exprTy) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" to " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
castTy)) (NodeInfo, Type)
annot (NodeInfo, Type)
annot


-- | Some exemptions where weird casts like int* -> char* may happen.
exemptions :: [String]
exemptions :: [String]
exemptions = [String
"call:getsockopt", String
"call:setsockopt", String
"call:bs_list_add", String
"call:bs_list_remove", String
"call:bs_list_find", String
"call:random_bytes", String
"call:randombytes"]


linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
    { doExpr :: CExpr -> TravT Env Identity () -> TravT Env Identity ()
doExpr = \CExpr
node TravT Env Identity ()
act -> case CExpr
node of
        cast :: CExpr
cast@(CCast CDeclaration NodeInfo
_ CExpr
e NodeInfo
_) -> do
            Type
castTy <- [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
cast
            Type
exprTy <- [StmtCtx] -> ExprSide -> CExpr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> CExpr -> m Type
tExpr [] ExprSide
RValue CExpr
e
            [String]
ctx <- Trav Env [String]
Env.getCtx
            Bool -> TravT Env Identity () -> TravT Env Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([String] -> String
forall a. [a] -> a
head [String]
ctx String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
exemptions) (TravT Env Identity () -> TravT Env Identity ())
-> TravT Env Identity () -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$
                Type -> Type -> CExpr -> TravT Env Identity ()
forall (m :: * -> *). MonadTrav m => Type -> Type -> CExpr -> m ()
checkCast Type
castTy Type
exprTy CExpr
e
            TravT Env Identity ()
act

        CCall (CVar (Ident String
fname Int
_ NodeInfo
_) NodeInfo
_) [CExpr]
_ NodeInfo
_ -> do
            String -> TravT Env Identity ()
Env.pushCtx (String -> TravT Env Identity ())
-> String -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$ String
"call:" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
fname
            TravT Env Identity ()
act
            TravT Env Identity ()
Env.popCtx

        CExpr
_ -> TravT Env Identity ()
act
    }


analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> TravT Env Identity ()
analyse = AstActions (TravT Env Identity)
-> GlobalDecls -> TravT Env Identity ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter