{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Data.SBV.Compilers.CodeGen (
SBVCodeGen(..), cgSym
, cgInput, cgInputArr
, cgOutput, cgOutputArr
, cgReturn, cgReturnArr
, svCgInput, svCgInputArr
, svCgOutput, svCgOutputArr
, svCgReturn, svCgReturnArr
, cgPerformRTCs, cgSetDriverValues
, cgAddPrototype, cgAddDecl, cgAddLDFlags, cgIgnoreSAssert, cgOverwriteFiles
, cgIntegerSize, cgSRealType, CgSRealType(..)
, CgTarget(..), CgConfig(..), CgState(..), CgPgmBundle(..), CgPgmKind(..), CgVal(..)
, defaultCgConfig, initCgState, isCgDriver, isCgMakefile
, cgGenerateDriver, cgGenerateMakefile, codeGen, renderCgPgmBundle
) where
import Control.Monad (filterM, replicateM, unless)
import Control.Monad.Reader (ask)
import Control.Monad.Trans (MonadIO, lift, liftIO)
import Control.Monad.State.Lazy (MonadState, StateT(..), modify')
import Data.Char (toLower, isSpace)
import Data.List (nub, isPrefixOf, intercalate, (\\))
import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist)
import System.FilePath ((</>))
import System.IO (hFlush, stdout)
import Text.PrettyPrint.HughesPJ (Doc, vcat)
import qualified Text.PrettyPrint.HughesPJ as P (render)
import Data.SBV.Core.Data
import Data.SBV.Core.Symbolic (svToSymSW, svMkSymVar, outputSVal)
#if MIN_VERSION_base(4,11,0)
import Control.Monad.Fail as Fail
#endif
class CgTarget a where
targetName :: a -> String
translate :: a -> CgConfig -> String -> CgState -> Result -> CgPgmBundle
data CgConfig = CgConfig {
cgRTC :: Bool
, cgInteger :: Maybe Int
, cgReal :: Maybe CgSRealType
, cgDriverVals :: [Integer]
, cgGenDriver :: Bool
, cgGenMakefile :: Bool
, cgIgnoreAsserts :: Bool
, cgOverwriteGenerated :: Bool
}
defaultCgConfig :: CgConfig
defaultCgConfig = CgConfig { cgRTC = False
, cgInteger = Nothing
, cgReal = Nothing
, cgDriverVals = []
, cgGenDriver = True
, cgGenMakefile = True
, cgIgnoreAsserts = False
, cgOverwriteGenerated = False
}
data CgVal = CgAtomic SW
| CgArray [SW]
data CgState = CgState {
cgInputs :: [(String, CgVal)]
, cgOutputs :: [(String, CgVal)]
, cgReturns :: [CgVal]
, cgPrototypes :: [String]
, cgDecls :: [String]
, cgLDFlags :: [String]
, cgFinalConfig :: CgConfig
}
initCgState :: CgState
initCgState = CgState {
cgInputs = []
, cgOutputs = []
, cgReturns = []
, cgPrototypes = []
, cgDecls = []
, cgLDFlags = []
, cgFinalConfig = defaultCgConfig
}
newtype SBVCodeGen a = SBVCodeGen (StateT CgState Symbolic a)
deriving ( Applicative, Functor, Monad, MonadIO, MonadState CgState
#if MIN_VERSION_base(4,11,0)
, Fail.MonadFail
#endif
)
cgSym :: Symbolic a -> SBVCodeGen a
cgSym = SBVCodeGen . lift
cgSBVToSW :: SBV a -> SBVCodeGen SW
cgSBVToSW = cgSym . sbvToSymSW
cgPerformRTCs :: Bool -> SBVCodeGen ()
cgPerformRTCs b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgRTC = b } })
cgIntegerSize :: Int -> SBVCodeGen ()
cgIntegerSize i
| i `notElem` [8, 16, 32, 64]
= error $ "SBV.cgIntegerSize: Argument must be one of 8, 16, 32, or 64. Received: " ++ show i
| True
= modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgInteger = Just i }})
data CgSRealType = CgFloat
| CgDouble
| CgLongDouble
deriving Eq
instance Show CgSRealType where
show CgFloat = "float"
show CgDouble = "double"
show CgLongDouble = "long double"
cgSRealType :: CgSRealType -> SBVCodeGen ()
cgSRealType rt = modify' (\s -> s {cgFinalConfig = (cgFinalConfig s) { cgReal = Just rt }})
cgGenerateDriver :: Bool -> SBVCodeGen ()
cgGenerateDriver b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgGenDriver = b } })
cgGenerateMakefile :: Bool -> SBVCodeGen ()
cgGenerateMakefile b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgGenMakefile = b } })
cgSetDriverValues :: [Integer] -> SBVCodeGen ()
cgSetDriverValues vs = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgDriverVals = vs } })
cgIgnoreSAssert :: Bool -> SBVCodeGen ()
cgIgnoreSAssert b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgIgnoreAsserts = b } })
cgAddPrototype :: [String] -> SBVCodeGen ()
cgAddPrototype ss = modify' (\s -> let old = cgPrototypes s
new = if null old then ss else old ++ [""] ++ ss
in s { cgPrototypes = new })
cgOverwriteFiles :: Bool -> SBVCodeGen ()
cgOverwriteFiles b = modify' (\s -> s { cgFinalConfig = (cgFinalConfig s) { cgOverwriteGenerated = b } })
cgAddDecl :: [String] -> SBVCodeGen ()
cgAddDecl ss = modify' (\s -> s { cgDecls = cgDecls s ++ ss })
cgAddLDFlags :: [String] -> SBVCodeGen ()
cgAddLDFlags ss = modify' (\s -> s { cgLDFlags = cgLDFlags s ++ ss })
svCgInput :: Kind -> String -> SBVCodeGen SVal
svCgInput k nm = do r <- cgSym (ask >>= liftIO . svMkSymVar (Just ALL) k Nothing)
sw <- cgSym (svToSymSW r)
modify' (\s -> s { cgInputs = (nm, CgAtomic sw) : cgInputs s })
return r
svCgInputArr :: Kind -> Int -> String -> SBVCodeGen [SVal]
svCgInputArr k sz nm
| sz < 1 = error $ "SBV.cgInputArr: Array inputs must have at least one element, given " ++ show sz ++ " for " ++ show nm
| True = do rs <- cgSym $ ask >>= liftIO . replicateM sz . svMkSymVar (Just ALL) k Nothing
sws <- cgSym $ mapM svToSymSW rs
modify' (\s -> s { cgInputs = (nm, CgArray sws) : cgInputs s })
return rs
svCgOutput :: String -> SVal -> SBVCodeGen ()
svCgOutput nm v = do _ <- cgSym (outputSVal v)
sw <- cgSym (svToSymSW v)
modify' (\s -> s { cgOutputs = (nm, CgAtomic sw) : cgOutputs s })
svCgOutputArr :: String -> [SVal] -> SBVCodeGen ()
svCgOutputArr nm vs
| sz < 1 = error $ "SBV.cgOutputArr: Array outputs must have at least one element, received " ++ show sz ++ " for " ++ show nm
| True = do _ <- cgSym (mapM outputSVal vs)
sws <- cgSym (mapM svToSymSW vs)
modify' (\s -> s { cgOutputs = (nm, CgArray sws) : cgOutputs s })
where sz = length vs
svCgReturn :: SVal -> SBVCodeGen ()
svCgReturn v = do _ <- cgSym (outputSVal v)
sw <- cgSym (svToSymSW v)
modify' (\s -> s { cgReturns = CgAtomic sw : cgReturns s })
svCgReturnArr :: [SVal] -> SBVCodeGen ()
svCgReturnArr vs
| sz < 1 = error $ "SBV.cgReturnArr: Array returns must have at least one element, received " ++ show sz
| True = do _ <- cgSym (mapM outputSVal vs)
sws <- cgSym (mapM svToSymSW vs)
modify' (\s -> s { cgReturns = CgArray sws : cgReturns s })
where sz = length vs
cgInput :: SymWord a => String -> SBVCodeGen (SBV a)
cgInput nm = do r <- cgSym forall_
sw <- cgSBVToSW r
modify' (\s -> s { cgInputs = (nm, CgAtomic sw) : cgInputs s })
return r
cgInputArr :: SymWord a => Int -> String -> SBVCodeGen [SBV a]
cgInputArr sz nm
| sz < 1 = error $ "SBV.cgInputArr: Array inputs must have at least one element, given " ++ show sz ++ " for " ++ show nm
| True = do rs <- cgSym $ mapM (const forall_) [1..sz]
sws <- mapM cgSBVToSW rs
modify' (\s -> s { cgInputs = (nm, CgArray sws) : cgInputs s })
return rs
cgOutput :: String -> SBV a -> SBVCodeGen ()
cgOutput nm v = do _ <- cgSym (output v)
sw <- cgSBVToSW v
modify' (\s -> s { cgOutputs = (nm, CgAtomic sw) : cgOutputs s })
cgOutputArr :: SymWord a => String -> [SBV a] -> SBVCodeGen ()
cgOutputArr nm vs
| sz < 1 = error $ "SBV.cgOutputArr: Array outputs must have at least one element, received " ++ show sz ++ " for " ++ show nm
| True = do _ <- cgSym (mapM output vs)
sws <- mapM cgSBVToSW vs
modify' (\s -> s { cgOutputs = (nm, CgArray sws) : cgOutputs s })
where sz = length vs
cgReturn :: SBV a -> SBVCodeGen ()
cgReturn v = do _ <- cgSym (output v)
sw <- cgSBVToSW v
modify' (\s -> s { cgReturns = CgAtomic sw : cgReturns s })
cgReturnArr :: SymWord a => [SBV a] -> SBVCodeGen ()
cgReturnArr vs
| sz < 1 = error $ "SBV.cgReturnArr: Array returns must have at least one element, received " ++ show sz
| True = do _ <- cgSym (mapM output vs)
sws <- mapM cgSBVToSW vs
modify' (\s -> s { cgReturns = CgArray sws : cgReturns s })
where sz = length vs
data CgPgmBundle = CgPgmBundle (Maybe Int, Maybe CgSRealType) [(FilePath, (CgPgmKind, [Doc]))]
data CgPgmKind = CgMakefile [String]
| CgHeader [Doc]
| CgSource
| CgDriver
isCgDriver :: CgPgmKind -> Bool
isCgDriver CgDriver = True
isCgDriver _ = False
isCgMakefile :: CgPgmKind -> Bool
isCgMakefile CgMakefile{} = True
isCgMakefile _ = False
instance Show CgPgmBundle where
show (CgPgmBundle _ fs) = intercalate "\n" $ map showFile fs
where showFile :: (FilePath, (CgPgmKind, [Doc])) -> String
showFile (f, (_, ds)) = "== BEGIN: " ++ show f ++ " ================\n"
++ render' (vcat ds)
++ "== END: " ++ show f ++ " =================="
codeGen :: CgTarget l => l -> CgConfig -> String -> SBVCodeGen () -> IO (CgConfig, CgPgmBundle)
codeGen l cgConfig nm (SBVCodeGen comp) = do
(((), st'), res) <- runSymbolic CodeGen $ runStateT comp initCgState { cgFinalConfig = cgConfig }
let st = st' { cgInputs = reverse (cgInputs st')
, cgOutputs = reverse (cgOutputs st')
}
allNamedVars = map fst (cgInputs st ++ cgOutputs st)
dupNames = allNamedVars \\ nub allNamedVars
unless (null dupNames) $
error $ "SBV.codeGen: " ++ show nm ++ " has following argument names duplicated: " ++ unwords dupNames
return (cgFinalConfig st, translate l (cgFinalConfig st) nm st res)
renderCgPgmBundle :: Maybe FilePath -> (CgConfig, CgPgmBundle) -> IO ()
renderCgPgmBundle Nothing (_ , bundle) = print bundle
renderCgPgmBundle (Just dirName) (cfg, CgPgmBundle _ files) = do
b <- doesDirectoryExist dirName
unless b $ do unless overWrite $ putStrLn $ "Creating directory " ++ show dirName ++ ".."
createDirectoryIfMissing True dirName
dups <- filterM (\fn -> doesFileExist (dirName </> fn)) (map fst files)
goOn <- case (overWrite, dups) of
(True, _) -> return True
(_, []) -> return True
_ -> do putStrLn $ "Code generation would overwrite the following " ++ (if length dups == 1 then "file:" else "files:")
mapM_ (\fn -> putStrLn ('\t' : fn)) dups
putStr "Continue? [yn] "
hFlush stdout
resp <- getLine
return $ map toLower resp `isPrefixOf` "yes"
if goOn then do mapM_ renderFile files
unless overWrite $ putStrLn "Done."
else putStrLn "Aborting."
where overWrite = cgOverwriteGenerated cfg
renderFile (f, (_, ds)) = do let fn = dirName </> f
unless overWrite $ putStrLn $ "Generating: " ++ show fn ++ ".."
writeFile fn (render' (vcat ds))
render' :: Doc -> String
render' = unlines . map clean . lines . P.render
where clean x | all isSpace x = ""
| True = x