{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# OPTIONS_GHC -Wall -Werror #-}
module Data.SBV.Compilers.CodeGen (
SBVCodeGen(..), cgSym
, cgInput, cgInputArr
, cgOutput, cgOutputArr
, cgReturn, cgReturnArr
, svCgInput, svCgInputArr
, svCgOutput, svCgOutputArr
, svCgReturn, svCgReturnArr
, cgPerformRTCs, cgSetDriverValues
, cgAddPrototype, cgAddDecl, cgAddLDFlags, cgIgnoreSAssert, cgOverwriteFiles, cgShowU8UsingHex
, cgIntegerSize, cgSRealType, CgSRealType(..)
, CgTarget(..), CgConfig(..), CgState(..), CgPgmBundle(..), CgPgmKind(..), CgVal(..)
, defaultCgConfig, initCgState, isCgDriver, isCgMakefile
, cgGenerateDriver, cgGenerateMakefile, codeGen, renderCgPgmBundle
) where
import Control.Monad (filterM, replicateM, unless)
import Control.Monad.Trans (MonadIO(liftIO), lift)
import Control.Monad.State.Lazy (MonadState, StateT(..), modify')
import Data.Char (toLower, isSpace)
import Data.List (nub, isPrefixOf, intercalate, (\\))
import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist)
import System.FilePath ((</>))
import System.IO (hFlush, stdout)
import Text.PrettyPrint.HughesPJ (Doc, vcat)
import qualified Text.PrettyPrint.HughesPJ as P (render)
import Data.SBV.Core.Data
import Data.SBV.Core.Symbolic (MonadSymbolic(..), svToSymSV, svMkSymVar, outputSVal)
#if MIN_VERSION_base(4,11,0)
import Control.Monad.Fail as Fail
#endif
class CgTarget a where
targetName :: a -> String
translate :: a -> CgConfig -> String -> CgState -> Result -> CgPgmBundle
data CgConfig = CgConfig {
cgRTC :: Bool
, cgInteger :: Maybe Int
, cgReal :: Maybe CgSRealType
, cgDriverVals :: [Integer]
, cgGenDriver :: Bool
, cgGenMakefile :: Bool
, cgIgnoreAsserts :: Bool
, cgOverwriteGenerated :: Bool
, cgShowU8InHex :: Bool
}
defaultCgConfig :: CgConfig
defaultCgConfig = CgConfig { cgRTC = False
, cgInteger = Nothing
, cgReal = Nothing
, cgDriverVals = []
, cgGenDriver = True
, cgGenMakefile = True
, cgIgnoreAsserts = False
, cgOverwriteGenerated = False
, cgShowU8InHex = False
}
data CgVal = CgAtomic SV
| CgArray [SV]
data CgState = CgState {
cgInputs :: [(String, CgVal)]
, cgOutputs :: [(String, CgVal)]
, cgReturns :: [CgVal]
, cgPrototypes :: [String]
, cgDecls :: [String]
, cgLDFlags :: [String]
, cgFinalConfig :: CgConfig
}
initCgState :: CgState
initCgState = CgState {
cgInputs = []
, cgOutputs = []
, cgReturns = []
, cgPrototypes = []
, cgDecls = []
, cgLDFlags = []
, cgFinalConfig = defaultCgConfig
}
newtype SBVCodeGen a = SBVCodeGen (StateT CgState Symbolic a)
deriving ( Applicative, Functor, Monad, MonadIO, MonadState CgState
, MonadSymbolic
#if MIN_VERSION_base(4,11,0)
, Fail.MonadFail
#endif
)
cgSym :: Symbolic a -> SBVCodeGen a
cgSym = SBVCodeGen . lift
cgPerformRTCs :: Bool -> SBVCodeGen ()
cgPerformRTCs b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgRTC = b } })
cgIntegerSize :: Int -> SBVCodeGen ()
cgIntegerSize i
| i `notElem` [8, 16, 32, 64]
= error $ "SBV.cgIntegerSize: Argument must be one of 8, 16, 32, or 64. Received: " ++ show i
| True
= modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgInteger = Just i }})
data CgSRealType = CgFloat
| CgDouble
| CgLongDouble
deriving Eq
instance Show CgSRealType where
show CgFloat = "float"
show CgDouble = "double"
show CgLongDouble = "long double"
cgSRealType :: CgSRealType -> SBVCodeGen ()
cgSRealType rt = modify' (\s -> s {cgFinalConfig = (cgFinalConfig s) { cgReal = Just rt }})
cgGenerateDriver :: Bool -> SBVCodeGen ()
cgGenerateDriver b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgGenDriver = b } })
cgGenerateMakefile :: Bool -> SBVCodeGen ()
cgGenerateMakefile b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgGenMakefile = b } })
cgSetDriverValues :: [Integer] -> SBVCodeGen ()
cgSetDriverValues vs = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgDriverVals = vs } })
cgIgnoreSAssert :: Bool -> SBVCodeGen ()
cgIgnoreSAssert b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgIgnoreAsserts = b } })
cgAddPrototype :: [String] -> SBVCodeGen ()
cgAddPrototype ss = modify' (\s -> let old = cgPrototypes s
new = if null old then ss else old ++ [""] ++ ss
in s { cgPrototypes = new })
cgOverwriteFiles :: Bool -> SBVCodeGen ()
cgOverwriteFiles b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgOverwriteGenerated = b } })
cgShowU8UsingHex :: Bool -> SBVCodeGen ()
cgShowU8UsingHex b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgShowU8InHex = b } })
cgAddDecl :: [String] -> SBVCodeGen ()
cgAddDecl ss = modify' (\s -> s { cgDecls = cgDecls s ++ ss })
cgAddLDFlags :: [String] -> SBVCodeGen ()
cgAddLDFlags ss = modify' (\s -> s { cgLDFlags = cgLDFlags s ++ ss })
svCgInput :: Kind -> String -> SBVCodeGen SVal
svCgInput k nm = do r <- symbolicEnv >>= liftIO . svMkSymVar (Just ALL) k Nothing
sv <- svToSymSV r
modify' (\s -> s { cgInputs = (nm, CgAtomic sv) : cgInputs s })
return r
svCgInputArr :: Kind -> Int -> String -> SBVCodeGen [SVal]
svCgInputArr k sz nm
| sz < 1 = error $ "SBV.cgInputArr: Array inputs must have at least one element, given " ++ show sz ++ " for " ++ show nm
| True = do rs <- symbolicEnv >>= liftIO . replicateM sz . svMkSymVar (Just ALL) k Nothing
sws <- mapM svToSymSV rs
modify' (\s -> s { cgInputs = (nm, CgArray sws) : cgInputs s })
return rs
svCgOutput :: String -> SVal -> SBVCodeGen ()
svCgOutput nm v = do _ <- outputSVal v
sv <- svToSymSV v
modify' (\s -> s { cgOutputs = (nm, CgAtomic sv) : cgOutputs s })
svCgOutputArr :: String -> [SVal] -> SBVCodeGen ()
svCgOutputArr nm vs
| sz < 1 = error $ "SBV.cgOutputArr: Array outputs must have at least one element, received " ++ show sz ++ " for " ++ show nm
| True = do mapM_ outputSVal vs
sws <- mapM svToSymSV vs
modify' (\s -> s { cgOutputs = (nm, CgArray sws) : cgOutputs s })
where sz = length vs
svCgReturn :: SVal -> SBVCodeGen ()
svCgReturn v = do _ <- outputSVal v
sv <- svToSymSV v
modify' (\s -> s { cgReturns = CgAtomic sv : cgReturns s })
svCgReturnArr :: [SVal] -> SBVCodeGen ()
svCgReturnArr vs
| sz < 1 = error $ "SBV.cgReturnArr: Array returns must have at least one element, received " ++ show sz
| True = do mapM_ outputSVal vs
sws <- mapM svToSymSV vs
modify' (\s -> s { cgReturns = CgArray sws : cgReturns s })
where sz = length vs
cgInput :: SymVal a => String -> SBVCodeGen (SBV a)
cgInput nm = do r <- forall_
sv <- sbvToSymSV r
modify' (\s -> s { cgInputs = (nm, CgAtomic sv) : cgInputs s })
return r
cgInputArr :: SymVal a => Int -> String -> SBVCodeGen [SBV a]
cgInputArr sz nm
| sz < 1 = error $ "SBV.cgInputArr: Array inputs must have at least one element, given " ++ show sz ++ " for " ++ show nm
| True = do rs <- mapM (const forall_) [1..sz]
sws <- mapM sbvToSymSV rs
modify' (\s -> s { cgInputs = (nm, CgArray sws) : cgInputs s })
return rs
cgOutput :: String -> SBV a -> SBVCodeGen ()
cgOutput nm v = do _ <- output v
sv <- sbvToSymSV v
modify' (\s -> s { cgOutputs = (nm, CgAtomic sv) : cgOutputs s })
cgOutputArr :: SymVal a => String -> [SBV a] -> SBVCodeGen ()
cgOutputArr nm vs
| sz < 1 = error $ "SBV.cgOutputArr: Array outputs must have at least one element, received " ++ show sz ++ " for " ++ show nm
| True = do mapM_ output vs
sws <- mapM sbvToSymSV vs
modify' (\s -> s { cgOutputs = (nm, CgArray sws) : cgOutputs s })
where sz = length vs
cgReturn :: SBV a -> SBVCodeGen ()
cgReturn v = do _ <- output v
sv <- sbvToSymSV v
modify' (\s -> s { cgReturns = CgAtomic sv : cgReturns s })
cgReturnArr :: SymVal a => [SBV a] -> SBVCodeGen ()
cgReturnArr vs
| sz < 1 = error $ "SBV.cgReturnArr: Array returns must have at least one element, received " ++ show sz
| True = do mapM_ output vs
sws <- mapM sbvToSymSV vs
modify' (\s -> s { cgReturns = CgArray sws : cgReturns s })
where sz = length vs
data CgPgmBundle = CgPgmBundle (Maybe Int, Maybe CgSRealType) [(FilePath, (CgPgmKind, [Doc]))]
data CgPgmKind = CgMakefile [String]
| CgHeader [Doc]
| CgSource
| CgDriver
isCgDriver :: CgPgmKind -> Bool
isCgDriver CgDriver = True
isCgDriver _ = False
isCgMakefile :: CgPgmKind -> Bool
isCgMakefile CgMakefile{} = True
isCgMakefile _ = False
instance Show CgPgmBundle where
show (CgPgmBundle _ fs) = intercalate "\n" $ map showFile fs
where showFile :: (FilePath, (CgPgmKind, [Doc])) -> String
showFile (f, (_, ds)) = "== BEGIN: " ++ show f ++ " ================\n"
++ render' (vcat ds)
++ "== END: " ++ show f ++ " =================="
codeGen :: CgTarget l => l -> CgConfig -> String -> SBVCodeGen a -> IO (a, CgConfig, CgPgmBundle)
codeGen l cgConfig nm (SBVCodeGen comp) = do
((retVal, st'), res) <- runSymbolic CodeGen $ runStateT comp initCgState { cgFinalConfig = cgConfig }
let st = st' { cgInputs = reverse (cgInputs st')
, cgOutputs = reverse (cgOutputs st')
}
allNamedVars = map fst (cgInputs st ++ cgOutputs st)
dupNames = allNamedVars \\ nub allNamedVars
unless (null dupNames) $
error $ "SBV.codeGen: " ++ show nm ++ " has following argument names duplicated: " ++ unwords dupNames
return (retVal, cgFinalConfig st, translate l (cgFinalConfig st) nm st res)
renderCgPgmBundle :: Maybe FilePath -> (CgConfig, CgPgmBundle) -> IO ()
renderCgPgmBundle Nothing (_ , bundle) = print bundle
renderCgPgmBundle (Just dirName) (cfg, CgPgmBundle _ files) = do
b <- doesDirectoryExist dirName
unless b $ do unless overWrite $ putStrLn $ "Creating directory " ++ show dirName ++ ".."
createDirectoryIfMissing True dirName
dups <- filterM (\fn -> doesFileExist (dirName </> fn)) (map fst files)
goOn <- case (overWrite, dups) of
(True, _) -> return True
(_, []) -> return True
_ -> do putStrLn $ "Code generation would overwrite the following " ++ (if length dups == 1 then "file:" else "files:")
mapM_ (\fn -> putStrLn ('\t' : fn)) dups
putStr "Continue? [yn] "
hFlush stdout
resp <- getLine
return $ map toLower resp `isPrefixOf` "yes"
if goOn then do mapM_ renderFile files
unless overWrite $ putStrLn "Done."
else putStrLn "Aborting."
where overWrite = cgOverwriteGenerated cfg
renderFile (f, (_, ds)) = do let fn = dirName </> f
unless overWrite $ putStrLn $ "Generating: " ++ show fn ++ ".."
writeFile fn (render' (vcat ds))
render' :: Doc -> String
render' = unlines . map clean . lines . P.render
where clean x | all isSpace x = ""
| True = x