{- CAO Compiler
Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see . -}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-
Module : $Header$
Description : Indistinguishable functions.
Copyright : (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho
License : GPL
Maintainer : Paulo Silva
Stability : experimental
Portability : non-portable
-}
module Language.CAO.Transformation.Indist
( mkIndistFun
, indist
) where
import Control.Applicative
import Data.List
import qualified Data.Map as M
import Data.Set ( Set )
import qualified Data.Set as Set
import Data.Maybe ( catMaybes )
import qualified Data.Traversable as T
import qualified Data.Foldable as F
import Language.CAO.Common.Error
import Language.CAO.Common.Fresh
import Language.CAO.Common.Monad
import Language.CAO.Common.Outputable
import Language.CAO.Common.SrcLoc
import Language.CAO.Common.State
import Language.CAO.Common.Var
import Language.CAO.Syntax
import Language.CAO.Syntax.Utils ( getVars, getLVars, sameKind, fvs, defVar )
import Language.CAO.Analysis.CFG
import Language.CAO.Analysis.SsaBack ( introduceDefs, rmVars )
--------------------------------------------------------------------------------
-- * Indistinguishable functions
--------------------------------------------------------------------------------
-- | Apply countermeasures to two function definitions
mkIndistFun :: CaoMonad m => String -> String -> [CaoCFG] -> m [CaoCFG]
mkIndistFun (mkFunName -> fn1) (mkFunName -> fn2) cfgs
| Just ((p1,p2), (cfg1, cfg2), cfgs2) <- mcfgs, valid cfg1, valid cfg2 = do
(cfg1', cfg2') <- mkIndistCfg (fn1, cfg1) (fn2, cfg2)
return $ insertPos [(p1, cfg1'), (p2, cfg2')] cfgs2
| otherwise = indistWarn fn1 fn2 >> return cfgs
where mcfgs :: Maybe ((Int, Int),(CaoCFG, CaoCFG), [CaoCFG])
mcfgs = do
(p1, cfg1, cfgs') <- lookupDef fn1 cfgs
(p2, cfg2, cfgs'') <- lookupDef fn2 cfgs'
return ((p1, p2), (cfg1, cfg2), cfgs'')
-- TODO: stub
valid _ = True
mkIndistCfg :: CaoMonad m => (Name, CaoCFG) -> (Name, CaoCFG) -> m (CaoCFG, CaoCFG)
mkIndistCfg (name1, cfg1) (name2, cfg2)
| Just ((n1, n2), (b1, b2), (c1, c2)) <- mcfgs = do
(b1', b2') <- indist b1 b2
let cfg1' = introduceDefs $ rmVars $ cfg1 { blocks = M.insert n1 (b1', c1) bcfg1 }
cfg2' = introduceDefs $ rmVars $ cfg2 { blocks = M.insert n2 (b2', c2) bcfg2 }
mkIndistDecls cfg1' cfg2'
| otherwise = indistWarn name1 name2 >> return (cfg1, cfg2)
where bcfg1 = blocks $ removeSsaDecl cfg1
bcfg2 = blocks $ removeSsaDecl cfg2
mcfgs = do
(n1, b1, c1) <- innerNode entryNode [exitNode] bcfg1
(n2, b2, c2) <- innerNode entryNode [exitNode] bcfg2
return ((n1,n2), (b1,b2), (c1,c2))
mkIndistDecls :: CaoMonad m => CaoCFG -> CaoCFG -> m (CaoCFG, CaoCFG)
mkIndistDecls cfg1 cfg2
| Just ((n1, n2), (b1, b2), (c1, c2)) <- mcfgs = do
(b1', b2') <- indistDecls b1 b2
return ( cfg1 { blocks = M.insert n1 (b1', c1) bcfg1 }
, cfg2 { blocks = M.insert n2 (b2', c2) bcfg2 }
)
| otherwise = return (cfg1, cfg2)
where
bcfg1 = blocks cfg1
bcfg2 = blocks cfg2
mcfgs :: Maybe ((NodeId, NodeId), (BasicBlock, BasicBlock), (Connections, Connections))
mcfgs = do
(n1, b1, c1) <- innerNode entryNode [exitNode] bcfg1
(n2, b2, c2) <- innerNode entryNode [exitNode] bcfg2
return ((n1, n2), (b1, b2), (c1, c2))
-- Pre: all operations are already "indistinguishable".
indistDecls :: CaoMonad m => BasicBlock -> BasicBlock -> m (BasicBlock, BasicBlock)
indistDecls b1 b2 = do
(db1', db2') <- case ldb1 of
_ | ldb1 == ldb2 -> return (db1, db2)
| ldb1 > ldb2 -> do
db2'' <- mapM dummyDecl (drop ldb2 db1)
return (db1, db2 ++ db2'')
| otherwise -> do -- ldb2 > ldb1
db1'' <- mapM dummyDecl (drop ldb1 db2)
return (db1 ++ db1'', db2)
return (db1' ++ rb1, db2' ++ rb2)
where
(db1, rb1) = partition isDecl b1
(db2, rb2) = partition isDecl b2
ldb1 = length db1
ldb2 = length db2
isDecl (L _ (VDecl _)) = True
isDecl _ = False
dummyDecl :: CaoMonad m => LStmt Var -> m (LStmt Var)
dummyDecl (unLoc -> VDecl vd)
= genLoc . VDecl <$> T.mapM (freshVar Local . varType) vd
dummyDecl s
= error $ "Language.CAO.CaoSSA.dummyDecl: failed to create a dummy\
\operation of this kind!" ++ showPpr s
innerNode :: NodeId -> [NodeId] -> M.Map NodeId (BasicBlock, Connections)
-> Maybe (NodeId, BasicBlock, Connections)
innerNode e next m
| Just (_, [n]) <- M.lookup e m -- entry
, Just (b, rest) <- M.lookup n m -- inner
, rest == next -- connections are OK, TODO:ordering
= Just (n, b, rest)
| otherwise
= Nothing
lookupDef :: Name -> [CaoCFG] -> Maybe (Int, CaoCFG, [CaoCFG])
lookupDef n cfgs
| ([(i,cfg)], cfgs') <- partitionPos hasName cfgs = Just (i,cfg, cfgs')
| otherwise = Nothing
where hasName = (== [n]) . map varName . defVar . definition
partitionPos :: (a -> Bool) -> [a] -> ([(Int, a)], [a])
partitionPos f lst = partitionPosAcc 0 ([],[]) lst
where partitionPosAcc _ r [] = r
partitionPosAcc a (ys,ns) (x:xs)
| f x = partitionPosAcc (a + 1) ((a,x):ys, ns ) xs
| otherwise = partitionPosAcc (a + 1) (ys , x:ns) xs
insertPos :: [(Int, a)] -> [a] -> [a]
insertPos lst xs = foldl' (\b (i, x) -> insertAt i x b) xs $ sortBy compareFst lst
where compareFst (i1,_) (i2,_) = compare i1 i2
insertAt :: Int -> a -> [a] -> [a]
insertAt 0 x lst = x:lst
insertAt _ x [] = [x]
insertAt n x (y:ys) = y:insertAt (n - 1) x ys
indistWarn :: CaoMonad m => Name -> Name -> m ()
indistWarn v1 = caoWarning defSrcLoc . IndistFail v1
-- | Turn two CFG basic blocks into indistinguishable
--
-- Notes: (b1', b2') <- b1 `indist` b2
indist :: CaoMonad m => BasicBlock -> BasicBlock -> m (BasicBlock, BasicBlock)
indist b1 b2 = mkIndist (mkStmtGraph b1) (mkStmtGraph b2)
-- | Algorithm for indistinguishable functions
-- TODO: check best place for dummy ops
mkIndist :: CaoMonad m => StmtGraph -> StmtGraph
-> m (BasicBlock, BasicBlock)
mkIndist g1 g2 = do
tr <- doMkSTree [SN { cost = 0
, stmt1 = []
, stmt2 = []
, rest1 = g1
, rest2 = g2
}]
let (r:_) = sortBy (\(c1,_,_) (c2,_,_) -> compare c1 c2) tr
return $ (\(_,x,y) -> (x,y)) r
--------------------------------------------------------------------------------
-- ** Solution
--------------------------------------------------------------------------------
data SNode = SN { cost :: Int
, stmt1 :: BasicBlock
, stmt2 :: BasicBlock
, rest1 :: StmtGraph
, rest2 :: StmtGraph
}
fCost :: SNode -> Int
fCost sn = cost sn + fDist (rest1 sn) + fDist (rest2 sn)
cmpNd :: SNode -> SNode -> Ordering
cmpNd sn1 sn2 = compare (fCost sn1) (fCost sn2)
{-
Not used but can be useful in the future
nextNode :: SNode -> SNode -> SNode
nextNode (SN sc b1 b2 _ _) (SN sc2 s1 s2 g1' g2')
= SN (sc + sc2) (s1 ++ b1) (s2 ++ b2) g1' g2'
-}
doMkSTree :: CaoMonad m => [SNode] -> m [(Int, BasicBlock, BasicBlock)]
doMkSTree [] = return []
doMkSTree es@(sn:xs)
| nullG g1 && nullG g2 = do
rs <- doMkSTree xs
return $ (cost sn, reverse $ stmt1 sn, reverse $ stmt2 sn):rs
| otherwise = do
alts <- sortBy cmpNd . concat <$> mapM nextNodes es
doMkSTree (take 200 alts) --- $ concatMap (\e -> map (nextNode e) alts) es
where g1 = rest1 sn
g2 = rest2 sn
fDist :: StmtGraph -> Int
fDist (SGraph w _) = w
nextNodes :: CaoMonad m => SNode -> m [SNode]
nextNodes sn = (sn' ++) <$> dummys
where g1 = rest1 sn
g2 = rest2 sn
altsG1 = anyStmt g1
altsG2 = anyStmt g2
sn' = map mkAlt $ combinations altsG1 altsG2
mkAlt ((s1,g1'),(s2,g2'))
= sn { stmt1 = s1:(stmt1 sn)
, stmt2 = s2:(stmt2 sn)
, rest1 = g1'
, rest2 = g2'
}
dummys = do
d1 <- mapM addDL $ filter (not . isRet . fst) altsG1
d2 <- mapM addDR $ filter (not . isRet . fst) altsG2
return $ d1 ++ d2
addDL (s, g)
| not (needsDummy s) =
return $ sn { stmt1 = s :(stmt1 sn)
, rest1 = g
}
| otherwise = do
(n, vs, s') <- mkDummyOp s
F.mapM_ storeTmpVar vs
return $ sn { cost = (cost sn) + n
, stmt1 = s :(stmt1 sn)
, stmt2 = s':(stmt2 sn)
, rest1 = g
}
addDR (s, g)
| not (needsDummy s) =
return $ sn { stmt2 = s :(stmt2 sn)
, rest2 = g
}
| otherwise = do
(n, vs, s') <- mkDummyOp s
F.mapM_ storeTmpVar vs
return $ sn { cost = (cost sn) + n
, stmt2 = s :(stmt2 sn)
, stmt1 = s':(stmt1 sn)
, rest2 = g
}
-- TODO: Refactor
isRet (L _ (Ret _)) = True
isRet _ = False
needsDummy (L _ (Assign _ _)) = True
needsDummy _ = False
combinations :: [(LStmt Var, StmtGraph)] -> [(LStmt Var, StmtGraph)]
-> [((LStmt Var, StmtGraph),(LStmt Var, StmtGraph))]
combinations l1 l2 = [ ((s1, g1), (s2, g2)) | (s1, g1) <- l1
, (s2, g2) <- l2
, sameKind s1 s2 ]
--------------------------------------------------------------------------------
-- ** Dependency graphs
--------------------------------------------------------------------------------
type LOC = Int
type Weight = Int
-- a := b;
-- b := c;
-- r := s;
-- z := b + r;
--
-- 1 -> (a := b, [])
-- 2 -> (b := c, [1])
-- 3 -> (r := s, [])
-- 4 -> (z := b + r, [2,3])
-- Statement dependency graph. Array of statements and list of dependencies
data StmtGraph = SGraph Weight (M.Map LOC (LStmt Var, [LOC]))
instance PP StmtGraph where
ppr (SGraph _ m) = vsep $ map (\(l, s) -> ppr l <+> text "->" <+> ppr s) $ M.assocs m
-- | Check if dependency graph is null
nullG :: StmtGraph -> Bool
nullG (SGraph _ m) = M.null m
{-
Not used but useful in the future.
-- | emptyGraph
emptyGraph :: StmtGraph
emptyGraph = SGraph 0 M.empty
-}
-- | Create a dependency graph from a basicblock
mkStmtGraph :: BasicBlock -> StmtGraph
mkStmtGraph ss = SGraph w $! lssDeps
where lss = zip [1..] ss {--} -- zip [length ss, length ss -1..1] $ ss
(w, lssDeps) = calculateDeps M.empty M.empty lss
calculateDeps :: M.Map Var LOC -> M.Map Var LOC -> [(LOC, LStmt Var)]
-> (Weight, M.Map LOC (LStmt Var, [LOC]))
calculateDeps _ _ []
= (0, M.empty)
calculateDeps lvars vars ((loc, stmt):rest)
= (w' + stmtCost stmt, mm `seq` M.insert loc (stmt, nub $ deps1 ++ deps2) mm)
where lvs = getLVars stmt
vs = getVars stmt
nlvs = foldl' (\m v -> M.insert v loc m) lvars lvs
nvs = foldl' (\m v -> M.insert v loc m) vars vs
deps1 = catMaybes $ map (`M.lookup` lvars) $ vs
deps2 = catMaybes $ map (`M.lookup` vars) $ lvs
(w', mm) = calculateDeps nlvs nvs rest
{-
Not used but useful in the future
takeBlock :: StmtGraph -> (BasicBlock, StmtGraph)
takeBlock (SGraph w a) = ng `seq` (stmts, SGraph w' ng)
where noDeps = M.filter (null . snd) a
stmts = map fst $ M.elems noDeps
locs = M.keys noDeps
(w',ng) = M.foldWithKey fAdjDeps (w,a) a
fAdjDeps :: LOC -> (LStmt Var, [LOC])
-> (Weight, M.Map LOC (LStmt Var, [LOC]))
-> (Weight, M.Map LOC (LStmt Var, [LOC]))
fAdjDeps k (stmt, deps) (wgt, mp)
| k `elem` locs = (wgt - stmtCost stmt, mp `seq` M.delete k mp)
| otherwise = (wgt, mp `seq` M.insert k (stmt, deps \\ locs) mp)
-}
anyStmt :: StmtGraph -> [(LStmt Var, StmtGraph)]
anyStmt (SGraph w a) = map fGetAlts ndlst
where ndlst = M.assocs $ M.filter (null . snd) a
fGetAlts :: (LOC, (LStmt Var, [LOC])) -> (LStmt Var, StmtGraph)
fGetAlts (k, (s, _)) = (s, SGraph (w - stmtCost s)
$! M.foldWithKey (fAdjDeps k) a a)
fAdjDeps :: LOC -> LOC -> (LStmt Var, [LOC])
-> M.Map LOC (LStmt Var, [LOC])
-> M.Map LOC (LStmt Var, [LOC])
fAdjDeps toDel k (stmt, deps) mp
| k == toDel = mp `seq` M.delete k mp
| otherwise = mp `seq` M.insert k (stmt, filter (/= toDel) deps) mp
{-
Not used but useful in the future
-- | Traverse StmtGraph
toStmtList :: StmtGraph -> [LStmt Var]
toStmtList g
| nullG g = []
| otherwise = s' ++ toStmtList g'
where (s', g') = takeBlock g
stmtsOf :: StmtGraph -> [LStmt Var]
stmtsOf (SGraph _ a) = map fst $ M.elems a
-}
--------------------------------------------------------------------------------
-- ** Operations
--------------------------------------------------------------------------------
---- | Compare two statement blocks.
----
---- The result is an integer whose value denotes the cost of introducing the
---- necessary dummy ops to turn both blocks indistinguishable
--compareBlocks :: BasicBlock -> BasicBlock -> Int
--compareBlocks = undefined
-- | Create dummy op
mkDummyOp :: CaoMonad m => LStmt Var -> m (Int, Set Var, LStmt Var)
mkDummyOp (unLoc -> Assign lvs es) = do
(vs' ,lvs') <- unzip <$> mapM mkDummyLv lvs
(ns, vs'',es') <- unzip3 <$> mapM mkDummyLExpr es
return (sum ns, Set.unions $ vs' ++ vs'', genLoc $ Assign lvs' es')
mkDummyOp (unLoc -> FCallS fn es) = do
(ns, vs, es') <- unzip3 <$> mapM mkDummyLExpr es
return (sum ns, Set.unions vs, genLoc $ FCallS fn es')
mkDummyOp s
= error $ "Language.CAO.CaoSSA.mkDummyOp: failed to create a dummy\
\operation of this kind!" ++ showPpr s
-- mkDummyOp (Ret es) = Ret <$> mapM mkDummyLExpr es
-- mkDummyOp (Ite i t me) =
-- mkDummyOp (Seq (SeqIter id) [LStmt id]
-- mkDummyOp (While e1 ss)
-- mkDummyOp (VDecl vd)
mkDummyLv :: CaoMonad m => LVal Var -> m (Set Var, LVal Var)
mkDummyLv (LVVar (L _ v)) = lvvar <$> freshVar Local (varType v)
where
lvvar v' = (Set.singleton v', LVVar $ genLoc v')
mkDummyLv (LVStruct lv n) = fixT2 (flip LVStruct n) (mkDummyLv lv)
mkDummyLv (LVCont t lv p) = fixT2 (flip (LVCont t) p) (mkDummyLv lv)
mkDummyLExpr :: CaoMonad m => TLExpr Var -> m (Int, Set Var, TLExpr Var)
mkDummyLExpr (L l e) = fixT3 (L l) (mkDummyExpr e)
fixT2 :: CaoMonad m => (a -> b) -> m (c, a) -> m (c, b)
fixT2 f m = (\(a, b) -> (a, f b)) <$> m
fixT3 :: CaoMonad m => (a -> b) -> m (r, s, a) -> m (r, s, b)
fixT3 f m = (\(a, b, c) -> (a, b, f c)) <$> m
-- TODO: complete with other exprs, fix cost of ops
mkDummyExpr :: CaoMonad m => TExpr Var -> m (Int, Set Var, TExpr Var)
mkDummyExpr (TyE t e@(BinaryOp (ArithOp op) _ _)) = do
e' <- T.mapM (freshVar Local . varType) e
return (costAOp op, fvs e', TyE t e')
mkDummyExpr e = do
e' <- T.mapM (freshVar Local . varType) e
return (0 , fvs e', e') -- TODO: Complete!!!
{-
Not used but useful in the future
-- | BasicBlock cost
blockCost :: BasicBlock -> Int
blockCost = sum . map stmtCost
-}
-- | Stmt cost
stmtCost :: LStmt Var -> Int
stmtCost (unLoc -> Assign _ es)
= sum $ map costLExpr es
stmtCost (unLoc -> FCallS _ es)
= sum $ map costLExpr es
stmtCost _
= 0
costLExpr :: TLExpr Var -> Int
costLExpr (L _ (TyE _ e)) = costExpr e
-- TODO: complete with other exprs, fix cost of ops
costExpr :: Expr Var -> Int
costExpr (BinaryOp (ArithOp op) _ _) = costAOp op
costExpr _ = 0
costAOp :: AOp -> Int
costAOp Plus = 1
costAOp Minus = 1
costAOp Times = 10
costAOp Div = 10
costAOp ModOp = 10
costAOp Power = 100
-- TODO: create dependency funcs. Place statements with no dependencies. Check
-- all possible reorderings with the cost of the necessary dummy instructions
-- and pick the lowest. Remove dependencies from graph and continue.